MCPcopy
hub / github.com/microsoft/qlib / predict

Method predict

examples/benchmarks/TFT/tft.py:252–284  ·  view source on GitHub ↗
(self, dataset)

Source from the content-addressed store, hash-verified

250 # ===========================Training Process===========================
251
252 def predict(self, dataset):
253 if self.model is None:
254 raise ValueError("model is not fitted yet!")
255 d_test = dataset.prepare("test", col_set=["feature", "label"])
256 d_test = transform_df(d_test)
257 d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
258 test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
259
260 use_gpu = (True, self.gpu_id)
261 # ===========================Predicting Process===========================
262 default_keras_session = tf.keras.backend.get_session()
263
264 # Sets up default params
265 fixed_params = self.data_formatter.get_experiment_params()
266 params = self.data_formatter.get_default_model_params()
267 params = {**params, **fixed_params}
268
269 print("*** Begin predicting ***")
270 tf.reset_default_graph()
271
272 with self.tf_graph.as_default():
273 tf.keras.backend.set_session(self.sess)
274 output_map = self.model.predict(test, return_targets=True)
275 targets = self.data_formatter.format_predictions(output_map["targets"])
276 p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
277 p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
278 tf.keras.backend.set_session(default_keras_session)
279
280 predict50 = format_score(p50_forecast, "pred", 1)
281 predict90 = format_score(p90_forecast, "pred", 1)
282 predict = (predict50 + predict90) / 2 # self.label_shift
283 # ===========================Predicting Process===========================
284 return predict
285
286 def finetune(self, dataset: DatasetH):
287 """

Callers 1

prepare_dataFunction · 0.45

Calls 8

transform_dfFunction · 0.85
get_shifted_labelFunction · 0.85
process_qlib_dataFunction · 0.85
format_scoreFunction · 0.85
get_experiment_paramsMethod · 0.80
prepareMethod · 0.45
format_predictionsMethod · 0.45

Tested by

no test coverage detected