MCPcopy Index your code
hub / github.com/THUDM/GLM / sample_sequence

Function sample_sequence

generate_samples.py:87–182  ·  view source on GitHub ↗
(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None)

Source from the content-addressed store, hash-verified

85
86
87def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None):
88 if not args.block_lm:
89 context_tokens, attention_mask, position_ids = get_batch(context_tokens, device, args)
90 tokens = torch.empty((args.num_beams, 0), device=context_tokens.device, dtype=torch.long)
91 else:
92 tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id)
93 counter = 0
94 if mems is None:
95 mems = []
96 if end_tokens is None:
97 end_tokens = [args.eod_token]
98 if args.num_beams > 1:
99 beam_scorer = BeamSearchScorer(
100 batch_size=1,
101 max_length=args.out_seq_length,
102 num_beams=args.num_beams,
103 device=context_tokens.device,
104 length_penalty=args.length_penalty,
105 do_early_stopping=False,
106 )
107 beam_scores = torch.zeros(1, dtype=torch.float, device=context_tokens.device)
108 last_beam_num = 1
109 while counter < args.out_seq_length:
110 if counter == 0 and not args.block_lm:
111 next_token_logits, *mems = model(context_tokens, position_ids, attention_mask, *mems)
112 else:
113 if args.block_lm:
114 if args.no_block_position:
115 position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter)
116 else:
117 position_ids = context_tokens.new_ones(last_beam_num, 2, 1)
118 position_ids[:, 0] = context_length
119 position_ids[:, 1] = counter + 1
120 attention_mask = context_tokens.new_zeros([1], device=context_tokens.device, dtype=torch.long)
121 else:
122 position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1)
123 attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1,
124 device=context_tokens.device, dtype=torch.float)
125 last_token = tokens[:, -1:]
126 next_token_logits, *mems = model(last_token, position_ids, attention_mask, *mems)
127 next_token_logits = next_token_logits[:, -1]
128 if args.num_beams > 1:
129 next_token_scores = F.log_softmax(next_token_logits, dim=-1)
130 next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
131 vocab_size = next_token_scores.shape[-1]
132 next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size)
133
134 probs = F.softmax(next_token_scores, dim=-1)
135 next_tokens = torch.multinomial(probs, num_samples=2 * args.num_beams)
136 next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
137 next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
138 next_tokens = torch.gather(next_tokens, -1, _indices)
139
140 next_indices = next_tokens // vocab_size
141 next_tokens = next_tokens % vocab_size
142 # stateless
143 tokens = tokens.expand((args.num_beams, -1))
144 beam_outputs = beam_scorer.process(

Callers 1

generate_samplesFunction · 0.85

Calls 7

processMethod · 0.95
finalizeMethod · 0.95
BeamSearchScorerClass · 0.90
top_k_logitsFunction · 0.90
get_commandMethod · 0.80
get_batchFunction · 0.70
DecodeIdsMethod · 0.45

Tested by

no test coverage detected