MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / generate_samples

Function generate_samples

Megatron-LM/generate_samples.py:128–217  ·  view source on GitHub ↗
(model, tokenizer, args, device)

Source from the content-addressed store, hash-verified

126
127
128def generate_samples(model, tokenizer, args, device):
129
130 context_count=0
131 model.eval()
132 with torch.no_grad():
133 while True:
134 torch.distributed.barrier(group=mpu.get_model_parallel_group())
135 terminate_runs=0
136
137 if mpu.get_model_parallel_rank() == 0:
138 raw_text = input("\nContext prompt (stop to exit) >>> ")
139 while not raw_text:
140 print('Prompt should not be empty!')
141 raw_text = input("\nContext prompt (stop to exit) >>> ")
142
143 if "stop" in raw_text:
144 terminate_runs = 1
145 else:
146 context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
147 context_length = len(context_tokens)
148
149 if context_length >=args.seq_length//2:
150 print("\nContext length", context_length, \
151 "\nPlease give smaller context (half of the sequence length)!")
152 continue
153 else:
154 context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization
155 context_length = len(context_tokens)
156
157 terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
158 torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
159 terminate_runs = terminate_runs_tensor[0].item()
160
161 if terminate_runs == 1:
162 return
163
164 pad_id = tokenizer.get_command('pad').Id
165 if context_length < args.seq_length:
166 context_tokens.extend([pad_id] * (args.seq_length - context_length))
167
168 context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
169 context_length_tensor = torch.cuda.LongTensor([context_length])
170
171 torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
172 torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
173
174 context_length = context_length_tensor[0].item()
175 tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, device, args)
176
177 start_time = time.time()
178
179 counter = 0
180 org_context_length = context_length
181
182 while counter < (org_context_length + args.out_seq_length):
183 logits = model(tokens, position_ids, attention_mask)
184 logits = logits[:, context_length - 1, :] / args.temperature
185 logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)

Callers 1

mainFunction · 0.85

Calls 7

top_k_logitsFunction · 0.85
evalMethod · 0.80
get_commandMethod · 0.80
extendMethod · 0.80
get_batchFunction · 0.70
EncodeAsIdsMethod · 0.45
DecodeIdsMethod · 0.45

Tested by

no test coverage detected