MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / WanVAE

Class WanVAE

wan/modules/vae.py:619–663  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

617
618
619class WanVAE:
620
621 def __init__(self,
622 z_dim=16,
623 vae_pth='cache/vae_step_411000.pth',
624 dtype=torch.float,
625 device="cuda"):
626 self.dtype = dtype
627 self.device = device
628
629 mean = [
630 -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
631 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
632 ]
633 std = [
634 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
635 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
636 ]
637 self.mean = torch.tensor(mean, dtype=dtype, device=device)
638 self.std = torch.tensor(std, dtype=dtype, device=device)
639 self.scale = [self.mean, 1.0 / self.std]
640
641 # init model
642 self.model = _video_vae(
643 pretrained_path=vae_pth,
644 z_dim=z_dim,
645 ).eval().requires_grad_(False).to(device)
646
647 def encode(self, videos):
648 """
649 videos: A list of videos each with shape [C, T, H, W].
650 """
651 with amp.autocast(dtype=self.dtype):
652 return [
653 self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
654 for u in videos
655 ]
656
657 def decode(self, zs):
658 with amp.autocast(dtype=self.dtype):
659 return [
660 self.model.decode(u.unsqueeze(0),
661 self.scale).float().clamp_(-1, 1).squeeze(0)
662 for u in zs
663 ]

Callers 6

__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
mp_workerMethod · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected