(context_tokens, device, args)
| 49 | return logits |
| 50 | |
| 51 | def get_batch(context_tokens, device, args): |
| 52 | tokens = context_tokens |
| 53 | if len(tokens.shape) == 1: |
| 54 | tokens = tokens.unsqueeze(0).contiguous() |
| 55 | else: |
| 56 | tokens = tokens.view(tokens.shape[0], -1).contiguous() |
| 57 | tokens = tokens.to(device) |
| 58 | |
| 59 | # Get the masks and postition ids. |
| 60 | attention_mask, loss_mask, position_ids = get_masks_and_position_ids( |
| 61 | tokens) |
| 62 | return tokens, attention_mask, position_ids |
| 63 | |
| 64 | def filling_sequence( |
| 65 | model, |
no test coverage detected