()
| 347 | |
| 348 | |
| 349 | def export_decoder(): |
| 350 | decoder = DVAE( |
| 351 | decoder_config=asdict(chattts_config.decoder), |
| 352 | dim=chattts_config.decoder.idim, |
| 353 | ).eval() |
| 354 | decoder.load_state_dict( |
| 355 | torch.load( |
| 356 | asdict(chattts_config.path)["decoder_ckpt_path"], |
| 357 | weights_only=True, |
| 358 | mmap=True, |
| 359 | ) |
| 360 | ) |
| 361 | |
| 362 | for param in decoder.parameters(): |
| 363 | param.requires_grad = False |
| 364 | rand_input = torch.rand([1, 768, 1024], requires_grad=False) |
| 365 | |
| 366 | def mydec(_inp): |
| 367 | return decoder(_inp, mode="decode") |
| 368 | |
| 369 | jitmodel = jit.trace(mydec, [rand_input]) |
| 370 | jit.save(jitmodel, f"{args.out_dir}/decoder_jit.pt") |
| 371 | |
| 372 | |
| 373 | def export_vocos(): |
no test coverage detected