Extract values from a 1-D numpy array for a batch of indices. :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param broadcast_shape: a larger shape of K dimensions with the batch dimension equal to t
(arr, timesteps, broadcast_shape)
| 867 | |
| 868 | |
| 869 | def _extract_into_tensor(arr, timesteps, broadcast_shape): |
| 870 | """ |
| 871 | Extract values from a 1-D numpy array for a batch of indices. |
| 872 | :param arr: the 1-D numpy array. |
| 873 | :param timesteps: a tensor of indices into the array to extract. |
| 874 | :param broadcast_shape: a larger shape of K dimensions with the batch |
| 875 | dimension equal to the length of timesteps. |
| 876 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. |
| 877 | """ |
| 878 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() |
| 879 | while len(res.shape) < len(broadcast_shape): |
| 880 | res = res[..., None] |
| 881 | return res + th.zeros(broadcast_shape, device=timesteps.device) |
no outgoing calls
no test coverage detected