对一批数据进行预测,返回混淆矩阵以及Accuracy :param net: :param data_loader: :param set_name: eg: 'valid' 'train' 'tesst :param classes_name: :return:
(net, data_loader, set_name, classes_name)
| 71 | |
| 72 | |
| 73 | def validate(net, data_loader, set_name, classes_name): |
| 74 | """ |
| 75 | 对一批数据进行预测,返回混淆矩阵以及Accuracy |
| 76 | :param net: |
| 77 | :param data_loader: |
| 78 | :param set_name: eg: 'valid' 'train' 'tesst |
| 79 | :param classes_name: |
| 80 | :return: |
| 81 | """ |
| 82 | net.eval() |
| 83 | cls_num = len(classes_name) |
| 84 | conf_mat = np.zeros([cls_num, cls_num]) |
| 85 | |
| 86 | for data in data_loader: |
| 87 | images, labels = data |
| 88 | images = Variable(images) |
| 89 | labels = Variable(labels) |
| 90 | |
| 91 | outputs = net(images) |
| 92 | outputs.detach_() |
| 93 | |
| 94 | _, predicted = torch.max(outputs.data, 1) |
| 95 | |
| 96 | # 统计混淆矩阵 |
| 97 | for i in range(len(labels)): |
| 98 | cate_i = labels[i].numpy() |
| 99 | pre_i = predicted[i].numpy() |
| 100 | conf_mat[cate_i, pre_i] += 1.0 |
| 101 | |
| 102 | for i in range(cls_num): |
| 103 | print('class:{:<10}, total num:{:<6}, correct num:{:<5} Recall: {:.2%} Precision: {:.2%}'.format( |
| 104 | classes_name[i], np.sum(conf_mat[i, :]), conf_mat[i, i], conf_mat[i, i] / (1 + np.sum(conf_mat[i, :])), |
| 105 | conf_mat[i, i] / (1 + np.sum(conf_mat[:, i])))) |
| 106 | |
| 107 | print('{} set Accuracy:{:.2%}'.format(set_name, np.trace(conf_mat) / np.sum(conf_mat))) |
| 108 | |
| 109 | return conf_mat, '{:.2}'.format(np.trace(conf_mat) / np.sum(conf_mat)) |
| 110 | |
| 111 | |
| 112 | def show_confMat(confusion_mat, classes, set_name, out_dir): |
no outgoing calls
no test coverage detected