MCPcopy
hub / github.com/microsoft/qlib / test_epoch

Method test_epoch

examples/benchmarks/TRA/src/model.py:146–201  ·  view source on GitHub ↗
(self, data_set, return_pred=False)

Source from the content-addressed store, hash-verified

144 return total_loss
145
146 def test_epoch(self, data_set, return_pred=False):
147 self.model.eval()
148 self.tra.eval()
149 data_set.eval()
150
151 preds = []
152 metrics = []
153 for batch in tqdm(data_set):
154 data, label, index = batch["data"], batch["label"], batch["index"]
155
156 feature = data[:, :, : -self.tra.num_states]
157 hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
158
159 with torch.no_grad():
160 hidden = self.model(feature)
161 pred, all_preds, prob = self.tra(hidden, hist_loss)
162
163 L = (all_preds - label[:, None]).pow(2)
164
165 L -= L.min(dim=-1, keepdim=True).values # normalize & ensure positive input
166
167 data_set.assign_data(index, L) # save loss to memory
168
169 X = np.c_[
170 pred.cpu().numpy(),
171 label.cpu().numpy(),
172 ]
173 columns = ["score", "label"]
174 if prob is not None:
175 X = np.c_[X, all_preds.cpu().numpy(), prob.cpu().numpy()]
176 columns += ["score_%d" % d for d in range(all_preds.shape[1])] + [
177 "prob_%d" % d for d in range(all_preds.shape[1])
178 ]
179
180 pred = pd.DataFrame(X, index=index.cpu().numpy(), columns=columns)
181
182 metrics.append(evaluate(pred))
183
184 if return_pred:
185 preds.append(pred)
186
187 metrics = pd.DataFrame(metrics)
188 metrics = {
189 "MSE": metrics.MSE.mean(),
190 "MAE": metrics.MAE.mean(),
191 "IC": metrics.IC.mean(),
192 "ICIR": metrics.IC.mean() / metrics.IC.std(),
193 }
194
195 if return_pred:
196 preds = pd.concat(preds, axis=0)
197 preds.index = data_set.restore_index(preds.index)
198 preds.index = preds.index.swaplevel()
199 preds.sort_index(inplace=True)
200
201 return metrics, preds
202
203 def fit(self, dataset, evals_result=dict()):

Callers 2

fitMethod · 0.95
predictMethod · 0.95

Calls 6

evaluateFunction · 0.70
evalMethod · 0.45
assign_dataMethod · 0.45
meanMethod · 0.45
restore_indexMethod · 0.45
sort_indexMethod · 0.45

Tested by

no test coverage detected