Extract values from a 1-D numpy array for a batch of indices. :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param broadcast_shape: a larger shape of K dimensions with the batch dimension equal to the l
(arr, timesteps, broadcast_shape)
| 826 | |
| 827 | |
| 828 | def _extract_into_tensor(arr, timesteps, broadcast_shape): |
| 829 | """ |
| 830 | Extract values from a 1-D numpy array for a batch of indices. |
| 831 | |
| 832 | :param arr: the 1-D numpy array. |
| 833 | :param timesteps: a tensor of indices into the array to extract. |
| 834 | :param broadcast_shape: a larger shape of K dimensions with the batch |
| 835 | dimension equal to the length of timesteps. |
| 836 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. |
| 837 | """ |
| 838 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() |
| 839 | while len(res.shape) < len(broadcast_shape): |
| 840 | res = res[..., None] |
| 841 | return res.expand(broadcast_shape) |
no outgoing calls
no test coverage detected