MCPcopy
hub / github.com/Tencent/NeuralNLP-NeuralClassifier / eval

Function eval

eval.py:61–112  ·  view source on GitHub ↗
(conf)

Source from the content-addressed store, hash-verified

59
60
61def eval(conf):
62 logger = util.Logger(conf)
63 model_name = conf.model_name
64 dataset_name = "ClassificationDataset"
65 collate_name = "FastTextCollator" if model_name == "FastText" \
66 else "ClassificationCollator"
67
68 test_dataset = globals()[dataset_name](conf, conf.data.test_json_files)
69 collate_fn = globals()[collate_name](conf, len(test_dataset.label_map))
70 test_data_loader = DataLoader(
71 test_dataset, batch_size=conf.eval.batch_size, shuffle=False,
72 num_workers=conf.data.num_worker, collate_fn=collate_fn,
73 pin_memory=True)
74
75 empty_dataset = globals()[dataset_name](conf, [])
76 model = get_classification_model(model_name, empty_dataset, conf)
77 optimizer = get_optimizer(conf, model)
78 load_checkpoint(conf.eval.model_dir, conf, model, optimizer)
79 model.eval()
80 is_multi = False
81 if conf.task_info.label_type == ClassificationType.MULTI_LABEL:
82 is_multi = True
83 predict_probs = []
84 standard_labels = []
85 evaluator = cEvaluator(conf.eval.dir)
86 for batch in test_data_loader:
87 if model_name == "HMCN":
88 (global_logits, local_logits, logits) = model(batch)
89 else:
90 logits = model(batch)
91 if not is_multi:
92 result = torch.nn.functional.softmax(logits, dim=1).cpu().tolist()
93 else:
94 result = torch.sigmoid(logits).cpu().tolist()
95 predict_probs.extend(result)
96 standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST])
97 (_, precision_list, recall_list, fscore_list, right_list,
98 predict_list, standard_list) = \
99 evaluator.evaluate(
100 predict_probs, standard_label_ids=standard_labels, label_map=empty_dataset.label_map,
101 threshold=conf.eval.threshold, top_k=conf.eval.top_k,
102 is_flat=conf.eval.is_flat, is_multi=is_multi)
103 logger.warn(
104 "Performance is precision: %f, "
105 "recall: %f, fscore: %f, right: %d, predict: %d, standard: %d." % (
106 precision_list[0][cEvaluator.MICRO_AVERAGE],
107 recall_list[0][cEvaluator.MICRO_AVERAGE],
108 fscore_list[0][cEvaluator.MICRO_AVERAGE],
109 right_list[0][cEvaluator.MICRO_AVERAGE],
110 predict_list[0][cEvaluator.MICRO_AVERAGE],
111 standard_list[0][cEvaluator.MICRO_AVERAGE]))
112 evaluator.save()
113
114
115if __name__ == '__main__':

Callers 1

eval.pyFile · 0.85

Calls 7

warnMethod · 0.95
get_optimizerFunction · 0.90
evalMethod · 0.80
evaluateMethod · 0.80
saveMethod · 0.80
get_classification_modelFunction · 0.70
load_checkpointFunction · 0.70

Tested by

no test coverage detected