MCPcopy Index your code
hub / github.com/InternLM/InternLM / DataParallelSampler

Class DataParallelSampler

internlm/data/batch_sampler.py:21–107  ·  view source on GitHub ↗

A data sampler for distributed data parallelism. Args: dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling. shuffle (bool, optional): Whether to shuffle data, defaults to False. seed (int, optional): The random seed used for sampling, defaults to 0.

Source from the content-addressed store, hash-verified

19
20
21class DataParallelSampler(Sampler):
22 """A data sampler for distributed data parallelism.
23
24 Args:
25 dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling.
26 shuffle (bool, optional): Whether to shuffle data, defaults to False.
27 seed (int, optional): The random seed used for sampling, defaults to 0.
28 drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
29 is not divisible by the batch size. If False and the size of dataset is not divisible by
30 the batch size, then the last batch will be smaller, defaults to False.
31 """
32
33 def __init__(
34 self,
35 dataset: Dataset,
36 shuffle: bool = False,
37 seed: int = 0,
38 drop_last: bool = False,
39 ) -> None:
40 self.dataset = dataset
41 self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
42 self.rank = gpc.get_local_rank(ParallelMode.DATA)
43 self.epoch = 0
44 self.drop_last = drop_last
45 # If the dataset length is evenly divisible by # of replicas, then there
46 # is no need to drop any data, since the dataset will be split equally.
47 # type: ignore[arg-type]
48 if self.drop_last and len(self.dataset) % self.num_replicas != 0:
49 # Split to nearest available length that is evenly divisible.
50 # This is to ensure each rank receives the same amount of data when
51 # using this Sampler.
52 self.num_samples = math.ceil(
53 # `type:ignore` is required because Dataset cannot provide a default __len__
54 # see NOTE in pytorch/torch/utils/data/sampler.py
55 (len(self.dataset) - self.num_replicas)
56 / self.num_replicas # type: ignore[arg-type]
57 )
58 else:
59 self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
60 self.total_size = self.num_samples * self.num_replicas
61 self.shuffle = shuffle
62 self.seed = seed
63
64 def __iter__(self) -> Iterator[T_co]:
65 if self.shuffle:
66 # deterministically shuffle based on epoch and seed
67 g = torch.Generator()
68 g.manual_seed(self.seed + self.epoch)
69 # type: ignore[arg-type]
70 indices = torch.randperm(len(self.dataset), generator=g).tolist()
71
72 # update for next epoch so that there is no need to call
73 # set_epoch manually
74 self.epoch += 1
75 else:
76 indices = list(range(len(self.dataset))) # type: ignore[arg-type]
77
78 if not self.drop_last:

Callers 1

get_dpsampler_dataloaderFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected