MCPcopy
hub / github.com/microsoft/Cream / train_one_epoch

Function train_one_epoch

TinyViT/main.py:195–281  ·  view source on GitHub ↗
(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler)

Source from the content-addressed store, hash-verified

193
194
195def train_one_epoch(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):
196 model.train()
197 set_bn_state(config, model)
198 optimizer.zero_grad()
199
200 num_steps = len(data_loader)
201 batch_time = AverageMeter()
202 loss_meter = AverageMeter()
203 norm_meter = AverageMeter()
204 scaler_meter = AverageMeter()
205 acc1_meter = AverageMeter()
206 acc5_meter = AverageMeter()
207
208 start = time.time()
209 end = time.time()
210 for idx, (samples, targets) in enumerate(data_loader):
211 normal_global_idx = epoch * NORM_ITER_LEN + \
212 (idx * NORM_ITER_LEN // num_steps)
213
214 samples = samples.cuda(non_blocking=True)
215 targets = targets.cuda(non_blocking=True)
216
217 if mixup_fn is not None:
218 samples, targets = mixup_fn(samples, targets)
219 original_targets = targets.argmax(dim=1)
220 else:
221 original_targets = targets
222
223 with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
224 outputs = model(samples)
225
226 loss = criterion(outputs, targets)
227 loss = loss / config.TRAIN.ACCUMULATION_STEPS
228
229 # this attribute is added by timm on one optimizer (adahessian)
230 is_second_order = hasattr(
231 optimizer, 'is_second_order') and optimizer.is_second_order
232 grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
233 parameters=model.parameters(), create_graph=is_second_order,
234 update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
235 if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
236 optimizer.zero_grad()
237 lr_scheduler.step_update(
238 (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
239 loss_scale_value = loss_scaler.state_dict().get("scale", 1.0)
240
241 with torch.no_grad():
242 acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5))
243 acc1_meter.update(acc1.item(), targets.size(0))
244 acc5_meter.update(acc5.item(), targets.size(0))
245
246 torch.cuda.synchronize()
247
248 loss_meter.update(loss.item(), targets.size(0))
249 if is_valid_grad_norm(grad_norm):
250 norm_meter.update(grad_norm)
251 scaler_meter.update(loss_scale_value)
252 batch_time.update(time.time() - end)

Callers 1

mainFunction · 0.70

Calls 13

updateMethod · 0.95
AverageMeterClass · 0.90
accuracyFunction · 0.90
is_main_processFunction · 0.90
is_valid_grad_normFunction · 0.85
zero_gradMethod · 0.80
step_updateMethod · 0.80
set_bn_stateFunction · 0.70
trainMethod · 0.45
getMethod · 0.45
state_dictMethod · 0.45
sizeMethod · 0.45

Tested by

no test coverage detected