Skip to content

Instantly share code, notes, and snippets.

@vritant24
Created March 29, 2018 22:49
Show Gist options
  • Select an option

  • Save vritant24/4e9ba27589d63e44a4ad3e4235c3e44b to your computer and use it in GitHub Desktop.

Select an option

Save vritant24/4e9ba27589d63e44a4ad3e4235c3e44b to your computer and use it in GitHub Desktop.
Implementation of the same neural network in Scala front end
// Implementation of the same neural network in Scala front end
package lantern
...
def snippet(a: Rep[String]): Rep[Unit] = {
...
val train_loader = new DataLoader(...)
// initialize all parameters
val pars = List(...)
def forward(input: TensorR, target: Rep[Int]) = { (dummy: TensorR) =>
val resL1 = (input.view(784).dot(pars(0)) + pars(1)).relu()
val resL2 = resL1.dot(pars(2)) + pars(3)
resL2.logSoftmax().nllLoss(target)
}
def train(epoch: Rep[Int]) = {
for ((batch_idx, (data, target)) <- zip(train_loader.length, train_loader)) {
val loss = forward(data, target)
val loss_value = grad_loss(loss) // grad_loss triggers backward prop and returns loss_value
for (par <- pars) step()
if (((batch_idx + 1) % log_interval) == 0):
print_time_and_loss()
}
}
for (idx <- epochs) train(idx)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment