| 619 | class WanVAE: |
| 620 | |
| 621 | def __init__(self, |
| 622 | z_dim=16, |
| 623 | vae_pth='cache/vae_step_411000.pth', |
| 624 | dtype=torch.float, |
| 625 | device="cuda"): |
| 626 | self.dtype = dtype |
| 627 | self.device = device |
| 628 | |
| 629 | mean = [ |
| 630 | -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, |
| 631 | 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 |
| 632 | ] |
| 633 | std = [ |
| 634 | 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, |
| 635 | 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 |
| 636 | ] |
| 637 | self.mean = torch.tensor(mean, dtype=dtype, device=device) |
| 638 | self.std = torch.tensor(std, dtype=dtype, device=device) |
| 639 | self.scale = [self.mean, 1.0 / self.std] |
| 640 | |
| 641 | # init model |
| 642 | self.model = _video_vae( |
| 643 | pretrained_path=vae_pth, |
| 644 | z_dim=z_dim, |
| 645 | ).eval().requires_grad_(False).to(device) |
| 646 | |
| 647 | def encode(self, videos): |
| 648 | """ |