MCPcopy
hub / github.com/microsoft/Cream / OrderedDistributedSampler

Class OrderedDistributedSampler

TinyViT/data/augmentation/distributed_sampler.py:7–51  ·  view source on GitHub ↗

Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original

Source from the content-addressed store, hash-verified

5
6
7class OrderedDistributedSampler(Sampler):
8 """Sampler that restricts data loading to a subset of the dataset.
9 It is especially useful in conjunction with
10 :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
11 process can pass a DistributedSampler instance as a DataLoader sampler,
12 and load a subset of the original dataset that is exclusive to it.
13 .. note::
14 Dataset is assumed to be of constant size.
15 Arguments:
16 dataset: Dataset used for sampling.
17 num_replicas (optional): Number of processes participating in
18 distributed training.
19 rank (optional): Rank of the current process within num_replicas.
20 """
21
22 def __init__(self, dataset, num_replicas=None, rank=None):
23 if num_replicas is None:
24 if not dist.is_available():
25 raise RuntimeError("Requires distributed package to be available")
26 num_replicas = dist.get_world_size()
27 if rank is None:
28 if not dist.is_available():
29 raise RuntimeError("Requires distributed package to be available")
30 rank = dist.get_rank()
31 self.dataset = dataset
32 self.num_replicas = num_replicas
33 self.rank = rank
34 self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35 self.total_size = self.num_samples * self.num_replicas
36
37 def __iter__(self):
38 indices = list(range(len(self.dataset)))
39
40 # add extra samples to make it evenly divisible
41 indices += indices[:(self.total_size - len(indices))]
42 assert len(indices) == self.total_size
43
44 # subsample
45 indices = indices[self.rank:self.total_size:self.num_replicas]
46 assert len(indices) == self.num_samples
47
48 return iter(indices)
49
50 def __len__(self):
51 return self.num_samples
52
53
54class RepeatAugSampler(Sampler):

Callers 1

create_loaderFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected