MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / RandomSampler

Class RandomSampler

Megatron-LM/data_utils/samplers.py:24–73  ·  view source on GitHub ↗

r""" Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, but this class lets the user set an epoch like DistributedSampler Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can sp

Source from the content-addressed store, hash-verified

22import numpy as np
23
24class RandomSampler(data.sampler.Sampler):
25 r"""
26 Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
27 but this class lets the user set an epoch like DistributedSampler
28 Samples elements randomly. If without replacement, then sample from a shuffled dataset.
29 If with replacement, then user can specify ``num_samples`` to draw.
30 Arguments:
31 data_source (Dataset): dataset to sample from
32 num_samples (int): number of samples to draw, default=len(dataset)
33 replacement (bool): samples are drawn with replacement if ``True``, default=False
34 """
35
36 def __init__(self, data_source, replacement=False, num_samples=None):
37 self.data_source = data_source
38 self.replacement = replacement
39 self._num_samples = num_samples
40 self.epoch = -1
41
42 if self._num_samples is not None and replacement is False:
43 raise ValueError("With replacement=False, num_samples should not be specified, "
44 "since a random permute will be performed.")
45
46 if not isinstance(self.num_samples, int) or self.num_samples <= 0:
47 raise ValueError("num_samples should be a positive integer "
48 "value, but got num_samples={}".format(self.num_samples))
49 if not isinstance(self.replacement, bool):
50 raise ValueError("replacement should be a boolean value, but got "
51 "replacement={}".format(self.replacement))
52
53 @property
54 def num_samples(self):
55 # dataset size might change at runtime
56 if self._num_samples is None:
57 return len(self.data_source)
58 return self._num_samples
59
60 def __iter__(self):
61 n = len(self.data_source)
62 g = torch.Generator()
63 if self.epoch >= 0:
64 g.manual_seed(self.epoch)
65 if self.replacement:
66 return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist())
67 return iter(torch.randperm(n, generator=g).tolist())
68
69 def __len__(self):
70 return self.num_samples
71
72 def set_epoch(self, epoch):
73 self.epoch = epoch
74
75class DistributedBatchSampler(data.sampler.BatchSampler):
76 """

Callers 3

get_dataloaderFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected