(data_iterator, args, timers)
| 254 | |
| 255 | |
| 256 | def get_batch(data_iterator, args, timers): |
| 257 | # Items and their type. |
| 258 | keys = ['text', 'loss_mask'] |
| 259 | datatype = torch.int64 |
| 260 | |
| 261 | # Broadcast data. |
| 262 | timers('data loader').start() |
| 263 | if data_iterator is not None: |
| 264 | data = next(data_iterator) |
| 265 | else: |
| 266 | data = None |
| 267 | timers('data loader').stop() |
| 268 | |
| 269 | data_b = mpu.broadcast_data(keys, data, datatype) |
| 270 | # Unpack. |
| 271 | tokens_ = data_b['text'].long() |
| 272 | loss_mask = data_b['loss_mask'].float() |
| 273 | labels = tokens_[:, 1:].contiguous() |
| 274 | loss_mask = loss_mask[:, 1:].contiguous() |
| 275 | tokens = tokens_[:, :-1].contiguous() |
| 276 | attention_mask = None |
| 277 | |
| 278 | # Get the masks and postition ids. |
| 279 | attention_mask, loss_mask, position_ids = get_masks_and_position_ids( |
| 280 | tokens, |
| 281 | loss_mask=loss_mask, |
| 282 | attention_mask=attention_mask, |
| 283 | args=args |
| 284 | ) |
| 285 | # Convert |
| 286 | if args.fp16: |
| 287 | attention_mask = attention_mask.half() |
| 288 | |
| 289 | return tokens, labels, loss_mask, attention_mask, position_ids |
| 290 | |
| 291 | |
| 292 | def forward_step(data_iterator, model, args, timers, mems): |
no test coverage detected