MCPcopy Index your code
hub / github.com/dmlc/dgl / evaluate

Function evaluate

examples/pytorch/diffpool/train.py:326–355  ·  view source on GitHub ↗

evaluate function

(dataloader, model, prog_args, logger=None)

Source from the content-addressed store, hash-verified

324
325
326def evaluate(dataloader, model, prog_args, logger=None):
327 """
328 evaluate function
329 """
330 if logger is not None and prog_args.save_dir is not None:
331 model.load_state_dict(
332 torch.load(
333 prog_args.save_dir
334 + "/"
335 + prog_args.dataset
336 + "/model.iter-"
337 + str(logger["best_epoch"])
338 )
339 )
340 model.eval()
341 correct_label = 0
342 with torch.no_grad():
343 for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
344 for key, value in batch_graph.ndata.items():
345 batch_graph.ndata[key] = value.float()
346 graph_labels = graph_labels.long()
347 if torch.cuda.is_available():
348 batch_graph = batch_graph.to(torch.cuda.current_device())
349 graph_labels = graph_labels.cuda()
350 ypred = model(batch_graph)
351 indi = torch.argmax(ypred, dim=1)
352 correct = torch.sum(indi == graph_labels)
353 correct_label += correct.item()
354 result = correct_label / (len(dataloader) * prog_args.batch_size)
355 return result
356
357
358def main():

Callers 2

graph_classify_taskFunction · 0.70
trainFunction · 0.70

Calls 7

load_state_dictMethod · 0.80
cudaMethod · 0.80
loadMethod · 0.45
itemsMethod · 0.45
floatMethod · 0.45
longMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected