Importance-sample timesteps for a batch. :param batch_size: the number of timesteps. :param device: the torch device to save to. :return: a tuple (timesteps, weights): - timesteps: a tensor of timestep indices. - weights: a tensor o
(self, batch_size, device)
| 40 | """ |
| 41 | |
| 42 | def sample(self, batch_size, device): |
| 43 | """ |
| 44 | Importance-sample timesteps for a batch. |
| 45 | |
| 46 | :param batch_size: the number of timesteps. |
| 47 | :param device: the torch device to save to. |
| 48 | :return: a tuple (timesteps, weights): |
| 49 | - timesteps: a tensor of timestep indices. |
| 50 | - weights: a tensor of weights to scale the resulting losses. |
| 51 | """ |
| 52 | w = self.weights() |
| 53 | p = w / np.sum(w) |
| 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) |
| 55 | indices = th.from_numpy(indices_np).long().to(device) |
| 56 | weights_np = 1 / (len(p) * p[indices_np]) |
| 57 | weights = th.from_numpy(weights_np).float().to(device) |
| 58 | return indices, weights |
| 59 | |
| 60 | |
| 61 | class UniformSampler(ScheduleSampler): |
no test coverage detected