MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / test

Function test

14_2_seq2seq_att.py:23–45  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

21
22# Simple test to show how our train works
23def test():
24 encoder_test = sm.EncoderRNN(10, 10, 2)
25 decoder_test = sm.AttnDecoderRNN(10, 10, 2)
26
27 if torch.cuda.is_available():
28 encoder_test.cuda()
29 decoder_test.cuda()
30
31 encoder_hidden = encoder_test.init_hidden()
32 word_input = cuda_variable(torch.LongTensor([1, 2, 3]))
33 encoder_outputs, encoder_hidden = encoder_test(word_input, encoder_hidden)
34 print(encoder_outputs.size())
35
36 word_target = cuda_variable(torch.LongTensor([1, 2, 3]))
37 decoder_attns = torch.zeros(1, 3, 3)
38 decoder_hidden = encoder_hidden
39
40 for c in range(len(word_target)):
41 decoder_output, decoder_hidden, decoder_attn = \
42 decoder_test(word_target[c],
43 decoder_hidden, encoder_outputs)
44 print(decoder_output.size(), decoder_hidden.size(), decoder_attn.size())
45 decoder_attns[0, c] = decoder_attn.squeeze(0).cpu().data
46
47
48# Train for a given src and target

Callers

nothing calls this directly

Calls 2

init_hiddenMethod · 0.95
cuda_variableFunction · 0.90

Tested by

no test coverage detected