MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / train

Function train

13_3_char_rnn.py:107–124  ·  view source on GitHub ↗
(line)

Source from the content-addressed store, hash-verified

105
106
107def train(line):
108 input = str2tensor(line[:-1])
109 target = str2tensor(line[1:])
110
111 hidden = decoder.init_hidden()
112 decoder_in = input[0]
113 loss = 0
114
115 for c in range(len(input)):
116 output, hidden = decoder(decoder_in, hidden)
117 loss += criterion(output, target[c])
118 decoder_in = output.max(1)[1]
119
120 decoder.zero_grad()
121 loss.backward()
122 decoder_optimizer.step()
123
124 return loss.data[0] / len(input)
125
126if __name__ == '__main__':
127

Callers 1

13_3_char_rnn.pyFile · 0.70

Calls 2

str2tensorFunction · 0.70
init_hiddenMethod · 0.45

Tested by

no test coverage detected