r"""Samples elements randomly from a given list of indices, without replacement. Arguments: indices (sequence): a sequence of indices
| 9 | |
| 10 | |
| 11 | class SubsetRandomSampler(torch.utils.data.Sampler): |
| 12 | r"""Samples elements randomly from a given list of indices, without replacement. |
| 13 | |
| 14 | Arguments: |
| 15 | indices (sequence): a sequence of indices |
| 16 | """ |
| 17 | |
| 18 | def __init__(self, indices): |
| 19 | self.epoch = 0 |
| 20 | self.indices = indices |
| 21 | |
| 22 | def __iter__(self): |
| 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) |
| 24 | |
| 25 | def __len__(self): |
| 26 | return len(self.indices) |
| 27 | |
| 28 | def set_epoch(self, epoch): |
| 29 | self.epoch = epoch |