Extract values from a 1-D numpy array for a batch of indices. :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param broadcast_shape: a larger shape of K dimensions with the batch dimension equal to the l
(arr, timesteps, broadcast_shape)
| 1600 | |
| 1601 | |
| 1602 | def _extract_into_tensor(arr, timesteps, broadcast_shape): |
| 1603 | """ |
| 1604 | Extract values from a 1-D numpy array for a batch of indices. |
| 1605 | |
| 1606 | :param arr: the 1-D numpy array. |
| 1607 | :param timesteps: a tensor of indices into the array to extract. |
| 1608 | :param broadcast_shape: a larger shape of K dimensions with the batch |
| 1609 | dimension equal to the length of timesteps. |
| 1610 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. |
| 1611 | """ |
| 1612 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() |
| 1613 | while len(res.shape) < len(broadcast_shape): |
| 1614 | res = res[..., None] |
| 1615 | return res.expand(broadcast_shape) |
no test coverage detected