MCPcopy Index your code
hub / github.com/THUDM/GLM / read_context

Function read_context

generate_samples.py:185–236  ·  view source on GitHub ↗
(tokenizer, args, output)

Source from the content-addressed store, hash-verified

183
184
185def read_context(tokenizer, args, output):
186 terminate_runs, skip_run = 0, 0
187 if mpu.get_model_parallel_rank() == 0:
188 while True:
189 raw_text = input("\nContext prompt (stop to exit) >>> ")
190 if not raw_text:
191 print('Prompt should not be empty!')
192 continue
193 if raw_text == "stop":
194 terminate_runs = 1
195 break
196 generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
197 if args.block_lm and 'MASK]' not in raw_text:
198 raw_text += ' ' + generation_mask
199 output.write(raw_text)
200 context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
201 if args.block_lm:
202 context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
203 if not raw_text.endswith('[gMASK]'):
204 context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
205 context_length = len(context_tokens)
206
207 if context_length >= args.seq_length:
208 print("\nContext length", context_length,
209 "\nPlease give smaller context than the window length!")
210 continue
211 break
212 else:
213 context_length = 0
214
215 terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
216 torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
217 group=mpu.get_model_parallel_group())
218 terminate_runs = terminate_runs_tensor[0].item()
219
220 if terminate_runs == 1:
221 return terminate_runs, None, None, None
222
223 context_length_tensor = torch.cuda.LongTensor([context_length])
224
225 torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
226 group=mpu.get_model_parallel_group())
227 context_length = context_length_tensor[0].item()
228 if mpu.get_model_parallel_rank() == 0:
229 context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
230 else:
231 context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
232 torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
233 group=mpu.get_model_parallel_group())
234 if mpu.get_model_parallel_rank() != 0:
235 raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
236 return terminate_runs, raw_text, context_tokens_tensor, context_length
237
238
239def generate_samples(model, tokenizer, args, device):

Callers 1

generate_samplesFunction · 0.85

Calls 4

get_commandMethod · 0.80
writeMethod · 0.45
EncodeAsIdsMethod · 0.45
DecodeIdsMethod · 0.45

Tested by

no test coverage detected