MCPcopy
hub / github.com/zai-org/CogVideo / get_args

Function get_args

sat/arguments.py:54–176  ·  view source on GitHub ↗

Parse all the args.

(args_list=None, parser=None)

Source from the content-addressed store, hash-verified

52
53
54def get_args(args_list=None, parser=None):
55 """Parse all the args."""
56 if parser is None:
57 parser = argparse.ArgumentParser(description="sat")
58 else:
59 assert isinstance(parser, argparse.ArgumentParser)
60 parser = add_model_config_args(parser)
61 parser = add_sampling_config_args(parser)
62 parser = add_training_args(parser)
63 parser = add_evaluation_args(parser)
64 parser = add_data_args(parser)
65
66 import deepspeed
67
68 parser = deepspeed.add_config_arguments(parser)
69
70 args = parser.parse_args(args_list)
71 args = process_config_to_args(args)
72
73 if not args.train_data:
74 print_rank0("No training data specified", level="WARNING")
75
76 assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
77 if args.train_iters is None and args.epochs is None:
78 args.train_iters = 10000 # default 10k iters
79 print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
80
81 args.cuda = torch.cuda.is_available()
82
83 args.rank = int(os.getenv("RANK", "0"))
84 args.world_size = int(os.getenv("WORLD_SIZE", "1"))
85 if args.local_rank is None:
86 args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
87
88 if args.device == -1:
89 if torch.cuda.device_count() == 0:
90 args.device = "cpu"
91 elif args.local_rank is not None:
92 args.device = args.local_rank
93 else:
94 args.device = args.rank % torch.cuda.device_count()
95
96 if args.local_rank != args.device and args.mode != "inference":
97 raise ValueError(
98 "LOCAL_RANK (default 0) and args.device inconsistent. "
99 "This can only happens in inference mode. "
100 "Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. "
101 )
102
103 if args.rank == 0:
104 print_rank0("using world size: {}".format(args.world_size))
105
106 if args.train_data_weights is not None:
107 assert len(args.train_data_weights) == len(args.train_data)
108
109 if args.mode != "inference": # training with deepspeed
110 args.deepspeed = True
111 if args.deepspeed_config is None: # not specified

Callers 2

train_video.pyFile · 0.90
sample_video.pyFile · 0.90

Calls 5

add_model_config_argsFunction · 0.85
add_sampling_config_argsFunction · 0.85
process_config_to_argsFunction · 0.85
initialize_distributedFunction · 0.85
loadMethod · 0.80

Tested by

no test coverage detected