(prepare_data_fun, out_label)
| 110 | |
| 111 | |
| 112 | def print_eval(prepare_data_fun, out_label): |
| 113 | model_file = os.path.join(snapshot_dir, "best_model.pth") |
| 114 | pkl_res_file = os.path.join(snapshot_dir, "best_model_predict_%s.pkl" % out_label) |
| 115 | out_file = os.path.join(snapshot_dir, "best_model_predict_%s.json" % out_label) |
| 116 | |
| 117 | data_set_test = prepare_data_fun(**cfg["data"], **cfg["model"], verbose=True) |
| 118 | data_reader_test = DataLoader( |
| 119 | data_set_test, |
| 120 | shuffle=False, |
| 121 | batch_size=cfg.data.batch_size, |
| 122 | num_workers=cfg.data.num_workers, |
| 123 | ) |
| 124 | ans_dic = data_set_test.answer_dict |
| 125 | |
| 126 | model = build_model(cfg, data_set_test) |
| 127 | model.load_state_dict(torch.load(model_file)["state_dict"]) |
| 128 | model.eval() |
| 129 | |
| 130 | question_ids, soft_max_result = run_model(model, data_reader_test, ans_dic.UNK_idx) |
| 131 | print_result( |
| 132 | question_ids, |
| 133 | soft_max_result, |
| 134 | ans_dic, |
| 135 | out_file, |
| 136 | json_only=False, |
| 137 | pkl_res_file=pkl_res_file, |
| 138 | ) |
| 139 | |
| 140 | |
| 141 | if __name__ == "__main__": |
no test coverage detected