(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler)
| 153 | |
| 154 | |
| 155 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler): |
| 156 | model.train() |
| 157 | optimizer.zero_grad() |
| 158 | |
| 159 | logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}') |
| 160 | |
| 161 | num_steps = len(data_loader) |
| 162 | batch_time = AverageMeter() |
| 163 | loss_meter = AverageMeter() |
| 164 | norm_meter = AverageMeter() |
| 165 | loss_scale_meter = AverageMeter() |
| 166 | |
| 167 | start = time.time() |
| 168 | end = time.time() |
| 169 | for idx, (samples, targets) in enumerate(data_loader): |
| 170 | samples = samples.cuda(non_blocking=True) |
| 171 | targets = targets.cuda(non_blocking=True) |
| 172 | |
| 173 | if mixup_fn is not None: |
| 174 | samples, targets = mixup_fn(samples, targets) |
| 175 | |
| 176 | outputs = model(samples) |
| 177 | |
| 178 | if config.TRAIN.ACCUMULATION_STEPS > 1: |
| 179 | loss = criterion(outputs, targets) |
| 180 | loss = loss / config.TRAIN.ACCUMULATION_STEPS |
| 181 | scaler.scale(loss).backward() |
| 182 | if config.TRAIN.CLIP_GRAD: |
| 183 | scaler.unscale_(optimizer) |
| 184 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) |
| 185 | else: |
| 186 | grad_norm = get_grad_norm(model.parameters()) |
| 187 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: |
| 188 | scaler.step(optimizer) |
| 189 | optimizer.zero_grad() |
| 190 | scaler.update() |
| 191 | lr_scheduler.step_update(epoch * num_steps + idx) |
| 192 | else: |
| 193 | loss = criterion(outputs, targets) |
| 194 | optimizer.zero_grad() |
| 195 | scaler.scale(loss).backward() |
| 196 | if config.TRAIN.CLIP_GRAD: |
| 197 | scaler.unscale_(optimizer) |
| 198 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) |
| 199 | else: |
| 200 | grad_norm = get_grad_norm(model.parameters()) |
| 201 | scaler.step(optimizer) |
| 202 | scaler.update() |
| 203 | lr_scheduler.step_update(epoch * num_steps + idx) |
| 204 | |
| 205 | torch.cuda.synchronize() |
| 206 | |
| 207 | loss_meter.update(loss.item(), targets.size(0)) |
| 208 | norm_meter.update(grad_norm) |
| 209 | loss_scale_meter.update(scaler.get_scale()) |
| 210 | batch_time.update(time.time() - end) |
| 211 | end = time.time() |
| 212 |
no test coverage detected