()
| 276 | |
| 277 | |
| 278 | def test_cn_ppo_strategy(): |
| 279 | set_log_with_config(C.logging_config) |
| 280 | # The data starts with 9:31 and ends with 15:00 |
| 281 | orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) |
| 282 | assert len(orders) == 40 |
| 283 | |
| 284 | state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR)) |
| 285 | action_interp = CategoricalActionInterpreter(4) |
| 286 | network = Recurrent(state_interp.observation_space) |
| 287 | policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) |
| 288 | policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu")) |
| 289 | csv_writer = CsvWriter(Path(__file__).parent / ".output") |
| 290 | |
| 291 | backtest( |
| 292 | partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30), |
| 293 | state_interp, |
| 294 | action_interp, |
| 295 | orders, |
| 296 | policy, |
| 297 | [ConsoleWriter(total_episodes=len(orders)), csv_writer], |
| 298 | concurrency=4, |
| 299 | ) |
| 300 | |
| 301 | metrics = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") |
| 302 | assert len(metrics) == len(orders) |
| 303 | assert np.isclose(metrics["ffr"].mean(), 1.0) |
| 304 | assert np.isclose(metrics["pa"].mean(), -16.21578303474833) |
| 305 | assert np.isclose(metrics["market_price"].mean(), 58.68277690875527) |
| 306 | assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002) |
| 307 | |
| 308 | |
| 309 | def test_ppo_train(): |
nothing calls this directly
no test coverage detected