MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / ResidualBlock

Class ResidualBlock

wan/modules/vae.py:186–220  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

184
185
186class ResidualBlock(nn.Module):
187
188 def __init__(self, in_dim, out_dim, dropout=0.0):
189 super().__init__()
190 self.in_dim = in_dim
191 self.out_dim = out_dim
192
193 # layers
194 self.residual = nn.Sequential(
195 RMS_norm(in_dim, images=False), nn.SiLU(),
196 CausalConv3d(in_dim, out_dim, 3, padding=1),
197 RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
198 CausalConv3d(out_dim, out_dim, 3, padding=1))
199 self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
200 if in_dim != out_dim else nn.Identity()
201
202 def forward(self, x, feat_cache=None, feat_idx=[0]):
203 h = self.shortcut(x)
204 for layer in self.residual:
205 if isinstance(layer, CausalConv3d) and feat_cache is not None:
206 idx = feat_idx[0]
207 cache_x = x[:, :, -CACHE_T:, :, :].clone()
208 if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
209 # cache last frame of last two chunk
210 cache_x = torch.cat([
211 feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
212 cache_x.device), cache_x
213 ],
214 dim=2)
215 x = layer(x, feat_cache[idx])
216 feat_cache[idx] = cache_x
217 feat_idx[0] += 1
218 else:
219 x = layer(x)
220 return x + h
221
222
223class AttentionBlock(nn.Module):

Callers 2

__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected