(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None)
| 85 | |
| 86 | |
| 87 | def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None): |
| 88 | if not args.block_lm: |
| 89 | context_tokens, attention_mask, position_ids = get_batch(context_tokens, device, args) |
| 90 | tokens = torch.empty((args.num_beams, 0), device=context_tokens.device, dtype=torch.long) |
| 91 | else: |
| 92 | tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id) |
| 93 | counter = 0 |
| 94 | if mems is None: |
| 95 | mems = [] |
| 96 | if end_tokens is None: |
| 97 | end_tokens = [args.eod_token] |
| 98 | if args.num_beams > 1: |
| 99 | beam_scorer = BeamSearchScorer( |
| 100 | batch_size=1, |
| 101 | max_length=args.out_seq_length, |
| 102 | num_beams=args.num_beams, |
| 103 | device=context_tokens.device, |
| 104 | length_penalty=args.length_penalty, |
| 105 | do_early_stopping=False, |
| 106 | ) |
| 107 | beam_scores = torch.zeros(1, dtype=torch.float, device=context_tokens.device) |
| 108 | last_beam_num = 1 |
| 109 | while counter < args.out_seq_length: |
| 110 | if counter == 0 and not args.block_lm: |
| 111 | next_token_logits, *mems = model(context_tokens, position_ids, attention_mask, *mems) |
| 112 | else: |
| 113 | if args.block_lm: |
| 114 | if args.no_block_position: |
| 115 | position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter) |
| 116 | else: |
| 117 | position_ids = context_tokens.new_ones(last_beam_num, 2, 1) |
| 118 | position_ids[:, 0] = context_length |
| 119 | position_ids[:, 1] = counter + 1 |
| 120 | attention_mask = context_tokens.new_zeros([1], device=context_tokens.device, dtype=torch.long) |
| 121 | else: |
| 122 | position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1) |
| 123 | attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1, |
| 124 | device=context_tokens.device, dtype=torch.float) |
| 125 | last_token = tokens[:, -1:] |
| 126 | next_token_logits, *mems = model(last_token, position_ids, attention_mask, *mems) |
| 127 | next_token_logits = next_token_logits[:, -1] |
| 128 | if args.num_beams > 1: |
| 129 | next_token_scores = F.log_softmax(next_token_logits, dim=-1) |
| 130 | next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) |
| 131 | vocab_size = next_token_scores.shape[-1] |
| 132 | next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size) |
| 133 | |
| 134 | probs = F.softmax(next_token_scores, dim=-1) |
| 135 | next_tokens = torch.multinomial(probs, num_samples=2 * args.num_beams) |
| 136 | next_token_scores = torch.gather(next_token_scores, -1, next_tokens) |
| 137 | next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) |
| 138 | next_tokens = torch.gather(next_tokens, -1, _indices) |
| 139 | |
| 140 | next_indices = next_tokens // vocab_size |
| 141 | next_tokens = next_tokens % vocab_size |
| 142 | # stateless |
| 143 | tokens = tokens.expand((args.num_beams, -1)) |
| 144 | beam_outputs = beam_scorer.process( |
no test coverage detected