MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeedExamples / get_batch

Function get_batch

Megatron-LM/pretrain_gpt2.py:231–272  ·  view source on GitHub ↗

get_batch subdivides the source data into chunks of length args.seq_length. If source is equal to the example output of the data loading example, with a seq_length limit of 2, we'd get the following two Variables for i = 0: ┌ a g m s ┐ ┌ b h n t ┐ └ b h n t ┘ └ c i o u ┘ Not

(data_iterator, args, timers)

Source from the content-addressed store, hash-verified

229
230
231def get_batch(data_iterator, args, timers):
232 ''' get_batch subdivides the source data into chunks of
233 length args.seq_length. If source is equal to the example
234 output of the data loading example, with a seq_length limit
235 of 2, we'd get the following two Variables for i = 0:
236 ┌ a g m s ┐ ┌ b h n t ┐
237 └ b h n t ┘ └ c i o u ┘
238 Note that despite the name of the function, the subdivison of data is not
239 done along the batch dimension (i.e. dimension 1), since that was handled
240 by the data loader. The chunks are along dimension 0, corresponding
241 to the seq_len dimension in the LSTM. A Variable representing an appropriate
242 shard reset mask of the same dimensions is also returned.
243 '''
244 # Items and their type.
245 keys = ['text']
246 datatype = torch.int64
247
248 # Broadcast data.
249 timers('data loader').start()
250 if data_iterator is not None:
251 data = next(data_iterator)
252 else:
253 data = None
254 timers('data loader').stop()
255 data_b = mpu.broadcast_data(keys, data, datatype)
256
257 # Unpack.
258 tokens_ = data_b['text'].long()
259 labels = tokens_[:, 1:].contiguous()
260 tokens = tokens_[:, :-1].contiguous()
261
262 # Get the masks and postition ids.
263 attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
264 tokens,
265 args.eod_token,
266 args.reset_position_ids,
267 args.reset_attention_mask)
268 # Convert
269 if args.fp16:
270 attention_mask = attention_mask.half()
271
272 return tokens, labels, loss_mask, attention_mask, position_ids
273
274
275def forward_step(data_iterator, model, args, timers):

Callers 1

forward_stepFunction · 0.70

Calls 4

halfMethod · 0.80
startMethod · 0.45
stopMethod · 0.45

Tested by

no test coverage detected