| 287 | |
| 288 | |
| 289 | def test_train_dagger_warmstart(tmpdir): |
| 290 | run = train_imitation.train_imitation_ex.run( |
| 291 | command_name="dagger", |
| 292 | named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], |
| 293 | config_updates=dict( |
| 294 | logging=dict(log_root=tmpdir), |
| 295 | demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), |
| 296 | ), |
| 297 | ) |
| 298 | assert run.status == "COMPLETED" |
| 299 | |
| 300 | log_dir = util.parse_path(run.config["logging"]["log_dir"]) |
| 301 | policy_path = log_dir / "scratch" / "policy-latest.pt" |
| 302 | run_warmstart = train_imitation.train_imitation_ex.run( |
| 303 | command_name="dagger", |
| 304 | named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], |
| 305 | config_updates=dict( |
| 306 | logging=dict(log_root=tmpdir), |
| 307 | demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), |
| 308 | bc=dict(agent_path=policy_path), |
| 309 | ), |
| 310 | ) |
| 311 | assert run_warmstart.status == "COMPLETED" |
| 312 | assert isinstance(run_warmstart.result, dict) |
| 313 | |
| 314 | |
| 315 | def test_train_bc_main_with_none_demonstrations_raises_value_error(tmpdir): |