MCPcopy
hub / github.com/DingXiaoH/RepVGG / insert_bn

Function insert_bn

tools/insert_bn.py:145–211  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

143
144
145def insert_bn():
146 args = parser.parse_args()
147
148 repvgg_build_func = get_RepVGG_func_by_name(args.arch)
149
150 model = repvgg_build_func(deploy=True).cuda()
151
152 load_checkpoint(model, args.weights)
153
154 switch_repvggblock_to_bnstat(model)
155
156 cudnn.benchmark = True
157
158 trans = get_default_train_trans(args)
159 print('data aug: ', trans)
160
161 train_dataset = get_ImageNet_train_dataset(args, trans)
162
163 train_loader = torch.utils.data.DataLoader(
164 train_dataset,
165 batch_size=args.batch_size, shuffle=False,
166 num_workers=args.workers, pin_memory=True)
167
168 batch_time = AverageMeter('Time', ':6.3f')
169 losses = AverageMeter('Loss', ':.4e')
170 top1 = AverageMeter('Acc@1', ':6.2f')
171 top5 = AverageMeter('Acc@5', ':6.2f')
172
173 progress = ProgressMeter(
174 min(len(train_loader), args.num_batches),
175 [batch_time, losses, top1, top5],
176 prefix='BN stat: ')
177
178 criterion = nn.CrossEntropyLoss().cuda()
179
180 with torch.no_grad():
181 end = time.time()
182 for i, (images, target) in enumerate(train_loader):
183 if i >= args.num_batches:
184 break
185 images = images.cuda(non_blocking=True)
186 target = target.cuda(non_blocking=True)
187
188 # compute output
189 output = model(images)
190 loss = criterion(output, target)
191
192 # measure accuracy and record loss
193 acc1, acc5 = accuracy(output, target, topk=(1, 5))
194 losses.update(loss.item(), images.size(0))
195 top1.update(acc1[0], images.size(0))
196 top5.update(acc5[0], images.size(0))
197
198 # measure elapsed time
199 batch_time.update(time.time() - end)
200 end = time.time()
201
202 if i % 10 == 0:

Callers 1

insert_bn.pyFile · 0.85

Calls 10

updateMethod · 0.95
displayMethod · 0.95
get_RepVGG_func_by_nameFunction · 0.90
load_checkpointFunction · 0.90
AverageMeterClass · 0.90
ProgressMeterClass · 0.90
accuracyFunction · 0.90
get_default_train_transFunction · 0.85
switch_bnstat_to_convbnFunction · 0.85

Tested by

no test coverage detected