(obs)
| 80 | model = ActorCritic(n_actions).to(device) |
| 81 | model.load_state_dict(torch.load(SAVE_PATH, map_location=device)) |
| 82 | def policy_action(obs): |
| 83 | with torch.no_grad(): |
| 84 | t = torch.as_tensor(np.asarray(obs), device=device).unsqueeze(0) |
| 85 | logits, _ = model(t) |
| 86 | return int(torch.distributions.Categorical(logits=logits).sample().item()) |
| 87 | run_test_loop(env, policy_action) |
| 88 | |
| 89 | envs = make_vec_env(args, N_ENVS) |