Created
March 29, 2018 22:49
-
-
Save vritant24/4e9ba27589d63e44a4ad3e4235c3e44b to your computer and use it in GitHub Desktop.
Implementation of the same neural network in Scala front end
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Implementation of the same neural network in Scala front end | |
| package lantern | |
| ... | |
| def snippet(a: Rep[String]): Rep[Unit] = { | |
| ... | |
| val train_loader = new DataLoader(...) | |
| // initialize all parameters | |
| val pars = List(...) | |
| def forward(input: TensorR, target: Rep[Int]) = { (dummy: TensorR) => | |
| val resL1 = (input.view(784).dot(pars(0)) + pars(1)).relu() | |
| val resL2 = resL1.dot(pars(2)) + pars(3) | |
| resL2.logSoftmax().nllLoss(target) | |
| } | |
| def train(epoch: Rep[Int]) = { | |
| for ((batch_idx, (data, target)) <- zip(train_loader.length, train_loader)) { | |
| val loss = forward(data, target) | |
| val loss_value = grad_loss(loss) // grad_loss triggers backward prop and returns loss_value | |
| for (par <- pars) step() | |
| if (((batch_idx + 1) % log_interval) == 0): | |
| print_time_and_loss() | |
| } | |
| } | |
| for (idx <- epochs) train(idx) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment