(
model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None
)
| 46 | |
| 47 | |
| 48 | def train_one_epoch( |
| 49 | model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None |
| 50 | ): |
| 51 | device = torch.device(args.device) |
| 52 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress |
| 53 | model.train() |
| 54 | loss = ClipLoss( |
| 55 | local_loss=args.local_loss, |
| 56 | gather_with_grad=args.gather_with_grad, |
| 57 | cache_labels=True, |
| 58 | rank=args.rank, |
| 59 | world_size=args.world_size, |
| 60 | use_horovod=args.horovod, |
| 61 | mlp_loss=args.clap_mlploss, |
| 62 | weight_loss_kappa=args.kappa, |
| 63 | ) |
| 64 | |
| 65 | dataloader, sampler = data["train"].dataloader, data["train"].sampler |
| 66 | if args.distributed and sampler is not None: |
| 67 | sampler.set_epoch(epoch) |
| 68 | num_batches_per_epoch = dataloader.num_batches |
| 69 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) |
| 70 | |
| 71 | # for toy dataset |
| 72 | if args.dataset_type == "toy": |
| 73 | dataloader.dataset.generate_queue() |
| 74 | |
| 75 | loss_m = AverageMeter() |
| 76 | batch_time_m = AverageMeter() |
| 77 | data_time_m = AverageMeter() |
| 78 | end = time.time() |
| 79 | |
| 80 | for i, batch in enumerate(dataloader): |
| 81 | # logging.info(f"batch {i} of {num_batches_per_epoch}") |
| 82 | step = num_batches_per_epoch * epoch + i |
| 83 | if isinstance(scheduler, dict): |
| 84 | for s in scheduler.values(): |
| 85 | s(step) |
| 86 | else: |
| 87 | scheduler(step) |
| 88 | audios = batch # contains mel_spec, wavform, and longer list |
| 89 | texts = batch["text"] |
| 90 | # audios = audios.to(device=device, non_blocking=True) |
| 91 | # texts = texts.to(device=device, non_blocking=True) |
| 92 | |
| 93 | data_time_m.update(time.time() - end) |
| 94 | if isinstance(optimizer, dict): |
| 95 | for o_ in optimizer.values(): |
| 96 | o_.zero_grad() |
| 97 | else: |
| 98 | optimizer.zero_grad() |
| 99 | |
| 100 | with autocast(): |
| 101 | ( |
| 102 | audio_features, |
| 103 | text_features, |
| 104 | audio_features_mlp, |
| 105 | text_features_mlp, |
no test coverage detected