(gpu, ngpus_per_node, args)
| 167 | return trans |
| 168 | |
| 169 | def main_worker(gpu, ngpus_per_node, args): |
| 170 | global best_acc1 |
| 171 | args.gpu = gpu |
| 172 | log_file = 'quant_{}_exp.txt'.format(args.tag) |
| 173 | |
| 174 | if args.gpu is not None: |
| 175 | print("Use GPU: {} for training".format(args.gpu)) |
| 176 | |
| 177 | if args.distributed: |
| 178 | if args.dist_url == "env://" and args.rank == -1: |
| 179 | args.rank = int(os.environ["RANK"]) |
| 180 | if args.multiprocessing_distributed: |
| 181 | # For multiprocessing distributed training, rank needs to be the |
| 182 | # global rank among all the processes |
| 183 | args.rank = args.rank * ngpus_per_node + gpu |
| 184 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| 185 | world_size=args.world_size, rank=args.rank) |
| 186 | |
| 187 | # 1. Build and load base model |
| 188 | from repvgg import get_RepVGG_func_by_name |
| 189 | repvgg_build_func = get_RepVGG_func_by_name(args.arch) |
| 190 | base_model = repvgg_build_func(deploy=True) |
| 191 | from tools.insert_bn import directly_insert_bn_without_init |
| 192 | directly_insert_bn_without_init(base_model) |
| 193 | if args.base_weights is not None: |
| 194 | load_checkpoint(base_model, args.base_weights) |
| 195 | |
| 196 | # 2. |
| 197 | if not args.fpfinetune: |
| 198 | from quantization.repvgg_quantized import RepVGGWholeQuant |
| 199 | qat_model = RepVGGWholeQuant(repvgg_model=base_model, quantlayers=args.quantlayers) |
| 200 | qat_model.prepare_quant() |
| 201 | else: |
| 202 | qat_model = base_model |
| 203 | log_msg('===================== not QAT, just full-precision finetune ===========', log_file) |
| 204 | |
| 205 | #=================================================== |
| 206 | # From now on, the code will be very similar to ordinary training |
| 207 | # =================================================== |
| 208 | |
| 209 | is_main = not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0) |
| 210 | |
| 211 | if is_main: |
| 212 | for n, p in qat_model.named_parameters(): |
| 213 | print(n, p.size()) |
| 214 | for n, p in qat_model.named_buffers(): |
| 215 | print(n, p.size()) |
| 216 | log_msg('epochs {}, lr {}, weight_decay {}'.format(args.epochs, args.lr, args.weight_decay), log_file) |
| 217 | # You will see it now has quantization-related parameters (zero-points and scales) |
| 218 | |
| 219 | if not torch.cuda.is_available(): |
| 220 | print('using CPU, this will be slow') |
| 221 | elif args.distributed: |
| 222 | if args.gpu is not None: |
| 223 | torch.cuda.set_device(args.gpu) |
| 224 | qat_model.cuda(args.gpu) |
| 225 | args.batch_size = int(args.batch_size / ngpus_per_node) |
| 226 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) |
no test coverage detected