Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. It ensures that different each augmented version of a sample will be visible to a different process (GPU) Heavily based on torch.utils.data.DistributedSampler
| 6 | |
| 7 | |
| 8 | class RASampler(torch.utils.data.Sampler): |
| 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, |
| 10 | with repeated augmentation. |
| 11 | It ensures that different each augmented version of a sample will be visible to a |
| 12 | different process (GPU) |
| 13 | Heavily based on torch.utils.data.DistributedSampler |
| 14 | """ |
| 15 | |
| 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): |
| 17 | if num_replicas is None: |
| 18 | if not dist.is_available(): |
| 19 | raise RuntimeError("Requires distributed package to be available") |
| 20 | num_replicas = dist.get_world_size() |
| 21 | if rank is None: |
| 22 | if not dist.is_available(): |
| 23 | raise RuntimeError("Requires distributed package to be available") |
| 24 | rank = dist.get_rank() |
| 25 | if num_repeats < 1: |
| 26 | raise ValueError("num_repeats should be greater than 0") |
| 27 | self.dataset = dataset |
| 28 | self.num_replicas = num_replicas |
| 29 | self.rank = rank |
| 30 | self.num_repeats = num_repeats |
| 31 | self.epoch = 0 |
| 32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) |
| 33 | self.total_size = self.num_samples * self.num_replicas |
| 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) |
| 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) |
| 36 | self.shuffle = shuffle |
| 37 | |
| 38 | def __iter__(self): |
| 39 | if self.shuffle: |
| 40 | # deterministically shuffle based on epoch |
| 41 | g = torch.Generator() |
| 42 | g.manual_seed(self.epoch) |
| 43 | indices = torch.randperm(len(self.dataset), generator=g) |
| 44 | else: |
| 45 | indices = torch.arange(start=0, end=len(self.dataset)) |
| 46 | |
| 47 | # add extra samples to make it evenly divisible |
| 48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() |
| 49 | padding_size: int = self.total_size - len(indices) |
| 50 | if padding_size > 0: |
| 51 | indices += indices[:padding_size] |
| 52 | assert len(indices) == self.total_size |
| 53 | |
| 54 | # subsample |
| 55 | indices = indices[self.rank:self.total_size:self.num_replicas] |
| 56 | assert len(indices) == self.num_samples |
| 57 | |
| 58 | return iter(indices[:self.num_selected_samples]) |
| 59 | |
| 60 | def __len__(self): |
| 61 | return self.num_selected_samples |
| 62 | |
| 63 | def set_epoch(self, epoch): |
| 64 | self.epoch = epoch |