MCPcopy
hub / github.com/alibaba/EasyCV / main

Function main

tools/prune.py:75–245  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

73
74
75def main():
76 args = parse_args()
77
78 if args.model_type is not None and args.config is None:
79 assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
80 ', '.join(CONFIG_TEMPLATE_ZOO.keys()))
81 print('model_type=%s, config file will be replaced by %s' %
82 (args.model_type, CONFIG_TEMPLATE_ZOO[args.model_type]))
83 args.config = CONFIG_TEMPLATE_ZOO[args.model_type]
84
85 if args.config.startswith('http'):
86
87 r = requests.get(args.config)
88 # download config in current dir
89 tpath = args.config.split('/')[-1]
90 while not osp.exists(tpath):
91 try:
92 with open(tpath, 'wb') as code:
93 code.write(r.content)
94 except:
95 pass
96
97 args.config = tpath
98
99 cfg = mmcv_config_fromfile(args.config)
100
101 if args.user_config_params is not None:
102 assert args.model_type is not None, 'model_type must be setted'
103 # rebuild config by user config params
104 cfg = rebuild_config(cfg, args.user_config_params)
105
106 # check oss_config and init oss io
107 if cfg.get('oss_io_config', None) is not None:
108 io.access_oss(**cfg.oss_io_config)
109
110 # set cudnn_benchmark
111 if cfg.get('cudnn_benchmark', False):
112 torch.backends.cudnn.benchmark = True
113
114 # update configs according to CLI args
115 if args.work_dir is not None:
116 cfg.work_dir = args.work_dir
117
118 # if `work_dir` is oss path, redirect `work_dir` to local path, add `oss_work_dir` point to oss path,
119 # and use osssync hook to upload log and ckpt in work_dir to oss_work_dir
120 if cfg.work_dir.startswith('oss://'):
121 cfg.oss_work_dir = cfg.work_dir
122 cfg.work_dir = osp.join('work_dirs',
123 cfg.work_dir.replace('oss://', ''))
124 else:
125 cfg.oss_work_dir = None
126
127 # create work_dir
128 if not io.exists(cfg.work_dir):
129 io.makedirs(cfg.work_dir)
130
131 # init the logger before other steps
132 timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())

Callers 1

prune.pyFile · 0.70

Calls 15

mmcv_config_fromfileFunction · 0.90
rebuild_configFunction · 0.90
get_root_loggerFunction · 0.90
get_prune_layerFunction · 0.90
set_random_seedFunction · 0.90
build_modelFunction · 0.90
load_checkpointFunction · 0.90
build_yolo_optimizerFunction · 0.90
build_optimizerFunction · 0.90
build_datasetFunction · 0.90
build_dataloaderFunction · 0.90
get_num_gpu_per_nodeFunction · 0.90

Tested by

no test coverage detected