MCPcopy
hub / github.com/Audio-AGI/AudioSep / train_one_epoch

Function train_one_epoch

models/CLAP/training/train.py:48–264  ·  view source on GitHub ↗
(
    model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None
)

Source from the content-addressed store, hash-verified

46
47
48def train_one_epoch(
49 model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None
50):
51 device = torch.device(args.device)
52 autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
53 model.train()
54 loss = ClipLoss(
55 local_loss=args.local_loss,
56 gather_with_grad=args.gather_with_grad,
57 cache_labels=True,
58 rank=args.rank,
59 world_size=args.world_size,
60 use_horovod=args.horovod,
61 mlp_loss=args.clap_mlploss,
62 weight_loss_kappa=args.kappa,
63 )
64
65 dataloader, sampler = data["train"].dataloader, data["train"].sampler
66 if args.distributed and sampler is not None:
67 sampler.set_epoch(epoch)
68 num_batches_per_epoch = dataloader.num_batches
69 sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
70
71 # for toy dataset
72 if args.dataset_type == "toy":
73 dataloader.dataset.generate_queue()
74
75 loss_m = AverageMeter()
76 batch_time_m = AverageMeter()
77 data_time_m = AverageMeter()
78 end = time.time()
79
80 for i, batch in enumerate(dataloader):
81 # logging.info(f"batch {i} of {num_batches_per_epoch}")
82 step = num_batches_per_epoch * epoch + i
83 if isinstance(scheduler, dict):
84 for s in scheduler.values():
85 s(step)
86 else:
87 scheduler(step)
88 audios = batch # contains mel_spec, wavform, and longer list
89 texts = batch["text"]
90 # audios = audios.to(device=device, non_blocking=True)
91 # texts = texts.to(device=device, non_blocking=True)
92
93 data_time_m.update(time.time() - end)
94 if isinstance(optimizer, dict):
95 for o_ in optimizer.values():
96 o_.zero_grad()
97 else:
98 optimizer.zero_grad()
99
100 with autocast():
101 (
102 audio_features,
103 text_features,
104 audio_features_mlp,
105 text_features_mlp,

Callers 1

mainFunction · 0.90

Calls 7

updateMethod · 0.95
resetMethod · 0.95
ClipLossClass · 0.90
is_masterFunction · 0.85
generate_queueMethod · 0.80
AverageMeterClass · 0.70
unwrap_modelFunction · 0.70

Tested by

no test coverage detected