| 458 | |
| 459 | @staticmethod |
| 460 | def pad_batch(token_batch, target_batch, loss_mask_batch, position_id_batch): |
| 461 | seq_lengths = list(map(len, token_batch)) |
| 462 | if seq_lengths.count(seq_lengths[0]) != len(seq_lengths): |
| 463 | max_length = max(seq_lengths) |
| 464 | token_batch = [np.concatenate((tokens, np.zeros(max_length - len(tokens), dtype=np.long))) for tokens in |
| 465 | token_batch] |
| 466 | target_batch = [np.concatenate((targets, np.zeros(max_length - len(targets), dtype=np.long))) for |
| 467 | targets in |
| 468 | target_batch] |
| 469 | loss_mask_batch = [np.concatenate((loss_masks, np.zeros(max_length - len(loss_masks), dtype=np.long))) |
| 470 | for loss_masks in loss_mask_batch] |
| 471 | position_id_batch = [ |
| 472 | np.concatenate((position_ids, np.zeros((2, max_length - position_ids.shape[1]), dtype=np.long)), |
| 473 | axis=1) for position_ids in position_id_batch] |
| 474 | return token_batch, target_batch, loss_mask_batch, position_id_batch |