MCPcopy
hub / github.com/TingsongYu/PyTorch_Tutorial / validate

Function validate

Code/utils/utils.py:73–109  ·  view source on GitHub ↗

对一批数据进行预测,返回混淆矩阵以及Accuracy :param net: :param data_loader: :param set_name: eg: 'valid' 'train' 'tesst :param classes_name: :return:

(net, data_loader, set_name, classes_name)

Source from the content-addressed store, hash-verified

71
72
73def validate(net, data_loader, set_name, classes_name):
74 """
75 对一批数据进行预测,返回混淆矩阵以及Accuracy
76 :param net:
77 :param data_loader:
78 :param set_name: eg: 'valid' 'train' 'tesst
79 :param classes_name:
80 :return:
81 """
82 net.eval()
83 cls_num = len(classes_name)
84 conf_mat = np.zeros([cls_num, cls_num])
85
86 for data in data_loader:
87 images, labels = data
88 images = Variable(images)
89 labels = Variable(labels)
90
91 outputs = net(images)
92 outputs.detach_()
93
94 _, predicted = torch.max(outputs.data, 1)
95
96 # 统计混淆矩阵
97 for i in range(len(labels)):
98 cate_i = labels[i].numpy()
99 pre_i = predicted[i].numpy()
100 conf_mat[cate_i, pre_i] += 1.0
101
102 for i in range(cls_num):
103 print('class:{:<10}, total num:{:<6}, correct num:{:<5} Recall: {:.2%} Precision: {:.2%}'.format(
104 classes_name[i], np.sum(conf_mat[i, :]), conf_mat[i, i], conf_mat[i, i] / (1 + np.sum(conf_mat[i, :])),
105 conf_mat[i, i] / (1 + np.sum(conf_mat[:, i]))))
106
107 print('{} set Accuracy:{:.2%}'.format(set_name, np.trace(conf_mat) / np.sum(conf_mat)))
108
109 return conf_mat, '{:.2}'.format(np.trace(conf_mat) / np.sum(conf_mat))
110
111
112def show_confMat(confusion_mat, classes, set_name, out_dir):

Callers 2

main.pyFile · 0.90
2_finetune.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected