MCPcopy
hub / github.com/microsoft/Cream / train_one_epoch_distill_using_saved_logits

Function train_one_epoch_distill_using_saved_logits

TinyViT/main.py:284–400  ·  view source on GitHub ↗
(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler)

Source from the content-addressed store, hash-verified

282
283
284def train_one_epoch_distill_using_saved_logits(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):
285 model.train()
286 set_bn_state(config, model)
287 optimizer.zero_grad()
288
289 num_steps = len(data_loader)
290 batch_time = AverageMeter()
291 loss_meter = AverageMeter()
292 norm_meter = AverageMeter()
293 scaler_meter = AverageMeter()
294 meters = defaultdict(AverageMeter)
295
296 start = time.time()
297 end = time.time()
298 data_tic = time.time()
299
300 num_classes = config.MODEL.NUM_CLASSES
301 topk = config.DISTILL.LOGITS_TOPK
302
303 for idx, ((samples, targets), (logits_index, logits_value, seeds)) in enumerate(data_loader):
304 normal_global_idx = epoch * NORM_ITER_LEN + \
305 (idx * NORM_ITER_LEN // num_steps)
306
307 samples = samples.cuda(non_blocking=True)
308 targets = targets.cuda(non_blocking=True)
309
310 if mixup_fn is not None:
311 samples, targets = mixup_fn(samples, targets, seeds)
312 original_targets = targets.argmax(dim=1)
313 else:
314 original_targets = targets
315 meters['data_time'].update(time.time() - data_tic)
316
317 with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
318 outputs = model(samples)
319
320 # recover teacher logits
321 logits_index = logits_index.long()
322 logits_value = logits_value.float()
323 logits_index = logits_index.cuda(non_blocking=True)
324 logits_value = logits_value.cuda(non_blocking=True)
325 minor_value = (1.0 - logits_value.sum(-1, keepdim=True)
326 ) / (num_classes - topk)
327 minor_value = minor_value.repeat_interleave(num_classes, dim=-1)
328 outputs_teacher = minor_value.scatter_(-1, logits_index, logits_value)
329
330 loss = criterion(outputs, outputs_teacher)
331 loss = loss / config.TRAIN.ACCUMULATION_STEPS
332
333 # this attribute is added by timm on one optimizer (adahessian)
334 is_second_order = hasattr(
335 optimizer, 'is_second_order') and optimizer.is_second_order
336 grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
337 parameters=model.parameters(), create_graph=is_second_order,
338 update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
339 if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
340 optimizer.zero_grad()
341 lr_scheduler.step_update(

Callers 1

mainFunction · 0.85

Calls 13

updateMethod · 0.95
AverageMeterClass · 0.90
accuracyFunction · 0.90
is_main_processFunction · 0.90
is_valid_grad_normFunction · 0.85
zero_gradMethod · 0.80
step_updateMethod · 0.80
set_bn_stateFunction · 0.70
trainMethod · 0.45
getMethod · 0.45
state_dictMethod · 0.45
logMethod · 0.45

Tested by

no test coverage detected