测试 :param testDataList:测试数据集 :param testLabelList: 测试标签集 :param tree: 提升树 :return: 准确率
(testDataList, testLabelList, tree)
| 214 | else: return H |
| 215 | |
| 216 | def model_test(testDataList, testLabelList, tree): |
| 217 | ''' |
| 218 | 测试 |
| 219 | :param testDataList:测试数据集 |
| 220 | :param testLabelList: 测试标签集 |
| 221 | :param tree: 提升树 |
| 222 | :return: 准确率 |
| 223 | ''' |
| 224 | #错误率计数值 |
| 225 | errorCnt = 0 |
| 226 | #遍历每一个测试样本 |
| 227 | for i in range(len(testDataList)): |
| 228 | #预测结果值,初始为0 |
| 229 | result = 0 |
| 230 | #依据算法8.1式8.6 |
| 231 | #预测式子是一个求和式,对于每一层的结果都要进行一次累加 |
| 232 | #遍历每层的树 |
| 233 | for curTree in tree: |
| 234 | #获取该层参数 |
| 235 | div = curTree['div'] |
| 236 | rule = curTree['rule'] |
| 237 | feature = curTree['feature'] |
| 238 | alpha = curTree['alpha'] |
| 239 | #将当前层结果加入预测中 |
| 240 | result += alpha * predict(testDataList[i], div, rule, feature) |
| 241 | #预测结果取sign值,如果大于0 sign为1,反之为0 |
| 242 | if np.sign(result) != testLabelList[i]: errorCnt += 1 |
| 243 | #返回准确率 |
| 244 | return 1 - errorCnt / len(testDataList) |
| 245 | |
| 246 | if __name__ == '__main__': |
| 247 | #开始时间 |