MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / train

Function train

14_2_seq2seq_att.py:54–76  ·  view source on GitHub ↗
(src, target)

Source from the content-addressed store, hash-verified

52# We need to use (1) batch and (2) data parallelism
53# http://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html.
54def train(src, target):
55 loss = 0
56
57 src_var = str2tensor(src)
58 target_var = str2tensor(target, eos=True) # Add the EOS token
59
60 encoder_hidden = encoder.init_hidden()
61 encoder_outputs, encoder_hidden = encoder(src_var, encoder_hidden)
62
63 hidden = encoder_hidden
64
65 for c in range(len(target_var)):
66 # First, we feed SOS. Others, we use teacher forcing.
67 token = target_var[c - 1] if c else str2tensor(SOS_token)
68 output, hidden, attention = decoder(token, hidden, encoder_outputs)
69 loss += criterion(output, target_var[c])
70
71 encoder.zero_grad()
72 decoder.zero_grad()
73 loss.backward()
74 optimizer.step()
75
76 return loss.data[0] / len(target_var)
77
78
79# Translate the given input

Callers 1

Calls 2

str2tensorFunction · 0.90
init_hiddenMethod · 0.45

Tested by

no test coverage detected