MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeedExamples / DistributedBatchSampler

Class DistributedBatchSampler

Megatron-LM/data_utils/samplers.py:75–139  ·  view source on GitHub ↗

similar to normal implementation of distributed sampler, except implementation is at the batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.

Source from the content-addressed store, hash-verified

73 self.epoch = epoch
74
75class DistributedBatchSampler(data.sampler.BatchSampler):
76 """
77 similar to normal implementation of distributed sampler, except implementation is at the
78 batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
79 data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
80 """
81 def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False):
82 super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
83 if rank == -1:
84 assert False, 'should not be here'
85 rank = torch.distributed.get_rank()
86 self.rank = rank
87 self.world_size = world_size
88 self.sampler.wrap_around = 0
89 self.wrap_around = 0
90 self.wrap_last = wrap_last
91 self.start_iter = 0
92
93 def __iter__(self):
94 batch = []
95 last_batch = None
96 i = 0
97 for idx in self.data_iterator(self.sampler, wrap_around=False):
98 batch.append(idx)
99 if len(batch) == self.batch_size:
100 tbatch = self._batch(batch)
101 if i >= self.start_iter:
102 yield tbatch
103 self.start_iter = 0
104 i += 1
105 last_batch = np.array(list(tbatch))
106 batch = []
107 batch_len = len(batch)
108 if batch_len > 0 and not self.drop_last:
109 if self.wrap_last:
110 self.sampler.wrap_around -= (self.batch_size)
111 self.wrap_around += (len(batch))
112 self.wrap_around %= self.batch_size
113 if isinstance(self.sampler, TransposedSampler):
114 for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)):
115 if i == 0:
116 continue
117 batch.append(idx)
118 new_batch_len = len(batch)
119 if len(batch) == self.batch_size:
120 break
121 yield self._batch(batch)
122 if self.wrap_last:
123 self.sampler.wrap_around += self.batch_size
124
125 def data_iterator(self, _iter, wrap_around=False):
126 """iterates through data and handles wrap around"""
127 for i, idx in enumerate(_iter):
128 if i < self.wrap_around%self.batch_size:
129 continue
130 if wrap_around:
131 self.wrap_around += 1
132 self.wrap_around %= self.batch_size

Callers 1

make_data_loader_Function · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected