MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / main

Function main

BingBertSquad/nvidia_run_squad_baseline.py:741–1059  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

739
740
741def main():
742 parser = get_argument_parser()
743 args = parser.parse_args()
744
745 check_early_exit_warning(args)
746
747 if args.local_rank == -1 or args.no_cuda:
748 device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
749 n_gpu = torch.cuda.device_count()
750 else:
751 torch.cuda.set_device(args.local_rank)
752 device = torch.device("cuda", args.local_rank)
753 n_gpu = 1
754 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
755 torch.distributed.init_process_group(backend='nccl')
756 logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
757 device, n_gpu, bool(args.local_rank != -1), args.fp16))
758
759 if args.gradient_accumulation_steps < 1:
760 raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
761 args.gradient_accumulation_steps))
762
763 args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
764
765 random.seed(args.seed)
766 np.random.seed(args.seed)
767 torch.manual_seed(args.seed)
768 if n_gpu > 0:
769 torch.cuda.manual_seed_all(args.seed)
770
771 if not args.do_train and not args.do_predict:
772 raise ValueError("At least one of `do_train` or `do_predict` must be True.")
773
774 if args.do_train:
775 if not args.train_file:
776 raise ValueError(
777 "If `do_train` is True, then `train_file` must be specified.")
778 if args.do_predict:
779 if not args.predict_file:
780 raise ValueError(
781 "If `do_predict` is True, then `predict_file` must be specified.")
782
783 if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
784 raise ValueError("Output directory () already exists and is not empty.")
785 os.makedirs(args.output_dir, exist_ok=True)
786
787 # Prepare Summary writer
788 if torch.distributed.get_rank() == 0 and args.job_name is not None:
789 args.summary_writer = get_summary_writer(name=args.job_name,
790 base=args.output_dir)
791 else:
792 args.summary_writer = None
793
794 tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
795
796 train_examples = None
797 num_train_steps = None
798 if args.do_train:

Callers 1

Calls 15

stepMethod · 0.95
stepMethod · 0.95
get_argument_parserFunction · 0.90
check_early_exit_warningFunction · 0.90
get_summary_writerFunction · 0.90
BertConfigClass · 0.90
BertAdamClass · 0.90
RandomSamplerClass · 0.90
write_summary_eventsFunction · 0.90
is_time_to_exitFunction · 0.90
GradientClipperClass · 0.85

Tested by

no test coverage detected