MCPcopy
hub / github.com/dmlc/dgl / decode

Method decode

examples/pytorch/dtgrnn/model.py:210–225  ·  view source on GitHub ↗
(self, g, teacher_states, hidden_states, batch_cnt, device)

Source from the content-addressed store, hash-verified

208 return hidden_states
209
210 def decode(self, g, teacher_states, hidden_states, batch_cnt, device):
211 outputs = []
212 inputs = torch.zeros(g.num_nodes(), self.in_feats).to(device)
213 for i in range(self.seq_len):
214 if (
215 np.random.random() < self.compute_thresh(batch_cnt)
216 and self.training
217 ):
218 inputs, hidden_states = self.decoder(
219 g, teacher_states[i], hidden_states
220 )
221 else:
222 inputs, hidden_states = self.decoder(g, inputs, hidden_states)
223 outputs.append(inputs)
224 outputs = torch.stack(outputs)
225 return outputs
226
227 def forward(self, g, inputs, teacher_states, batch_cnt, device):
228 hidden = self.encode(g, inputs, device)

Callers 15

forwardMethod · 0.95
get_killed_pidsFunction · 0.45
get_remote_pidsFunction · 0.45
get_killed_pidsFunction · 0.45
get_remote_pidsFunction · 0.45
readMethod · 0.45
base.pyFile · 0.45
test_trainFunction · 0.45
test_recipeFunction · 0.45
check_fileFunction · 0.45
check_fileFunction · 0.45
test_gcnFunction · 0.45

Calls 5

compute_threshMethod · 0.95
decoderMethod · 0.80
appendMethod · 0.80
toMethod · 0.45
num_nodesMethod · 0.45

Tested by 14

test_trainFunction · 0.36
test_recipeFunction · 0.36
test_gcnFunction · 0.36
test_gcniiFunction · 0.36
test_appnpFunction · 0.36
test_c_and_sFunction · 0.36
test_gatFunction · 0.36
test_hgnnFunction · 0.36
test_hypergraphattFunction · 0.36
test_sgcFunction · 0.36
_test_signFunction · 0.36
test_twirlsFunction · 0.36