MCPcopy
hub / github.com/google-deepmind/acme / _call

Method _call

acme/tf/networks/recurrence.py:202–243  ·  view source on GitHub ↗

Computes a forward step for a single element. The observation and state are packed together in order to use `tf.vectorized_map` to handle batches of observations. See this module's __call__() function. Args: observation_and_state: the observation and state packed in a tuple.

(
      self, observation_and_state: Tuple[types.NestedTensor,
                                         PolicyCriticRNNState]
  )

Source from the content-addressed store, hash-verified

200 return tf.vectorized_map(self._call, (observation, prev_state))
201
202 def _call(
203 self, observation_and_state: Tuple[types.NestedTensor,
204 PolicyCriticRNNState]
205 ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]:
206 """Computes a forward step for a single element.
207
208 The observation and state are packed together in order to use
209 `tf.vectorized_map` to handle batches of observations.
210 See this module's __call__() function.
211
212 Args:
213 observation_and_state: the observation and state packed in a tuple.
214
215 Returns:
216 The selected action and the corresponding state.
217 """
218 observation, prev_state = observation_and_state
219
220 # Tile input observations and states to allow multiple policy predictions.
221 tiled_observation, tiled_prev_state = utils.tile_nested(
222 (observation, prev_state), self._num_action_samples)
223 actions, policy_states = self._policy_network(
224 tiled_observation, tiled_prev_state.policy)
225
226 # Evaluate multiple critic predictions with the sampled actions.
227 value_distribution, critic_states = self._critic_network(
228 tiled_observation, actions, tiled_prev_state.critic)
229 value_estimate = value_distribution.mean()
230
231 # Resample a single action of the sampled actions according to logits given
232 # by the tempered Q-values.
233 selected_action_idx = tfp.distributions.Categorical(
234 probs=tf.nn.softmax(value_estimate / self._temperature_beta)).sample()
235 selected_action = actions[selected_action_idx]
236
237 # Select and return the RNN state that corresponds to the selected action.
238 states = PolicyCriticRNNState(
239 policy=policy_states, critic=critic_states)
240 selected_state = tree.map_structure(
241 lambda x: x[selected_action_idx], states)
242
243 return selected_action, selected_state
244
245 def initial_state(self, batch_size: int) -> PolicyCriticRNNState:
246 return PolicyCriticRNNState(

Callers

nothing calls this directly

Calls 3

_policy_networkMethod · 0.80
sampleMethod · 0.80

Tested by

no test coverage detected