MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / generate_samples_interactive

Function generate_samples_interactive

codegeex/megatron/code_generation_utils.py:377–484  ·  view source on GitHub ↗
(model, print_frequency=24)

Source from the content-addressed store, hash-verified

375
376
377def generate_samples_interactive(model, print_frequency=24):
378 args = get_args()
379 tokenizer = get_tokenizer()
380
381 context_count = 0
382 model.eval()
383 with torch.no_grad():
384 while True:
385 terminate_runs = 0
386 raw_text_len = 0
387
388 if (
389 mpu.is_pipeline_first_stage()
390 and mpu.get_tensor_model_parallel_rank() == 0
391 ):
392 os.system("clear")
393 raw_text = input("\nContext prompt (stop to exit) >>> ")
394 while not raw_text:
395 print("Prompt should not be empty!")
396 raw_text = input("\nContext prompt (stop to exit) >>> ")
397 raw_text_len = len(raw_text)
398
399 if "stop" in raw_text:
400 terminate_runs = 1
401 else:
402 context_tokens = tokenizer.tokenize(raw_text)
403 context_length = len(context_tokens)
404
405 if context_length >= (args.seq_length // 2):
406 print(
407 "\nContext length",
408 context_length,
409 "\nPlease give smaller context (half of the "
410 "sequence length)!",
411 flush=True,
412 )
413 continue
414 else:
415 context_tokens = tokenizer.tokenize("EMPTY TEXT")
416 context_length = 0
417
418 input_info = [terminate_runs, raw_text_len, context_length]
419 input_info_tensor = torch.cuda.LongTensor(input_info)
420 torch.distributed.all_reduce(
421 input_info_tensor, group=mpu.get_model_parallel_group()
422 )
423 terminate_runs = input_info_tensor[0].item()
424 raw_text_len = input_info_tensor[1].item()
425 context_length = input_info_tensor[2].item()
426
427 if terminate_runs == 1:
428 return
429
430 # For pipeline parallel we send context tokens to other stages
431 # so they get the lengths correct
432 if (
433 mpu.get_tensor_model_parallel_rank() == 0
434 and args.pipeline_model_parallel_size > 1

Callers

nothing calls this directly

Calls 6

get_argsFunction · 0.90
get_tokenizerFunction · 0.90
get_token_streamFunction · 0.70
evalMethod · 0.45
tokenizeMethod · 0.45
detokenizeMethod · 0.45

Tested by

no test coverage detected