MCPcopy
hub / github.com/THUDM/CogDL / forward

Method forward

examples/VRGCN/VRGCN.py:92–120  ·  view source on GitHub ↗
(self, x, sample_ids_adjs, full_ids_adjs)

Source from the content-addressed store, hash-verified

90 norm.reset_parameters()
91
92 def forward(self, x, sample_ids_adjs, full_ids_adjs) -> Tensor:
93 sample_ids, sample_adjs = sample_ids_adjs
94 full_ids, full_adjs = full_ids_adjs
95
96 """VR-GCN"""
97 x = x[sample_ids[0]].to(self._device)
98 x_list = []
99 for i in range(self.num_layers):
100 sample_adj, cur_id, target_id = sample_adjs[i], sample_ids[i], sample_ids[i + 1]
101 full_id, full_adj = full_ids[i], full_adjs[i]
102 full_adj = full_adj.to(x.device)
103 sample_adj = sample_adj.to(x.device)
104
105 x = x - self.histories[i].pull(cur_id).detach()
106 h = self.histories[i].pull(full_id)
107
108 x = spmm(sample_adj, x)[: target_id.shape[0]] + spmm(full_adj, h)[: target_id.shape[0]].detach()
109 x = self.lins[i](x)
110
111 if i != self.num_layers - 1:
112 x = self.norms[i](x)
113 x = x.relu_()
114 x_list.append(x)
115 x = F.dropout(x, p=self.dropout, training=self.training)
116
117 """history embedding update"""
118 for i in range(1, self.num_layers):
119 self.histories[i].push(x_list[i - 1].detach(), sample_ids[i])
120 return x.log_softmax(dim=-1)
121
122 def initialize_history(self, x, test_loader):
123 _, xs = self.inference_batch(x, test_loader)

Callers

nothing calls this directly

Calls 4

spmmFunction · 0.90
pullMethod · 0.80
pushMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected