MCPcopy
hub / github.com/hustvl/Vim / RASampler

Class RASampler

vim/samplers.py:8–64  ·  view source on GitHub ↗

Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. It ensures that different each augmented version of a sample will be visible to a different process (GPU) Heavily based on torch.utils.data.DistributedSampler

Source from the content-addressed store, hash-verified

6
7
8class RASampler(torch.utils.data.Sampler):
9 """Sampler that restricts data loading to a subset of the dataset for distributed,
10 with repeated augmentation.
11 It ensures that different each augmented version of a sample will be visible to a
12 different process (GPU)
13 Heavily based on torch.utils.data.DistributedSampler
14 """
15
16 def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
17 if num_replicas is None:
18 if not dist.is_available():
19 raise RuntimeError("Requires distributed package to be available")
20 num_replicas = dist.get_world_size()
21 if rank is None:
22 if not dist.is_available():
23 raise RuntimeError("Requires distributed package to be available")
24 rank = dist.get_rank()
25 if num_repeats < 1:
26 raise ValueError("num_repeats should be greater than 0")
27 self.dataset = dataset
28 self.num_replicas = num_replicas
29 self.rank = rank
30 self.num_repeats = num_repeats
31 self.epoch = 0
32 self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
33 self.total_size = self.num_samples * self.num_replicas
34 # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
35 self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
36 self.shuffle = shuffle
37
38 def __iter__(self):
39 if self.shuffle:
40 # deterministically shuffle based on epoch
41 g = torch.Generator()
42 g.manual_seed(self.epoch)
43 indices = torch.randperm(len(self.dataset), generator=g)
44 else:
45 indices = torch.arange(start=0, end=len(self.dataset))
46
47 # add extra samples to make it evenly divisible
48 indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
49 padding_size: int = self.total_size - len(indices)
50 if padding_size > 0:
51 indices += indices[:padding_size]
52 assert len(indices) == self.total_size
53
54 # subsample
55 indices = indices[self.rank:self.total_size:self.num_replicas]
56 assert len(indices) == self.num_samples
57
58 return iter(indices[:self.num_selected_samples])
59
60 def __len__(self):
61 return self.num_selected_samples
62
63 def set_epoch(self, epoch):
64 self.epoch = epoch

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected