MCPcopy Index your code
hub / github.com/zai-org/CogView / get_args

Function get_args

arguments.py:301–375  ·  view source on GitHub ↗

Parse all the args.

()

Source from the content-addressed store, hash-verified

299 return parser
300
301def get_args():
302 """Parse all the args."""
303
304 parser = argparse.ArgumentParser(description='PyTorch CogView Model')
305 parser = add_model_config_args(parser)
306 parser = add_fp16_config_args(parser)
307 parser = add_training_args(parser)
308 parser = add_evaluation_args(parser)
309 parser = add_text_generate_args(parser)
310 parser = add_data_args(parser)
311 parser = add_generation_api_args(parser)
312 parser = add_sparse_args(parser)
313
314 # Include DeepSpeed configuration arguments
315 parser = deepspeed.add_config_arguments(parser)
316
317 args = parser.parse_args()
318 if not args.train_data:
319 print('WARNING: No training data specified')
320 assert args.is_sparse != 1, 'use is-sparse == 2 for inference'
321 elif args.is_sparse == 1 and (args.max_position_embeddings - 1) % args.query_window != 0:
322 raise ValueError('During sparse training, the sequence length must be exactly divided by window_size.')
323
324 args.cuda = torch.cuda.is_available()
325
326 args.rank = int(os.getenv('RANK', '0'))
327 args.world_size = int(os.getenv("WORLD_SIZE", '1'))
328 if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
329 mpi_define_env(args)
330 elif os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
331 # We are using (OpenMPI) mpirun for launching distributed data parallel processes
332 local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
333 local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'))
334
335 # Possibly running with Slurm
336 num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1'))
337 nodeid = int(os.getenv('SLURM_NODEID', '0'))
338
339 args.local_rank = local_rank
340 args.rank = nodeid * local_size + local_rank
341 args.world_size = num_nodes * local_size
342
343 args.model_parallel_size = min(args.model_parallel_size, args.world_size)
344 if args.rank == 0:
345 print('using world size: {} and model-parallel size: {} '.format(
346 args.world_size, args.model_parallel_size))
347
348 args.dynamic_loss_scale = False
349 if args.loss_scale is None:
350 args.dynamic_loss_scale = True
351 if args.rank == 0:
352 print(' > using dynamic loss scaling')
353
354 # The args fp32_* or fp16_* meant to be active when the
355 # args fp16 is set. So the default behaviour should all
356 # be false.
357 if not args.fp16:
358 args.fp32_embedding = False

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 9

add_model_config_argsFunction · 0.85
add_fp16_config_argsFunction · 0.85
add_training_argsFunction · 0.85
add_evaluation_argsFunction · 0.85
add_text_generate_argsFunction · 0.85
add_data_argsFunction · 0.85
add_generation_api_argsFunction · 0.85
add_sparse_argsFunction · 0.85
mpi_define_envFunction · 0.85

Tested by

no test coverage detected