(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0)
| 265 | class Encoder3d(nn.Module): |
| 266 | |
| 267 | def __init__(self, |
| 268 | dim=128, |
| 269 | z_dim=4, |
| 270 | dim_mult=[1, 2, 4, 4], |
| 271 | num_res_blocks=2, |
| 272 | attn_scales=[], |
| 273 | temperal_downsample=[True, True, False], |
| 274 | dropout=0.0): |
| 275 | super().__init__() |
| 276 | self.dim = dim |
| 277 | self.z_dim = z_dim |
| 278 | self.dim_mult = dim_mult |
| 279 | self.num_res_blocks = num_res_blocks |
| 280 | self.attn_scales = attn_scales |
| 281 | self.temperal_downsample = temperal_downsample |
| 282 | |
| 283 | # dimensions |
| 284 | dims = [dim * u for u in [1] + dim_mult] |
| 285 | scale = 1.0 |
| 286 | |
| 287 | # init block |
| 288 | self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) |
| 289 | |
| 290 | # downsample blocks |
| 291 | downsamples = [] |
| 292 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): |
| 293 | # residual (+attention) blocks |
| 294 | for _ in range(num_res_blocks): |
| 295 | downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) |
| 296 | if scale in attn_scales: |
| 297 | downsamples.append(AttentionBlock(out_dim)) |
| 298 | in_dim = out_dim |
| 299 | |
| 300 | # downsample block |
| 301 | if i != len(dim_mult) - 1: |
| 302 | mode = 'downsample3d' if temperal_downsample[ |
| 303 | i] else 'downsample2d' |
| 304 | downsamples.append(Resample(out_dim, mode=mode)) |
| 305 | scale /= 2.0 |
| 306 | self.downsamples = nn.Sequential(*downsamples) |
| 307 | |
| 308 | # middle blocks |
| 309 | self.middle = nn.Sequential( |
| 310 | ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), |
| 311 | ResidualBlock(out_dim, out_dim, dropout)) |
| 312 | |
| 313 | # output blocks |
| 314 | self.head = nn.Sequential( |
| 315 | RMS_norm(out_dim, images=False), nn.SiLU(), |
| 316 | CausalConv3d(out_dim, z_dim, 3, padding=1)) |
| 317 | |
| 318 | def forward(self, x, feat_cache=None, feat_idx=[0]): |
| 319 | if feat_cache is not None: |
nothing calls this directly
no test coverage detected