(model, print_frequency=24)
| 375 | |
| 376 | |
| 377 | def generate_samples_interactive(model, print_frequency=24): |
| 378 | args = get_args() |
| 379 | tokenizer = get_tokenizer() |
| 380 | |
| 381 | context_count = 0 |
| 382 | model.eval() |
| 383 | with torch.no_grad(): |
| 384 | while True: |
| 385 | terminate_runs = 0 |
| 386 | raw_text_len = 0 |
| 387 | |
| 388 | if ( |
| 389 | mpu.is_pipeline_first_stage() |
| 390 | and mpu.get_tensor_model_parallel_rank() == 0 |
| 391 | ): |
| 392 | os.system("clear") |
| 393 | raw_text = input("\nContext prompt (stop to exit) >>> ") |
| 394 | while not raw_text: |
| 395 | print("Prompt should not be empty!") |
| 396 | raw_text = input("\nContext prompt (stop to exit) >>> ") |
| 397 | raw_text_len = len(raw_text) |
| 398 | |
| 399 | if "stop" in raw_text: |
| 400 | terminate_runs = 1 |
| 401 | else: |
| 402 | context_tokens = tokenizer.tokenize(raw_text) |
| 403 | context_length = len(context_tokens) |
| 404 | |
| 405 | if context_length >= (args.seq_length // 2): |
| 406 | print( |
| 407 | "\nContext length", |
| 408 | context_length, |
| 409 | "\nPlease give smaller context (half of the " |
| 410 | "sequence length)!", |
| 411 | flush=True, |
| 412 | ) |
| 413 | continue |
| 414 | else: |
| 415 | context_tokens = tokenizer.tokenize("EMPTY TEXT") |
| 416 | context_length = 0 |
| 417 | |
| 418 | input_info = [terminate_runs, raw_text_len, context_length] |
| 419 | input_info_tensor = torch.cuda.LongTensor(input_info) |
| 420 | torch.distributed.all_reduce( |
| 421 | input_info_tensor, group=mpu.get_model_parallel_group() |
| 422 | ) |
| 423 | terminate_runs = input_info_tensor[0].item() |
| 424 | raw_text_len = input_info_tensor[1].item() |
| 425 | context_length = input_info_tensor[2].item() |
| 426 | |
| 427 | if terminate_runs == 1: |
| 428 | return |
| 429 | |
| 430 | # For pipeline parallel we send context tokens to other stages |
| 431 | # so they get the lengths correct |
| 432 | if ( |
| 433 | mpu.get_tensor_model_parallel_rank() == 0 |
| 434 | and args.pipeline_model_parallel_size > 1 |
nothing calls this directly
no test coverage detected