| 266 | |
| 267 | |
| 268 | def test_train_dagger_main(tmpdir): |
| 269 | with pytest.warns(None) as record: |
| 270 | run = train_imitation.train_imitation_ex.run( |
| 271 | command_name="dagger", |
| 272 | named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], |
| 273 | config_updates=dict( |
| 274 | logging=dict(log_root=tmpdir), |
| 275 | demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), |
| 276 | ), |
| 277 | ) |
| 278 | for warning in record: |
| 279 | # PyTorch wants writeable arrays. |
| 280 | # See https://github.com/HumanCompatibleAI/imitation/issues/219 |
| 281 | assert not ( |
| 282 | warning.category == UserWarning |
| 283 | and "NumPy array is not writeable" in warning.message.args[0] |
| 284 | ) |
| 285 | assert run.status == "COMPLETED" |
| 286 | assert isinstance(run.result, dict) |
| 287 | |
| 288 | |
| 289 | def test_train_dagger_warmstart(tmpdir): |