Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
(pretrained_path=None, z_dim=None, device='cpu', **kwargs)
| 590 | |
| 591 | |
| 592 | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): |
| 593 | """ |
| 594 | Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. |
| 595 | """ |
| 596 | # params |
| 597 | cfg = dict( |
| 598 | dim=96, |
| 599 | z_dim=z_dim, |
| 600 | dim_mult=[1, 2, 4, 4], |
| 601 | num_res_blocks=2, |
| 602 | attn_scales=[], |
| 603 | temperal_downsample=[False, True, True], |
| 604 | dropout=0.0) |
| 605 | cfg.update(**kwargs) |
| 606 | |
| 607 | # init model |
| 608 | with torch.device('meta'): |
| 609 | model = WanVAE_(**cfg) |
| 610 | |
| 611 | # load checkpoint |
| 612 | logging.info(f'loading {pretrained_path}') |
| 613 | model.load_state_dict( |
| 614 | torch.load(pretrained_path, map_location=device), assign=True) |
| 615 | |
| 616 | return model |
| 617 | |
| 618 | |
| 619 | class WanVAE: |