| 208 | return hidden_states |
| 209 | |
| 210 | def decode(self, g, teacher_states, hidden_states, batch_cnt, device): |
| 211 | outputs = [] |
| 212 | inputs = torch.zeros(g.num_nodes(), self.in_feats).to(device) |
| 213 | for i in range(self.seq_len): |
| 214 | if ( |
| 215 | np.random.random() < self.compute_thresh(batch_cnt) |
| 216 | and self.training |
| 217 | ): |
| 218 | inputs, hidden_states = self.decoder( |
| 219 | g, teacher_states[i], hidden_states |
| 220 | ) |
| 221 | else: |
| 222 | inputs, hidden_states = self.decoder(g, inputs, hidden_states) |
| 223 | outputs.append(inputs) |
| 224 | outputs = torch.stack(outputs) |
| 225 | return outputs |
| 226 | |
| 227 | def forward(self, g, inputs, teacher_states, batch_cnt, device): |
| 228 | hidden = self.encode(g, inputs, device) |