MCPcopy Index your code
hub / github.com/hunkim/PyTorchZeroToAll / train

Function train

13_2_rnn_classification.py:136–157  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

134
135# Train cycle
136def train():
137 total_loss = 0
138
139 for i, (names, countries) in enumerate(train_loader, 1):
140 input, seq_lengths, target = make_variables(names, countries)
141 output = classifier(input, seq_lengths)
142
143 loss = criterion(output, target)
144 total_loss += loss.data[0]
145
146 classifier.zero_grad()
147 loss.backward()
148 optimizer.step()
149
150 if i % 10 == 0:
151 print('[{}] Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.2f}'.format(
152 time_since(start), epoch, i *
153 len(names), len(train_loader.dataset),
154 100. * i * len(names) / len(train_loader.dataset),
155 total_loss / i * len(names)))
156
157 return total_loss
158
159
160# Testing cycle

Callers 1

Calls 2

time_sinceFunction · 0.85
make_variablesFunction · 0.70

Tested by

no test coverage detected