MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / get_batch

Function get_batch

Megatron-LM/pretrain_bert.py:168–202  ·  view source on GitHub ↗

get_batch subdivides the source data into chunks of length args.seq_length. If source is equal to the example output of the data loading example, with a seq_length limit of 2, we'd get the following two Variables for i = 0: ┌ a g m s ┐ ┌ b h n t ┐ └ b h n t ┘ └ c i o u ┘ Not

(data_iterator, timers)

Source from the content-addressed store, hash-verified

166
167
168def get_batch(data_iterator, timers):
169 ''' get_batch subdivides the source data into chunks of
170 length args.seq_length. If source is equal to the example
171 output of the data loading example, with a seq_length limit
172 of 2, we'd get the following two Variables for i = 0:
173 ┌ a g m s ┐ ┌ b h n t ┐
174 └ b h n t ┘ └ c i o u ┘
175 Note that despite the name of the function, the subdivison of data is not
176 done along the batch dimension (i.e. dimension 1), since that was handled
177 by the data loader. The chunks are along dimension 0, corresponding
178 to the seq_len dimension in the LSTM. A Variable representing an appropriate
179 shard reset mask of the same dimensions is also returned.
180 '''
181 # Items and their type.
182 keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask']
183 datatype = torch.int64
184
185 # Broadcast data.
186 timers('data loader').start()
187 if data_iterator is not None:
188 data = next(data_iterator)
189 else:
190 data = None
191 timers('data loader').stop()
192 data_b = mpu.broadcast_data(keys, data, datatype)
193
194 # Unpack.
195 tokens = data_b['text'].long()
196 types = data_b['types'].long()
197 next_sentence = data_b['is_random'].long()
198 loss_mask = data_b['mask'].float()
199 lm_labels = data_b['mask_labels'].long()
200 padding_mask = data_b['pad_mask'].byte()
201
202 return tokens, types, next_sentence, loss_mask, lm_labels, padding_mask
203
204
205def forward_step(data_iterator, model, args, timers):

Callers 1

forward_stepFunction · 0.70

Calls 2

startMethod · 0.45
stopMethod · 0.45

Tested by

no test coverage detected