Computes a forward step for a single element. The observation and state are packed together in order to use `tf.vectorized_map` to handle batches of observations. See this module's __call__() function. Args: observation_and_state: the observation and state packed in a tuple.
(
self, observation_and_state: Tuple[types.NestedTensor,
PolicyCriticRNNState]
)
| 200 | return tf.vectorized_map(self._call, (observation, prev_state)) |
| 201 | |
| 202 | def _call( |
| 203 | self, observation_and_state: Tuple[types.NestedTensor, |
| 204 | PolicyCriticRNNState] |
| 205 | ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]: |
| 206 | """Computes a forward step for a single element. |
| 207 | |
| 208 | The observation and state are packed together in order to use |
| 209 | `tf.vectorized_map` to handle batches of observations. |
| 210 | See this module's __call__() function. |
| 211 | |
| 212 | Args: |
| 213 | observation_and_state: the observation and state packed in a tuple. |
| 214 | |
| 215 | Returns: |
| 216 | The selected action and the corresponding state. |
| 217 | """ |
| 218 | observation, prev_state = observation_and_state |
| 219 | |
| 220 | # Tile input observations and states to allow multiple policy predictions. |
| 221 | tiled_observation, tiled_prev_state = utils.tile_nested( |
| 222 | (observation, prev_state), self._num_action_samples) |
| 223 | actions, policy_states = self._policy_network( |
| 224 | tiled_observation, tiled_prev_state.policy) |
| 225 | |
| 226 | # Evaluate multiple critic predictions with the sampled actions. |
| 227 | value_distribution, critic_states = self._critic_network( |
| 228 | tiled_observation, actions, tiled_prev_state.critic) |
| 229 | value_estimate = value_distribution.mean() |
| 230 | |
| 231 | # Resample a single action of the sampled actions according to logits given |
| 232 | # by the tempered Q-values. |
| 233 | selected_action_idx = tfp.distributions.Categorical( |
| 234 | probs=tf.nn.softmax(value_estimate / self._temperature_beta)).sample() |
| 235 | selected_action = actions[selected_action_idx] |
| 236 | |
| 237 | # Select and return the RNN state that corresponds to the selected action. |
| 238 | states = PolicyCriticRNNState( |
| 239 | policy=policy_states, critic=critic_states) |
| 240 | selected_state = tree.map_structure( |
| 241 | lambda x: x[selected_action_idx], states) |
| 242 | |
| 243 | return selected_action, selected_state |
| 244 | |
| 245 | def initial_state(self, batch_size: int) -> PolicyCriticRNNState: |
| 246 | return PolicyCriticRNNState( |
nothing calls this directly
no test coverage detected