(context_tokens, device, args)
| 60 | |
| 61 | |
| 62 | def get_batch(context_tokens, device, args): |
| 63 | tokens = context_tokens |
| 64 | tokens = tokens.view(args.batch_size, -1).contiguous() |
| 65 | tokens = tokens.to(device) |
| 66 | |
| 67 | # Get the masks and postition ids. |
| 68 | if args.block_lm: |
| 69 | attention_mask = torch.tensor([tokens.size(1)], device=device, dtype=torch.long) |
| 70 | position_ids = torch.arange(tokens.size(1), device=device, dtype=torch.long) |
| 71 | if not args.no_block_position: |
| 72 | block_position_ids = torch.zeros(tokens.size(1), device=device, dtype=torch.long) |
| 73 | position_ids = torch.stack((position_ids, block_position_ids), dim=0) |
| 74 | position_ids = position_ids.unsqueeze(0) |
| 75 | else: |
| 76 | attention_mask, loss_mask, position_ids = get_masks_and_position_ids( |
| 77 | tokens, |
| 78 | args.eod_token, |
| 79 | reset_position_ids=False, |
| 80 | reset_attention_mask=False, |
| 81 | set_loss_mask=False, |
| 82 | mem_length=args.mem_length) |
| 83 | |
| 84 | return tokens, attention_mask, position_ids |
| 85 | |
| 86 | |
| 87 | def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None): |
no test coverage detected