MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / train

Function train

14_1_seq2seq.py:38–60  ·  view source on GitHub ↗
(src, target)

Source from the content-addressed store, hash-verified

36# We need to use (1) batch and (2) data parallelism
37# http://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html.
38def train(src, target):
39 src_var = str2tensor(src)
40 target_var = str2tensor(target, eos=True) # Add the EOS token
41
42 encoder_hidden = encoder.init_hidden()
43 encoder_outputs, encoder_hidden = encoder(src_var, encoder_hidden)
44
45 hidden = encoder_hidden
46 loss = 0
47
48 for c in range(len(target_var)):
49 # First, we feed SOS
50 # Others, we use teacher forcing
51 token = target_var[c - 1] if c else str2tensor(SOS_token)
52 output, hidden = decoder(token, hidden)
53 loss += criterion(output, target_var[c])
54
55 encoder.zero_grad()
56 decoder.zero_grad()
57 loss.backward()
58 optimizer.step()
59
60 return loss.data[0] / len(target_var)
61
62
63# Translate the given input

Callers 1

14_1_seq2seq.pyFile · 0.70

Calls 2

str2tensorFunction · 0.90
init_hiddenMethod · 0.45

Tested by

no test coverage detected