| 184 | |
| 185 | |
| 186 | class ResidualBlock(nn.Module): |
| 187 | |
| 188 | def __init__(self, in_dim, out_dim, dropout=0.0): |
| 189 | super().__init__() |
| 190 | self.in_dim = in_dim |
| 191 | self.out_dim = out_dim |
| 192 | |
| 193 | # layers |
| 194 | self.residual = nn.Sequential( |
| 195 | RMS_norm(in_dim, images=False), nn.SiLU(), |
| 196 | CausalConv3d(in_dim, out_dim, 3, padding=1), |
| 197 | RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), |
| 198 | CausalConv3d(out_dim, out_dim, 3, padding=1)) |
| 199 | self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ |
| 200 | if in_dim != out_dim else nn.Identity() |
| 201 | |
| 202 | def forward(self, x, feat_cache=None, feat_idx=[0]): |
| 203 | h = self.shortcut(x) |
| 204 | for layer in self.residual: |
| 205 | if isinstance(layer, CausalConv3d) and feat_cache is not None: |
| 206 | idx = feat_idx[0] |
| 207 | cache_x = x[:, :, -CACHE_T:, :, :].clone() |
| 208 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: |
| 209 | # cache last frame of last two chunk |
| 210 | cache_x = torch.cat([ |
| 211 | feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( |
| 212 | cache_x.device), cache_x |
| 213 | ], |
| 214 | dim=2) |
| 215 | x = layer(x, feat_cache[idx]) |
| 216 | feat_cache[idx] = cache_x |
| 217 | feat_idx[0] += 1 |
| 218 | else: |
| 219 | x = layer(x) |
| 220 | return x + h |
| 221 | |
| 222 | |
| 223 | class AttentionBlock(nn.Module): |