MCPcopy
hub / github.com/facebookresearch/mmf / print_eval

Function print_eval

pythia/legacy/train.py:112–138  ·  view source on GitHub ↗
(prepare_data_fun, out_label)

Source from the content-addressed store, hash-verified

110
111
112def print_eval(prepare_data_fun, out_label):
113 model_file = os.path.join(snapshot_dir, "best_model.pth")
114 pkl_res_file = os.path.join(snapshot_dir, "best_model_predict_%s.pkl" % out_label)
115 out_file = os.path.join(snapshot_dir, "best_model_predict_%s.json" % out_label)
116
117 data_set_test = prepare_data_fun(**cfg["data"], **cfg["model"], verbose=True)
118 data_reader_test = DataLoader(
119 data_set_test,
120 shuffle=False,
121 batch_size=cfg.data.batch_size,
122 num_workers=cfg.data.num_workers,
123 )
124 ans_dic = data_set_test.answer_dict
125
126 model = build_model(cfg, data_set_test)
127 model.load_state_dict(torch.load(model_file)["state_dict"])
128 model.eval()
129
130 question_ids, soft_max_result = run_model(model, data_reader_test, ans_dic.UNK_idx)
131 print_result(
132 question_ids,
133 soft_max_result,
134 ans_dic,
135 out_file,
136 json_only=False,
137 pkl_res_file=pkl_res_file,
138 )
139
140
141if __name__ == "__main__":

Callers 1

train.pyFile · 0.85

Calls 5

build_modelFunction · 0.90
run_modelFunction · 0.90
print_resultFunction · 0.90
load_state_dictMethod · 0.80
loadMethod · 0.45

Tested by

no test coverage detected