MCPcopy Index your code
hub / github.com/THUDM/GLM / pad_batch

Method pad_batch

blocklm_utils.py:460–474  ·  view source on GitHub ↗
(token_batch, target_batch, loss_mask_batch, position_id_batch)

Source from the content-addressed store, hash-verified

458
459 @staticmethod
460 def pad_batch(token_batch, target_batch, loss_mask_batch, position_id_batch):
461 seq_lengths = list(map(len, token_batch))
462 if seq_lengths.count(seq_lengths[0]) != len(seq_lengths):
463 max_length = max(seq_lengths)
464 token_batch = [np.concatenate((tokens, np.zeros(max_length - len(tokens), dtype=np.long))) for tokens in
465 token_batch]
466 target_batch = [np.concatenate((targets, np.zeros(max_length - len(targets), dtype=np.long))) for
467 targets in
468 target_batch]
469 loss_mask_batch = [np.concatenate((loss_masks, np.zeros(max_length - len(loss_masks), dtype=np.long)))
470 for loss_masks in loss_mask_batch]
471 position_id_batch = [
472 np.concatenate((position_ids, np.zeros((2, max_length - position_ids.shape[1]), dtype=np.long)),
473 axis=1) for position_ids in position_id_batch]
474 return token_batch, target_batch, loss_mask_batch, position_id_batch

Callers 1

construct_blocksMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected