MCPcopy
hub / github.com/zai-org/CogView / get_batch

Function get_batch

pretrain_gpt2.py:256–289  ·  view source on GitHub ↗
(data_iterator, args, timers)

Source from the content-addressed store, hash-verified

254
255
256def get_batch(data_iterator, args, timers):
257 # Items and their type.
258 keys = ['text', 'loss_mask']
259 datatype = torch.int64
260
261 # Broadcast data.
262 timers('data loader').start()
263 if data_iterator is not None:
264 data = next(data_iterator)
265 else:
266 data = None
267 timers('data loader').stop()
268
269 data_b = mpu.broadcast_data(keys, data, datatype)
270 # Unpack.
271 tokens_ = data_b['text'].long()
272 loss_mask = data_b['loss_mask'].float()
273 labels = tokens_[:, 1:].contiguous()
274 loss_mask = loss_mask[:, 1:].contiguous()
275 tokens = tokens_[:, :-1].contiguous()
276 attention_mask = None
277
278 # Get the masks and postition ids.
279 attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
280 tokens,
281 loss_mask=loss_mask,
282 attention_mask=attention_mask,
283 args=args
284 )
285 # Convert
286 if args.fp16:
287 attention_mask = attention_mask.half()
288
289 return tokens, labels, loss_mask, attention_mask, position_ids
290
291
292def forward_step(data_iterator, model, args, timers, mems):

Callers 1

forward_stepFunction · 0.70

Calls 3

startMethod · 0.80
stopMethod · 0.80

Tested by

no test coverage detected