MCPcopy
hub / github.com/Robbyant/lingbot-world / forward

Method forward

wan/modules/vae2_2.py:111–168  ·  view source on GitHub ↗
(self, x, feat_cache=None, feat_idx=[0])

Source from the content-addressed store, hash-verified

109 self.resample = nn.Identity()
110
111 def forward(self, x, feat_cache=None, feat_idx=[0]):
112 b, c, t, h, w = x.size()
113 if self.mode == "upsample3d":
114 if feat_cache is not None:
115 idx = feat_idx[0]
116 if feat_cache[idx] is None:
117 feat_cache[idx] = "Rep"
118 feat_idx[0] += 1
119 else:
120 cache_x = x[:, :, -CACHE_T:, :, :].clone()
121 if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
122 feat_cache[idx] != "Rep"):
123 # cache last frame of last two chunk
124 cache_x = torch.cat(
125 [
126 feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
127 cache_x.device),
128 cache_x,
129 ],
130 dim=2,
131 )
132 if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
133 feat_cache[idx] == "Rep"):
134 cache_x = torch.cat(
135 [
136 torch.zeros_like(cache_x).to(cache_x.device),
137 cache_x
138 ],
139 dim=2,
140 )
141 if feat_cache[idx] == "Rep":
142 x = self.time_conv(x)
143 else:
144 x = self.time_conv(x, feat_cache[idx])
145 feat_cache[idx] = cache_x
146 feat_idx[0] += 1
147 x = x.reshape(b, 2, c, t, h, w)
148 x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
149 3)
150 x = x.reshape(b, c, t * 2, h, w)
151 t = x.shape[2]
152 x = rearrange(x, "b c t h w -> (b t) c h w")
153 x = self.resample(x)
154 x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
155
156 if self.mode == "downsample3d":
157 if feat_cache is not None:
158 idx = feat_idx[0]
159 if feat_cache[idx] is None:
160 feat_cache[idx] = x.clone()
161 feat_idx[0] += 1
162 else:
163 cache_x = x[:, :, -1:, :, :].clone()
164 x = self.time_conv(
165 torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
166 feat_cache[idx] = cache_x
167 feat_idx[0] += 1
168 return x

Callers

nothing calls this directly

Calls 2

sizeMethod · 0.80
toMethod · 0.80

Tested by

no test coverage detected