(state)
| 114 | model.load_state_dict(torch.load(SAVE_PATH)) |
| 115 | |
| 116 | def pick(state): |
| 117 | with torch.no_grad(): |
| 118 | logits, _ = model(torch.as_tensor(state)) |
| 119 | return int(torch.distributions.Categorical(logits=logits).sample().item()) |
| 120 | |
| 121 | run_test_loop(env, pick) |
| 122 |