MCPcopy
hub / github.com/google-deepmind/acme / __call__

Method __call__

acme/tf/networks/stochastic.py:70–105  ·  view source on GitHub ↗
(self, inputs: types.NestedTensor)

Source from the content-addressed store, hash-verified

68 self._beta = beta
69
70 def __call__(self, inputs: types.NestedTensor) -> tf.Tensor:
71 # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...].
72 tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples)
73 shape = tf.shape(tree.flatten(tiled_inputs)[0])
74 n, b = shape[0], shape[1]
75 tf.debugging.assert_equal(n, self._num_action_samples,
76 'Internal Error. Unexpected tiled_inputs shape.')
77 dummy_zeros_n_b = tf.zeros((n, b))
78 # Reshape to [N * B, ...].
79 merge = lambda x: snt.merge_leading_dims(x, 2)
80 tiled_inputs = tree.map_structure(merge, tiled_inputs)
81
82 tiled_actions = self._actor_network(tiled_inputs)
83
84 # Compute Q-values and the resulting tempered probabilities.
85 q = self._critic_network(tiled_inputs, tiled_actions)
86 boltzmann_logits = q / self._beta
87
88 boltzmann_logits = snt.split_leading_dim(boltzmann_logits, dummy_zeros_n_b,
89 2)
90 # [B, N]
91 boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0))
92 # Resample one action per batch according to the Boltzmann distribution.
93 action_idx = tfp.distributions.Categorical(logits=boltzmann_logits).sample()
94 # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to
95 # the batch dimension.
96 action_idx = tf.stack((tf.range(b), action_idx), axis=1)
97
98 tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b, 2)
99 action_dim = len(tiled_actions.get_shape().as_list())
100 tiled_actions = tf.transpose(tiled_actions,
101 perm=[1, 0] + list(range(2, action_dim)))
102 # [B, ...]
103 action_sample = tf.gather_nd(tiled_actions, action_idx)
104
105 return action_sample

Callers

nothing calls this directly

Calls 1

sampleMethod · 0.80

Tested by

no test coverage detected