()
| 21 | |
| 22 | # Simple test to show how our train works |
| 23 | def test(): |
| 24 | encoder_test = sm.EncoderRNN(10, 10, 2) |
| 25 | decoder_test = sm.AttnDecoderRNN(10, 10, 2) |
| 26 | |
| 27 | if torch.cuda.is_available(): |
| 28 | encoder_test.cuda() |
| 29 | decoder_test.cuda() |
| 30 | |
| 31 | encoder_hidden = encoder_test.init_hidden() |
| 32 | word_input = cuda_variable(torch.LongTensor([1, 2, 3])) |
| 33 | encoder_outputs, encoder_hidden = encoder_test(word_input, encoder_hidden) |
| 34 | print(encoder_outputs.size()) |
| 35 | |
| 36 | word_target = cuda_variable(torch.LongTensor([1, 2, 3])) |
| 37 | decoder_attns = torch.zeros(1, 3, 3) |
| 38 | decoder_hidden = encoder_hidden |
| 39 | |
| 40 | for c in range(len(word_target)): |
| 41 | decoder_output, decoder_hidden, decoder_attn = \ |
| 42 | decoder_test(word_target[c], |
| 43 | decoder_hidden, encoder_outputs) |
| 44 | print(decoder_output.size(), decoder_hidden.size(), decoder_attn.size()) |
| 45 | decoder_attns[0, c] = decoder_attn.squeeze(0).cpu().data |
| 46 | |
| 47 | |
| 48 | # Train for a given src and target |
nothing calls this directly
no test coverage detected