| 68 | self._beta = beta |
| 69 | |
| 70 | def __call__(self, inputs: types.NestedTensor) -> tf.Tensor: |
| 71 | # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...]. |
| 72 | tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples) |
| 73 | shape = tf.shape(tree.flatten(tiled_inputs)[0]) |
| 74 | n, b = shape[0], shape[1] |
| 75 | tf.debugging.assert_equal(n, self._num_action_samples, |
| 76 | 'Internal Error. Unexpected tiled_inputs shape.') |
| 77 | dummy_zeros_n_b = tf.zeros((n, b)) |
| 78 | # Reshape to [N * B, ...]. |
| 79 | merge = lambda x: snt.merge_leading_dims(x, 2) |
| 80 | tiled_inputs = tree.map_structure(merge, tiled_inputs) |
| 81 | |
| 82 | tiled_actions = self._actor_network(tiled_inputs) |
| 83 | |
| 84 | # Compute Q-values and the resulting tempered probabilities. |
| 85 | q = self._critic_network(tiled_inputs, tiled_actions) |
| 86 | boltzmann_logits = q / self._beta |
| 87 | |
| 88 | boltzmann_logits = snt.split_leading_dim(boltzmann_logits, dummy_zeros_n_b, |
| 89 | 2) |
| 90 | # [B, N] |
| 91 | boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0)) |
| 92 | # Resample one action per batch according to the Boltzmann distribution. |
| 93 | action_idx = tfp.distributions.Categorical(logits=boltzmann_logits).sample() |
| 94 | # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to |
| 95 | # the batch dimension. |
| 96 | action_idx = tf.stack((tf.range(b), action_idx), axis=1) |
| 97 | |
| 98 | tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b, 2) |
| 99 | action_dim = len(tiled_actions.get_shape().as_list()) |
| 100 | tiled_actions = tf.transpose(tiled_actions, |
| 101 | perm=[1, 0] + list(range(2, action_dim))) |
| 102 | # [B, ...] |
| 103 | action_sample = tf.gather_nd(tiled_actions, action_idx) |
| 104 | |
| 105 | return action_sample |