MCPcopy Index your code
hub / github.com/MeiGen-AI/InfiniteTalk / Resample

Class Resample

wan/modules/vae.py:66–183  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

64
65
66class Resample(nn.Module):
67
68 def __init__(self, dim, mode):
69 assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
70 'downsample3d')
71 super().__init__()
72 self.dim = dim
73 self.mode = mode
74
75 # layers
76 if mode == 'upsample2d':
77 self.resample = nn.Sequential(
78 Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
79 nn.Conv2d(dim, dim // 2, 3, padding=1))
80 elif mode == 'upsample3d':
81 self.resample = nn.Sequential(
82 Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83 nn.Conv2d(dim, dim // 2, 3, padding=1))
84 self.time_conv = CausalConv3d(
85 dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
86
87 elif mode == 'downsample2d':
88 self.resample = nn.Sequential(
89 nn.ZeroPad2d((0, 1, 0, 1)),
90 nn.Conv2d(dim, dim, 3, stride=(2, 2)))
91 elif mode == 'downsample3d':
92 self.resample = nn.Sequential(
93 nn.ZeroPad2d((0, 1, 0, 1)),
94 nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95 self.time_conv = CausalConv3d(
96 dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
97
98 else:
99 self.resample = nn.Identity()
100
101 def forward(self, x, feat_cache=None, feat_idx=[0]):
102 b, c, t, h, w = x.size()
103 if self.mode == 'upsample3d':
104 if feat_cache is not None:
105 idx = feat_idx[0]
106 if feat_cache[idx] is None:
107 feat_cache[idx] = 'Rep'
108 feat_idx[0] += 1
109 else:
110
111 cache_x = x[:, :, -CACHE_T:, :, :].clone()
112 if cache_x.shape[2] < 2 and feat_cache[
113 idx] is not None and feat_cache[idx] != 'Rep':
114 # cache last frame of last two chunk
115 cache_x = torch.cat([
116 feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
117 cache_x.device), cache_x
118 ],
119 dim=2)
120 if cache_x.shape[2] < 2 and feat_cache[
121 idx] is not None and feat_cache[idx] == 'Rep':
122 cache_x = torch.cat([
123 torch.zeros_like(cache_x).to(cache_x.device),

Callers 2

__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected