r""" Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, but this class lets the user set an epoch like DistributedSampler Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can sp
| 22 | import numpy as np |
| 23 | |
| 24 | class RandomSampler(data.sampler.Sampler): |
| 25 | r""" |
| 26 | Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, |
| 27 | but this class lets the user set an epoch like DistributedSampler |
| 28 | Samples elements randomly. If without replacement, then sample from a shuffled dataset. |
| 29 | If with replacement, then user can specify ``num_samples`` to draw. |
| 30 | Arguments: |
| 31 | data_source (Dataset): dataset to sample from |
| 32 | num_samples (int): number of samples to draw, default=len(dataset) |
| 33 | replacement (bool): samples are drawn with replacement if ``True``, default=False |
| 34 | """ |
| 35 | |
| 36 | def __init__(self, data_source, replacement=False, num_samples=None): |
| 37 | self.data_source = data_source |
| 38 | self.replacement = replacement |
| 39 | self._num_samples = num_samples |
| 40 | self.epoch = -1 |
| 41 | |
| 42 | if self._num_samples is not None and replacement is False: |
| 43 | raise ValueError("With replacement=False, num_samples should not be specified, " |
| 44 | "since a random permute will be performed.") |
| 45 | |
| 46 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: |
| 47 | raise ValueError("num_samples should be a positive integer " |
| 48 | "value, but got num_samples={}".format(self.num_samples)) |
| 49 | if not isinstance(self.replacement, bool): |
| 50 | raise ValueError("replacement should be a boolean value, but got " |
| 51 | "replacement={}".format(self.replacement)) |
| 52 | |
| 53 | @property |
| 54 | def num_samples(self): |
| 55 | # dataset size might change at runtime |
| 56 | if self._num_samples is None: |
| 57 | return len(self.data_source) |
| 58 | return self._num_samples |
| 59 | |
| 60 | def __iter__(self): |
| 61 | n = len(self.data_source) |
| 62 | g = torch.Generator() |
| 63 | if self.epoch >= 0: |
| 64 | g.manual_seed(self.epoch) |
| 65 | if self.replacement: |
| 66 | return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist()) |
| 67 | return iter(torch.randperm(n, generator=g).tolist()) |
| 68 | |
| 69 | def __len__(self): |
| 70 | return self.num_samples |
| 71 | |
| 72 | def set_epoch(self, epoch): |
| 73 | self.epoch = epoch |
| 74 | |
| 75 | class DistributedBatchSampler(data.sampler.BatchSampler): |
| 76 | """ |
no outgoing calls
no test coverage detected