evaluate function
(dataloader, model, prog_args, logger=None)
| 324 | |
| 325 | |
| 326 | def evaluate(dataloader, model, prog_args, logger=None): |
| 327 | """ |
| 328 | evaluate function |
| 329 | """ |
| 330 | if logger is not None and prog_args.save_dir is not None: |
| 331 | model.load_state_dict( |
| 332 | torch.load( |
| 333 | prog_args.save_dir |
| 334 | + "/" |
| 335 | + prog_args.dataset |
| 336 | + "/model.iter-" |
| 337 | + str(logger["best_epoch"]) |
| 338 | ) |
| 339 | ) |
| 340 | model.eval() |
| 341 | correct_label = 0 |
| 342 | with torch.no_grad(): |
| 343 | for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader): |
| 344 | for key, value in batch_graph.ndata.items(): |
| 345 | batch_graph.ndata[key] = value.float() |
| 346 | graph_labels = graph_labels.long() |
| 347 | if torch.cuda.is_available(): |
| 348 | batch_graph = batch_graph.to(torch.cuda.current_device()) |
| 349 | graph_labels = graph_labels.cuda() |
| 350 | ypred = model(batch_graph) |
| 351 | indi = torch.argmax(ypred, dim=1) |
| 352 | correct = torch.sum(indi == graph_labels) |
| 353 | correct_label += correct.item() |
| 354 | result = correct_label / (len(dataloader) * prog_args.batch_size) |
| 355 | return result |
| 356 | |
| 357 | |
| 358 | def main(): |