MCPcopy
hub / github.com/microsoft/qlib / test_cn_ppo_strategy

Function test_cn_ppo_strategy

tests/rl/test_saoe_simple.py:278–306  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

276
277
278def test_cn_ppo_strategy():
279 set_log_with_config(C.logging_config)
280 # The data starts with 9:31 and ends with 15:00
281 orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
282 assert len(orders) == 40
283
284 state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
285 action_interp = CategoricalActionInterpreter(4)
286 network = Recurrent(state_interp.observation_space)
287 policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
288 policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
289 csv_writer = CsvWriter(Path(__file__).parent / ".output")
290
291 backtest(
292 partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),
293 state_interp,
294 action_interp,
295 orders,
296 policy,
297 [ConsoleWriter(total_episodes=len(orders)), csv_writer],
298 concurrency=4,
299 )
300
301 metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
302 assert len(metrics) == len(orders)
303 assert np.isclose(metrics["ffr"].mean(), 1.0)
304 assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
305 assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
306 assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)
307
308
309def test_ppo_train():

Callers

nothing calls this directly

Calls 12

set_log_with_configFunction · 0.90
CsvWriterClass · 0.90
backtestFunction · 0.90
ConsoleWriterClass · 0.90
RecurrentClass · 0.85
PPOClass · 0.85
load_state_dictMethod · 0.45
loadMethod · 0.45
meanMethod · 0.45

Tested by

no test coverage detected