(model, loss, optimizer, inputs, labels)
| 64 | # so we encapsulate it in a function |
| 65 | # Note: inputs and labels are torch tensors |
| 66 | def train(model, loss, optimizer, inputs, labels): |
| 67 | inputs = Variable(inputs, requires_grad=False) |
| 68 | labels = Variable(labels, requires_grad=False) |
| 69 | |
| 70 | # Reset gradient |
| 71 | optimizer.zero_grad() |
| 72 | |
| 73 | # Forward |
| 74 | logits = model.forward(inputs) |
| 75 | output = loss.forward(logits, labels) |
| 76 | |
| 77 | # Backward |
| 78 | output.backward() |
| 79 | |
| 80 | # Update parameters |
| 81 | optimizer.step() |
| 82 | |
| 83 | # what's the difference between backward() and step()? |
| 84 | |
| 85 | return output.item() |
| 86 | |
| 87 | |
| 88 | # similar to train() but not doing the backprop step |
no test coverage detected