MCPcopy Index your code
hub / github.com/lazyprogrammer/machine_learning_examples / train

Function train

ann_class2/pytorch_example2.py:66–85  ·  view source on GitHub ↗
(model, loss, optimizer, inputs, labels)

Source from the content-addressed store, hash-verified

64# so we encapsulate it in a function
65# Note: inputs and labels are torch tensors
66def train(model, loss, optimizer, inputs, labels):
67 inputs = Variable(inputs, requires_grad=False)
68 labels = Variable(labels, requires_grad=False)
69
70 # Reset gradient
71 optimizer.zero_grad()
72
73 # Forward
74 logits = model.forward(inputs)
75 output = loss.forward(logits, labels)
76
77 # Backward
78 output.backward()
79
80 # Update parameters
81 optimizer.step()
82
83 # what's the difference between backward() and step()?
84
85 return output.item()
86
87
88# similar to train() but not doing the backprop step

Callers 1

Calls 2

forwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected