MCPcopy
hub / github.com/DingXiaoH/RepVGG / main_worker

Function main_worker

quantization/quant_qat_train.py:169–315  ·  view source on GitHub ↗
(gpu, ngpus_per_node, args)

Source from the content-addressed store, hash-verified

167 return trans
168
169def main_worker(gpu, ngpus_per_node, args):
170 global best_acc1
171 args.gpu = gpu
172 log_file = 'quant_{}_exp.txt'.format(args.tag)
173
174 if args.gpu is not None:
175 print("Use GPU: {} for training".format(args.gpu))
176
177 if args.distributed:
178 if args.dist_url == "env://" and args.rank == -1:
179 args.rank = int(os.environ["RANK"])
180 if args.multiprocessing_distributed:
181 # For multiprocessing distributed training, rank needs to be the
182 # global rank among all the processes
183 args.rank = args.rank * ngpus_per_node + gpu
184 dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
185 world_size=args.world_size, rank=args.rank)
186
187 # 1. Build and load base model
188 from repvgg import get_RepVGG_func_by_name
189 repvgg_build_func = get_RepVGG_func_by_name(args.arch)
190 base_model = repvgg_build_func(deploy=True)
191 from tools.insert_bn import directly_insert_bn_without_init
192 directly_insert_bn_without_init(base_model)
193 if args.base_weights is not None:
194 load_checkpoint(base_model, args.base_weights)
195
196 # 2.
197 if not args.fpfinetune:
198 from quantization.repvgg_quantized import RepVGGWholeQuant
199 qat_model = RepVGGWholeQuant(repvgg_model=base_model, quantlayers=args.quantlayers)
200 qat_model.prepare_quant()
201 else:
202 qat_model = base_model
203 log_msg('===================== not QAT, just full-precision finetune ===========', log_file)
204
205 #===================================================
206 # From now on, the code will be very similar to ordinary training
207 # ===================================================
208
209 is_main = not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0)
210
211 if is_main:
212 for n, p in qat_model.named_parameters():
213 print(n, p.size())
214 for n, p in qat_model.named_buffers():
215 print(n, p.size())
216 log_msg('epochs {}, lr {}, weight_decay {}'.format(args.epochs, args.lr, args.weight_decay), log_file)
217 # You will see it now has quantization-related parameters (zero-points and scales)
218
219 if not torch.cuda.is_available():
220 print('using CPU, this will be slow')
221 elif args.distributed:
222 if args.gpu is not None:
223 torch.cuda.set_device(args.gpu)
224 qat_model.cuda(args.gpu)
225 args.batch_size = int(args.batch_size / ngpus_per_node)
226 args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

Callers 1

mainFunction · 0.85

Calls 12

get_RepVGG_func_by_nameFunction · 0.90
RepVGGWholeQuantClass · 0.90
load_checkpointFunction · 0.85
log_msgFunction · 0.85
sgd_optimizerFunction · 0.85
trainFunction · 0.85
prepare_quantMethod · 0.80
set_epochMethod · 0.80
validateFunction · 0.70
save_checkpointFunction · 0.70

Tested by

no test coverage detected