MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / forward

Method forward

seq2seq_models.py:101–122  ·  view source on GitHub ↗
(self, word_input, last_hidden, encoder_hiddens)

Source from the content-addressed store, hash-verified

99 self.out = nn.Linear(hidden_size * 2, output_size)
100
101 def forward(self, word_input, last_hidden, encoder_hiddens):
102 # Note: we run this one step (S=1) at a time
103 # Get the embedding of the current input word (last output word)
104 rnn_input = self.embedding(word_input).view(1, 1, -1) # S=1 x B x I
105 rnn_output, hidden = self.gru(rnn_input, last_hidden)
106
107 # Calculate attention from current RNN state and all encoder outputs;
108 # apply to encoder outputs
109 attn_weights = self.get_att_weight(
110 rnn_output.squeeze(0), encoder_hiddens)
111 context = attn_weights.bmm(
112 encoder_hiddens.transpose(0, 1)) # B x S(=1) x I
113
114 # Final output layer (next word prediction) using the RNN hidden state
115 # and context vector
116 rnn_output = rnn_output.squeeze(0) # S(=1) x B x I -> B x I
117 context = context.squeeze(1) # B x S(=1) x I -> B x I
118 output = self.out(torch.cat((rnn_output, context), 1))
119
120 # Return final output, hidden state, and attention weights (for
121 # visualization)
122 return output, hidden, attn_weights
123
124 def get_att_weight(self, hidden, encoder_hiddens):
125 seq_len = len(encoder_hiddens)

Callers

nothing calls this directly

Calls 1

get_att_weightMethod · 0.95

Tested by

no test coverage detected