MCPcopy
hub / github.com/microsoft/qlib / test_epoch

Method test_epoch

qlib/contrib/model/pytorch_igmtf.py:222–246  ·  view source on GitHub ↗
(self, data_x, data_y, train_hidden, train_hidden_day)

Source from the content-addressed store, hash-verified

220 self.train_optimizer.step()
221
222 def test_epoch(self, data_x, data_y, train_hidden, train_hidden_day):
223 # prepare training data
224 x_values = data_x.values
225 y_values = np.squeeze(data_y.values)
226
227 self.igmtf_model.eval()
228
229 scores = []
230 losses = []
231
232 daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
233
234 for idx, count in zip(daily_index, daily_count):
235 batch = slice(idx, idx + count)
236 feature = torch.from_numpy(x_values[batch]).float().to(self.device)
237 label = torch.from_numpy(y_values[batch]).float().to(self.device)
238
239 pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
240 loss = self.loss_fn(pred, label)
241 losses.append(loss.item())
242
243 score = self.metric_fn(pred, label)
244 scores.append(score.item())
245
246 return np.mean(losses), np.mean(scores)
247
248 def fit(
249 self,

Callers 1

fitMethod · 0.95

Calls 5

get_daily_interMethod · 0.95
loss_fnMethod · 0.95
metric_fnMethod · 0.95
evalMethod · 0.45
meanMethod · 0.45

Tested by

no test coverage detected