Parse all the args.
()
| 396 | |
| 397 | |
| 398 | def get_args(): |
| 399 | """Parse all the args.""" |
| 400 | |
| 401 | parser = argparse.ArgumentParser(description='PyTorch BERT Model') |
| 402 | parser = add_model_config_args(parser) |
| 403 | parser = add_fp16_config_args(parser) |
| 404 | parser = add_training_args(parser) |
| 405 | parser = add_evaluation_args(parser) |
| 406 | parser = add_text_generate_args(parser) |
| 407 | parser = add_data_args(parser) |
| 408 | parser = add_finetune_config_args(parser) |
| 409 | |
| 410 | # Include DeepSpeed configuration arguments |
| 411 | parser = deepspeed.add_config_arguments(parser) |
| 412 | |
| 413 | args = parser.parse_args() |
| 414 | if not args.train_data and not args.data_dir: |
| 415 | print('WARNING: No training data specified') |
| 416 | |
| 417 | args.cuda = torch.cuda.is_available() |
| 418 | |
| 419 | args.rank = int(os.getenv('RANK', '0')) |
| 420 | args.world_size = int(os.getenv("WORLD_SIZE", '1')) |
| 421 | if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: |
| 422 | mpi_define_env(args) |
| 423 | elif os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'): |
| 424 | # We are using (OpenMPI) mpirun for launching distributed data parallel processes |
| 425 | local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')) |
| 426 | local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE')) |
| 427 | |
| 428 | # Possibly running with Slurm |
| 429 | num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1')) |
| 430 | nodeid = int(os.getenv('SLURM_NODEID', '0')) |
| 431 | |
| 432 | args.local_rank = local_rank |
| 433 | args.rank = nodeid * local_size + local_rank |
| 434 | args.world_size = num_nodes * local_size |
| 435 | |
| 436 | args.model_parallel_size = min(args.model_parallel_size, args.world_size) |
| 437 | if args.rank == 0: |
| 438 | print('using world size: {} and model-parallel size: {} '.format( |
| 439 | args.world_size, args.model_parallel_size)) |
| 440 | |
| 441 | args.dynamic_loss_scale = False |
| 442 | if args.loss_scale is None: |
| 443 | args.dynamic_loss_scale = True |
| 444 | if args.rank == 0: |
| 445 | print(' > using dynamic loss scaling') |
| 446 | |
| 447 | # The args fp32_* or fp16_* meant to be active when the |
| 448 | # args fp16 is set. So the default behaviour should all |
| 449 | # be false. |
| 450 | if not args.fp16: |
| 451 | args.fp32_embedding = False |
| 452 | args.fp32_tokentypes = False |
| 453 | args.fp32_layernorm = False |
| 454 | |
| 455 | if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None: |
no test coverage detected