(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0)
| 369 | class Decoder3d(nn.Module): |
| 370 | |
| 371 | def __init__(self, |
| 372 | dim=128, |
| 373 | z_dim=4, |
| 374 | dim_mult=[1, 2, 4, 4], |
| 375 | num_res_blocks=2, |
| 376 | attn_scales=[], |
| 377 | temperal_upsample=[False, True, True], |
| 378 | dropout=0.0): |
| 379 | super().__init__() |
| 380 | self.dim = dim |
| 381 | self.z_dim = z_dim |
| 382 | self.dim_mult = dim_mult |
| 383 | self.num_res_blocks = num_res_blocks |
| 384 | self.attn_scales = attn_scales |
| 385 | self.temperal_upsample = temperal_upsample |
| 386 | |
| 387 | # dimensions |
| 388 | dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] |
| 389 | scale = 1.0 / 2**(len(dim_mult) - 2) |
| 390 | |
| 391 | # init block |
| 392 | self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) |
| 393 | |
| 394 | # middle blocks |
| 395 | self.middle = nn.Sequential( |
| 396 | ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), |
| 397 | ResidualBlock(dims[0], dims[0], dropout)) |
| 398 | |
| 399 | # upsample blocks |
| 400 | upsamples = [] |
| 401 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): |
| 402 | # residual (+attention) blocks |
| 403 | if i == 1 or i == 2 or i == 3: |
| 404 | in_dim = in_dim // 2 |
| 405 | for _ in range(num_res_blocks + 1): |
| 406 | upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) |
| 407 | if scale in attn_scales: |
| 408 | upsamples.append(AttentionBlock(out_dim)) |
| 409 | in_dim = out_dim |
| 410 | |
| 411 | # upsample block |
| 412 | if i != len(dim_mult) - 1: |
| 413 | mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' |
| 414 | upsamples.append(Resample(out_dim, mode=mode)) |
| 415 | scale *= 2.0 |
| 416 | self.upsamples = nn.Sequential(*upsamples) |
| 417 | |
| 418 | # output blocks |
| 419 | self.head = nn.Sequential( |
| 420 | RMS_norm(out_dim, images=False), nn.SiLU(), |
| 421 | CausalConv3d(out_dim, 3, 3, padding=1)) |
| 422 | |
| 423 | def forward(self, x, feat_cache=None, feat_idx=[0]): |
| 424 | ## conv1 |
nothing calls this directly
no test coverage detected