MCPcopy
hub / github.com/descriptinc/descript-audio-codec / load

Function load

scripts/train_no_adv.py:120–175  ·  view source on GitHub ↗
(
    args,
    accel: ml.Accelerator,
    tracker: Tracker,
    save_path: str,
    resume: bool = False,
    tag: str = "latest",
    load_weights: bool = False,
)

Source from the content-addressed store, hash-verified

118
119@argbind.bind(without_prefix=True)
120def load(
121 args,
122 accel: ml.Accelerator,
123 tracker: Tracker,
124 save_path: str,
125 resume: bool = False,
126 tag: str = "latest",
127 load_weights: bool = False,
128):
129 generator, g_extra = None, {}
130
131 if resume:
132 kwargs = {
133 "folder": f"{save_path}/{tag}",
134 "map_location": "cpu",
135 "package": not load_weights,
136 }
137 tracker.print(f"Resuming from {str(Path('.').absolute())}/{kwargs['folder']}")
138 if (Path(kwargs["folder"]) / "dac").exists():
139 generator, g_extra = DAC.load_from_folder(**kwargs)
140
141 generator = DAC() if generator is None else generator
142 generator = accel.prepare_model(generator)
143
144 with argbind.scope(args, "generator"):
145 optimizer_g = AdamW(generator.parameters(), use_zero=accel.use_ddp)
146 scheduler_g = ExponentialLR(optimizer_g)
147
148 if "optimizer.pth" in g_extra:
149 optimizer_g.load_state_dict(g_extra["optimizer.pth"])
150 if "scheduler.pth" in g_extra:
151 scheduler_g.load_state_dict(g_extra["scheduler.pth"])
152 if "tracker.pth" in g_extra:
153 tracker.load_state_dict(g_extra["tracker.pth"])
154
155 sample_rate = accel.unwrap(generator).sample_rate
156 with argbind.scope(args, "train"):
157 train_data = build_dataset(sample_rate)
158 with argbind.scope(args, "val"):
159 val_data = build_dataset(sample_rate)
160
161 waveform_loss = losses.L1Loss()
162 stft_loss = losses.MultiScaleSTFTLoss()
163 mel_loss = losses.MelSpectrogramLoss()
164
165 return State(
166 generator=generator,
167 optimizer_g=optimizer_g,
168 scheduler_g=scheduler_g,
169 waveform_loss=waveform_loss,
170 stft_loss=stft_loss,
171 mel_loss=mel_loss,
172 tracker=tracker,
173 train_data=train_data,
174 val_data=val_data,
175 )
176
177

Callers 1

trainFunction · 0.70

Calls 4

DACClass · 0.85
ExponentialLRFunction · 0.70
build_datasetFunction · 0.70
StateClass · 0.70

Tested by

no test coverage detected