()
| 143 | |
| 144 | |
| 145 | def insert_bn(): |
| 146 | args = parser.parse_args() |
| 147 | |
| 148 | repvgg_build_func = get_RepVGG_func_by_name(args.arch) |
| 149 | |
| 150 | model = repvgg_build_func(deploy=True).cuda() |
| 151 | |
| 152 | load_checkpoint(model, args.weights) |
| 153 | |
| 154 | switch_repvggblock_to_bnstat(model) |
| 155 | |
| 156 | cudnn.benchmark = True |
| 157 | |
| 158 | trans = get_default_train_trans(args) |
| 159 | print('data aug: ', trans) |
| 160 | |
| 161 | train_dataset = get_ImageNet_train_dataset(args, trans) |
| 162 | |
| 163 | train_loader = torch.utils.data.DataLoader( |
| 164 | train_dataset, |
| 165 | batch_size=args.batch_size, shuffle=False, |
| 166 | num_workers=args.workers, pin_memory=True) |
| 167 | |
| 168 | batch_time = AverageMeter('Time', ':6.3f') |
| 169 | losses = AverageMeter('Loss', ':.4e') |
| 170 | top1 = AverageMeter('Acc@1', ':6.2f') |
| 171 | top5 = AverageMeter('Acc@5', ':6.2f') |
| 172 | |
| 173 | progress = ProgressMeter( |
| 174 | min(len(train_loader), args.num_batches), |
| 175 | [batch_time, losses, top1, top5], |
| 176 | prefix='BN stat: ') |
| 177 | |
| 178 | criterion = nn.CrossEntropyLoss().cuda() |
| 179 | |
| 180 | with torch.no_grad(): |
| 181 | end = time.time() |
| 182 | for i, (images, target) in enumerate(train_loader): |
| 183 | if i >= args.num_batches: |
| 184 | break |
| 185 | images = images.cuda(non_blocking=True) |
| 186 | target = target.cuda(non_blocking=True) |
| 187 | |
| 188 | # compute output |
| 189 | output = model(images) |
| 190 | loss = criterion(output, target) |
| 191 | |
| 192 | # measure accuracy and record loss |
| 193 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
| 194 | losses.update(loss.item(), images.size(0)) |
| 195 | top1.update(acc1[0], images.size(0)) |
| 196 | top5.update(acc5[0], images.size(0)) |
| 197 | |
| 198 | # measure elapsed time |
| 199 | batch_time.update(time.time() - end) |
| 200 | end = time.time() |
| 201 | |
| 202 | if i % 10 == 0: |
no test coverage detected