| 90 | norm.reset_parameters() |
| 91 | |
| 92 | def forward(self, x, sample_ids_adjs, full_ids_adjs) -> Tensor: |
| 93 | sample_ids, sample_adjs = sample_ids_adjs |
| 94 | full_ids, full_adjs = full_ids_adjs |
| 95 | |
| 96 | """VR-GCN""" |
| 97 | x = x[sample_ids[0]].to(self._device) |
| 98 | x_list = [] |
| 99 | for i in range(self.num_layers): |
| 100 | sample_adj, cur_id, target_id = sample_adjs[i], sample_ids[i], sample_ids[i + 1] |
| 101 | full_id, full_adj = full_ids[i], full_adjs[i] |
| 102 | full_adj = full_adj.to(x.device) |
| 103 | sample_adj = sample_adj.to(x.device) |
| 104 | |
| 105 | x = x - self.histories[i].pull(cur_id).detach() |
| 106 | h = self.histories[i].pull(full_id) |
| 107 | |
| 108 | x = spmm(sample_adj, x)[: target_id.shape[0]] + spmm(full_adj, h)[: target_id.shape[0]].detach() |
| 109 | x = self.lins[i](x) |
| 110 | |
| 111 | if i != self.num_layers - 1: |
| 112 | x = self.norms[i](x) |
| 113 | x = x.relu_() |
| 114 | x_list.append(x) |
| 115 | x = F.dropout(x, p=self.dropout, training=self.training) |
| 116 | |
| 117 | """history embedding update""" |
| 118 | for i in range(1, self.num_layers): |
| 119 | self.histories[i].push(x_list[i - 1].detach(), sample_ids[i]) |
| 120 | return x.log_softmax(dim=-1) |
| 121 | |
| 122 | def initialize_history(self, x, test_loader): |
| 123 | _, xs = self.inference_batch(x, test_loader) |