(
args,
accel: ml.Accelerator,
tracker: Tracker,
save_path: str,
resume: bool = False,
tag: str = "latest",
load_weights: bool = False,
)
| 118 | |
| 119 | @argbind.bind(without_prefix=True) |
| 120 | def load( |
| 121 | args, |
| 122 | accel: ml.Accelerator, |
| 123 | tracker: Tracker, |
| 124 | save_path: str, |
| 125 | resume: bool = False, |
| 126 | tag: str = "latest", |
| 127 | load_weights: bool = False, |
| 128 | ): |
| 129 | generator, g_extra = None, {} |
| 130 | |
| 131 | if resume: |
| 132 | kwargs = { |
| 133 | "folder": f"{save_path}/{tag}", |
| 134 | "map_location": "cpu", |
| 135 | "package": not load_weights, |
| 136 | } |
| 137 | tracker.print(f"Resuming from {str(Path('.').absolute())}/{kwargs['folder']}") |
| 138 | if (Path(kwargs["folder"]) / "dac").exists(): |
| 139 | generator, g_extra = DAC.load_from_folder(**kwargs) |
| 140 | |
| 141 | generator = DAC() if generator is None else generator |
| 142 | generator = accel.prepare_model(generator) |
| 143 | |
| 144 | with argbind.scope(args, "generator"): |
| 145 | optimizer_g = AdamW(generator.parameters(), use_zero=accel.use_ddp) |
| 146 | scheduler_g = ExponentialLR(optimizer_g) |
| 147 | |
| 148 | if "optimizer.pth" in g_extra: |
| 149 | optimizer_g.load_state_dict(g_extra["optimizer.pth"]) |
| 150 | if "scheduler.pth" in g_extra: |
| 151 | scheduler_g.load_state_dict(g_extra["scheduler.pth"]) |
| 152 | if "tracker.pth" in g_extra: |
| 153 | tracker.load_state_dict(g_extra["tracker.pth"]) |
| 154 | |
| 155 | sample_rate = accel.unwrap(generator).sample_rate |
| 156 | with argbind.scope(args, "train"): |
| 157 | train_data = build_dataset(sample_rate) |
| 158 | with argbind.scope(args, "val"): |
| 159 | val_data = build_dataset(sample_rate) |
| 160 | |
| 161 | waveform_loss = losses.L1Loss() |
| 162 | stft_loss = losses.MultiScaleSTFTLoss() |
| 163 | mel_loss = losses.MelSpectrogramLoss() |
| 164 | |
| 165 | return State( |
| 166 | generator=generator, |
| 167 | optimizer_g=optimizer_g, |
| 168 | scheduler_g=scheduler_g, |
| 169 | waveform_loss=waveform_loss, |
| 170 | stft_loss=stft_loss, |
| 171 | mel_loss=mel_loss, |
| 172 | tracker=tracker, |
| 173 | train_data=train_data, |
| 174 | val_data=val_data, |
| 175 | ) |
| 176 | |
| 177 |
no test coverage detected