(model, data, epoch, optimizer, scaler, scheduler, scheduler_l0, args, tb_writer=None, start_iter=0, zs=None)
| 82 | |
| 83 | |
| 84 | def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, scheduler_l0, args, tb_writer=None, start_iter=0, zs=None): |
| 85 | |
| 86 | global NAN_LOSS_CNT |
| 87 | |
| 88 | device = torch.device(args.device) |
| 89 | autocast = get_autocast(args.precision) |
| 90 | |
| 91 | image_autocast = get_autocast(args.image_precision) |
| 92 | text_autocast = get_autocast(args.text_precision) |
| 93 | logit_autocast = get_autocast(args.logit_precision) |
| 94 | |
| 95 | model.set_autocast( |
| 96 | image_autocast=image_autocast, |
| 97 | text_autocast=text_autocast, |
| 98 | logit_autocast=logit_autocast) |
| 99 | |
| 100 | teacher_autocast = torch.cuda.amp.autocast |
| 101 | |
| 102 | model_without_ddp = unwrap_model(model) |
| 103 | |
| 104 | distillation = args.distillation |
| 105 | if distillation: |
| 106 | teacher_model = model_without_ddp.teacher[0] |
| 107 | |
| 108 | model.train() |
| 109 | loss_kwargs = dict( |
| 110 | local_loss=args.local_loss, |
| 111 | gather_with_grad=args.gather_with_grad, |
| 112 | cache_labels=True, |
| 113 | rank=args.rank, |
| 114 | world_size=args.world_size, |
| 115 | use_horovod=args.horovod) |
| 116 | |
| 117 | if start_iter == 0: |
| 118 | # set epoch in process safe manner via sampler or shared_epoch |
| 119 | data['train'].set_epoch(epoch) |
| 120 | dataloader = data['train'].dataloader |
| 121 | |
| 122 | dataloader.device = args.device |
| 123 | if distillation: |
| 124 | soft_loss_fn = ClipSoftLoss(**loss_kwargs) # , ignore_diag=True) |
| 125 | else: |
| 126 | soft_loss_fn = None |
| 127 | |
| 128 | hard_loss_fn = ClipLoss(**loss_kwargs) |
| 129 | |
| 130 | dataloader, sampler = data['train'].dataloader, data['train'].sampler |
| 131 | if args.distributed and sampler is not None and start_iter == 0: |
| 132 | # [DO NOT REMOVE IT] it will call set_epoch even if sampler is not a DistributedSampler. |
| 133 | sampler.set_epoch(epoch) |
| 134 | |
| 135 | num_batches_per_epoch = dataloader.num_batches |
| 136 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) |
| 137 | |
| 138 | loss_m = AverageMeter() |
| 139 | metrics = defaultdict(AverageMeter) |
| 140 | end = time.time() |
| 141 | batch_size = dataloader.batch_size |
no test coverage detected