MCPcopy
hub / github.com/dmlc/dgl / evaluate

Function evaluate

examples/distributed/graphsage/node_classification.py:158–194  ·  view source on GitHub ↗

Evaluate the model on the validation and test set. Parameters ---------- model : DistSAGE The model to be evaluated. g : DistGraph The entire graph. inputs : DistTensor The feature data of all the nodes. labels : DistTensor The labels of

(model, g, inputs, labels, val_nid, test_nid, batch_size, device)

Source from the content-addressed store, hash-verified

156
157
158def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
159 """
160 Evaluate the model on the validation and test set.
161
162 Parameters
163 ----------
164 model : DistSAGE
165 The model to be evaluated.
166 g : DistGraph
167 The entire graph.
168 inputs : DistTensor
169 The feature data of all the nodes.
170 labels : DistTensor
171 The labels of all the nodes.
172 val_nid : torch.Tensor
173 The node IDs for validation.
174 test_nid : torch.Tensor
175 The node IDs for test.
176 batch_size : int
177 Batch size for evaluation.
178 device : torch.Device
179 The target device to evaluate on.
180
181 Returns
182 -------
183 float
184 Validation accuracy.
185 float
186 Test accuracy.
187 """
188 model.eval()
189 with th.no_grad():
190 pred = model.inference(g, inputs, batch_size, device)
191 model.train()
192 return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(
193 pred[test_nid], labels[test_nid]
194 )
195
196
197def run(args, device, data):

Callers 1

runFunction · 0.70

Calls 3

compute_accFunction · 0.70
inferenceMethod · 0.45
trainMethod · 0.45

Tested by

no test coverage detected