(self, word_input, last_hidden, encoder_hiddens)
| 99 | self.out = nn.Linear(hidden_size * 2, output_size) |
| 100 | |
| 101 | def forward(self, word_input, last_hidden, encoder_hiddens): |
| 102 | # Note: we run this one step (S=1) at a time |
| 103 | # Get the embedding of the current input word (last output word) |
| 104 | rnn_input = self.embedding(word_input).view(1, 1, -1) # S=1 x B x I |
| 105 | rnn_output, hidden = self.gru(rnn_input, last_hidden) |
| 106 | |
| 107 | # Calculate attention from current RNN state and all encoder outputs; |
| 108 | # apply to encoder outputs |
| 109 | attn_weights = self.get_att_weight( |
| 110 | rnn_output.squeeze(0), encoder_hiddens) |
| 111 | context = attn_weights.bmm( |
| 112 | encoder_hiddens.transpose(0, 1)) # B x S(=1) x I |
| 113 | |
| 114 | # Final output layer (next word prediction) using the RNN hidden state |
| 115 | # and context vector |
| 116 | rnn_output = rnn_output.squeeze(0) # S(=1) x B x I -> B x I |
| 117 | context = context.squeeze(1) # B x S(=1) x I -> B x I |
| 118 | output = self.out(torch.cat((rnn_output, context), 1)) |
| 119 | |
| 120 | # Return final output, hidden state, and attention weights (for |
| 121 | # visualization) |
| 122 | return output, hidden, attn_weights |
| 123 | |
| 124 | def get_att_weight(self, hidden, encoder_hiddens): |
| 125 | seq_len = len(encoder_hiddens) |
nothing calls this directly
no test coverage detected