(conf)
| 59 | |
| 60 | |
| 61 | def eval(conf): |
| 62 | logger = util.Logger(conf) |
| 63 | model_name = conf.model_name |
| 64 | dataset_name = "ClassificationDataset" |
| 65 | collate_name = "FastTextCollator" if model_name == "FastText" \ |
| 66 | else "ClassificationCollator" |
| 67 | |
| 68 | test_dataset = globals()[dataset_name](conf, conf.data.test_json_files) |
| 69 | collate_fn = globals()[collate_name](conf, len(test_dataset.label_map)) |
| 70 | test_data_loader = DataLoader( |
| 71 | test_dataset, batch_size=conf.eval.batch_size, shuffle=False, |
| 72 | num_workers=conf.data.num_worker, collate_fn=collate_fn, |
| 73 | pin_memory=True) |
| 74 | |
| 75 | empty_dataset = globals()[dataset_name](conf, []) |
| 76 | model = get_classification_model(model_name, empty_dataset, conf) |
| 77 | optimizer = get_optimizer(conf, model) |
| 78 | load_checkpoint(conf.eval.model_dir, conf, model, optimizer) |
| 79 | model.eval() |
| 80 | is_multi = False |
| 81 | if conf.task_info.label_type == ClassificationType.MULTI_LABEL: |
| 82 | is_multi = True |
| 83 | predict_probs = [] |
| 84 | standard_labels = [] |
| 85 | evaluator = cEvaluator(conf.eval.dir) |
| 86 | for batch in test_data_loader: |
| 87 | if model_name == "HMCN": |
| 88 | (global_logits, local_logits, logits) = model(batch) |
| 89 | else: |
| 90 | logits = model(batch) |
| 91 | if not is_multi: |
| 92 | result = torch.nn.functional.softmax(logits, dim=1).cpu().tolist() |
| 93 | else: |
| 94 | result = torch.sigmoid(logits).cpu().tolist() |
| 95 | predict_probs.extend(result) |
| 96 | standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST]) |
| 97 | (_, precision_list, recall_list, fscore_list, right_list, |
| 98 | predict_list, standard_list) = \ |
| 99 | evaluator.evaluate( |
| 100 | predict_probs, standard_label_ids=standard_labels, label_map=empty_dataset.label_map, |
| 101 | threshold=conf.eval.threshold, top_k=conf.eval.top_k, |
| 102 | is_flat=conf.eval.is_flat, is_multi=is_multi) |
| 103 | logger.warn( |
| 104 | "Performance is precision: %f, " |
| 105 | "recall: %f, fscore: %f, right: %d, predict: %d, standard: %d." % ( |
| 106 | precision_list[0][cEvaluator.MICRO_AVERAGE], |
| 107 | recall_list[0][cEvaluator.MICRO_AVERAGE], |
| 108 | fscore_list[0][cEvaluator.MICRO_AVERAGE], |
| 109 | right_list[0][cEvaluator.MICRO_AVERAGE], |
| 110 | predict_list[0][cEvaluator.MICRO_AVERAGE], |
| 111 | standard_list[0][cEvaluator.MICRO_AVERAGE])) |
| 112 | evaluator.save() |
| 113 | |
| 114 | |
| 115 | if __name__ == '__main__': |
no test coverage detected