MCPcopy
hub / github.com/microsoft/Swin-Transformer / train_one_epoch

Function train_one_epoch

main_simmim_ft.py:155–226  ·  view source on GitHub ↗
(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler)

Source from the content-addressed store, hash-verified

153
154
155def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler):
156 model.train()
157 optimizer.zero_grad()
158
159 logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}')
160
161 num_steps = len(data_loader)
162 batch_time = AverageMeter()
163 loss_meter = AverageMeter()
164 norm_meter = AverageMeter()
165 loss_scale_meter = AverageMeter()
166
167 start = time.time()
168 end = time.time()
169 for idx, (samples, targets) in enumerate(data_loader):
170 samples = samples.cuda(non_blocking=True)
171 targets = targets.cuda(non_blocking=True)
172
173 if mixup_fn is not None:
174 samples, targets = mixup_fn(samples, targets)
175
176 outputs = model(samples)
177
178 if config.TRAIN.ACCUMULATION_STEPS > 1:
179 loss = criterion(outputs, targets)
180 loss = loss / config.TRAIN.ACCUMULATION_STEPS
181 scaler.scale(loss).backward()
182 if config.TRAIN.CLIP_GRAD:
183 scaler.unscale_(optimizer)
184 grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
185 else:
186 grad_norm = get_grad_norm(model.parameters())
187 if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
188 scaler.step(optimizer)
189 optimizer.zero_grad()
190 scaler.update()
191 lr_scheduler.step_update(epoch * num_steps + idx)
192 else:
193 loss = criterion(outputs, targets)
194 optimizer.zero_grad()
195 scaler.scale(loss).backward()
196 if config.TRAIN.CLIP_GRAD:
197 scaler.unscale_(optimizer)
198 grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
199 else:
200 grad_norm = get_grad_norm(model.parameters())
201 scaler.step(optimizer)
202 scaler.update()
203 lr_scheduler.step_update(epoch * num_steps + idx)
204
205 torch.cuda.synchronize()
206
207 loss_meter.update(loss.item(), targets.size(0))
208 norm_meter.update(grad_norm)
209 loss_scale_meter.update(scaler.get_scale())
210 batch_time.update(time.time() - end)
211 end = time.time()
212

Callers 1

mainFunction · 0.70

Calls 2

get_grad_normFunction · 0.90
backwardMethod · 0.45

Tested by

no test coverage detected