Run one episode. Each episode is a loop which interacts first with the environment to get an observation and then give that observation to the agent in order to retrieve an action. Returns: An instance of `loggers.LoggingData`.
(self)
| 74 | self._observers = observers |
| 75 | |
| 76 | def run_episode(self) -> loggers.LoggingData: |
| 77 | """Run one episode. |
| 78 | |
| 79 | Each episode is a loop which interacts first with the environment to get an |
| 80 | observation and then give that observation to the agent in order to retrieve |
| 81 | an action. |
| 82 | |
| 83 | Returns: |
| 84 | An instance of `loggers.LoggingData`. |
| 85 | """ |
| 86 | # Reset any counts and start the environment. |
| 87 | start_time = time.time() |
| 88 | episode_steps = 0 |
| 89 | |
| 90 | # For evaluation, this keeps track of the total undiscounted reward |
| 91 | # accumulated during the episode. |
| 92 | episode_return = tree.map_structure(_generate_zeros_from_spec, |
| 93 | self._environment.reward_spec()) |
| 94 | timestep = self._environment.reset() |
| 95 | # Make the first observation. |
| 96 | self._actor.observe_first(timestep) |
| 97 | for observer in self._observers: |
| 98 | # Initialize the observer with the current state of the env after reset |
| 99 | # and the initial timestep. |
| 100 | observer.observe_first(self._environment, timestep) |
| 101 | |
| 102 | # Run an episode. |
| 103 | while not timestep.last(): |
| 104 | # Generate an action from the agent's policy and step the environment. |
| 105 | action = self._actor.select_action(timestep.observation) |
| 106 | timestep = self._environment.step(action) |
| 107 | |
| 108 | # Have the agent observe the timestep and let the actor update itself. |
| 109 | self._actor.observe(action, next_timestep=timestep) |
| 110 | for observer in self._observers: |
| 111 | # One environment step was completed. Observe the current state of the |
| 112 | # environment, the current timestep and the action. |
| 113 | observer.observe(self._environment, timestep, action) |
| 114 | if self._should_update: |
| 115 | self._actor.update() |
| 116 | |
| 117 | # Book-keeping. |
| 118 | episode_steps += 1 |
| 119 | |
| 120 | # Equivalent to: episode_return += timestep.reward |
| 121 | # We capture the return value because if timestep.reward is a JAX |
| 122 | # DeviceArray, episode_return will not be mutated in-place. (In all other |
| 123 | # cases, the returned episode_return will be the same object as the |
| 124 | # argument episode_return.) |
| 125 | episode_return = tree.map_structure(operator.iadd, |
| 126 | episode_return, |
| 127 | timestep.reward) |
| 128 | |
| 129 | # Record counts. |
| 130 | counts = self._counter.increment(episodes=1, steps=episode_steps) |
| 131 | |
| 132 | # Collect the results and combine with counts. |
| 133 | steps_per_second = episode_steps / (time.time() - start_time) |