(model, loss, optimizer, inputs, labels)
| 64 | # so we encapsulate it in a function |
| 65 | # Note: inputs and labels are torch tensors |
| 66 | def train(model, loss, optimizer, inputs, labels): |
| 67 | # https://discuss.pytorch.org/t/why-is-it-recommended-to-wrap-your-data-with-variable-each-step-of-the-iterations-rather-than-before-training-starts/12683 |
| 68 | inputs = Variable(inputs, requires_grad=False) |
| 69 | labels = Variable(labels, requires_grad=False) |
| 70 | |
| 71 | # Reset gradient |
| 72 | # https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/7 |
| 73 | optimizer.zero_grad() |
| 74 | |
| 75 | # Forward |
| 76 | logits = model.forward(inputs) |
| 77 | output = loss.forward(logits, labels) |
| 78 | |
| 79 | # Backward |
| 80 | output.backward() |
| 81 | |
| 82 | # Update parameters |
| 83 | optimizer.step() |
| 84 | |
| 85 | # what's the difference between backward() and step()? |
| 86 | # https://discuss.pytorch.org/t/what-does-the-backward-function-do/9944 |
| 87 | return output.item() |
| 88 | |
| 89 | |
| 90 | # define the prediction procedure |
no test coverage detected