MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / __init__

Method __init__

wan/modules/vae.py:371–421  ·  view source on GitHub ↗
(self,
                 dim=128,
                 z_dim=4,
                 dim_mult=[1, 2, 4, 4],
                 num_res_blocks=2,
                 attn_scales=[],
                 temperal_upsample=[False, True, True],
                 dropout=0.0)

Source from the content-addressed store, hash-verified

369class Decoder3d(nn.Module):
370
371 def __init__(self,
372 dim=128,
373 z_dim=4,
374 dim_mult=[1, 2, 4, 4],
375 num_res_blocks=2,
376 attn_scales=[],
377 temperal_upsample=[False, True, True],
378 dropout=0.0):
379 super().__init__()
380 self.dim = dim
381 self.z_dim = z_dim
382 self.dim_mult = dim_mult
383 self.num_res_blocks = num_res_blocks
384 self.attn_scales = attn_scales
385 self.temperal_upsample = temperal_upsample
386
387 # dimensions
388 dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
389 scale = 1.0 / 2**(len(dim_mult) - 2)
390
391 # init block
392 self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
393
394 # middle blocks
395 self.middle = nn.Sequential(
396 ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
397 ResidualBlock(dims[0], dims[0], dropout))
398
399 # upsample blocks
400 upsamples = []
401 for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
402 # residual (+attention) blocks
403 if i == 1 or i == 2 or i == 3:
404 in_dim = in_dim // 2
405 for _ in range(num_res_blocks + 1):
406 upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
407 if scale in attn_scales:
408 upsamples.append(AttentionBlock(out_dim))
409 in_dim = out_dim
410
411 # upsample block
412 if i != len(dim_mult) - 1:
413 mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
414 upsamples.append(Resample(out_dim, mode=mode))
415 scale *= 2.0
416 self.upsamples = nn.Sequential(*upsamples)
417
418 # output blocks
419 self.head = nn.Sequential(
420 RMS_norm(out_dim, images=False), nn.SiLU(),
421 CausalConv3d(out_dim, 3, 3, padding=1))
422
423 def forward(self, x, feat_cache=None, feat_idx=[0]):
424 ## conv1

Callers

nothing calls this directly

Calls 6

CausalConv3dClass · 0.85
ResidualBlockClass · 0.85
ResampleClass · 0.85
RMS_normClass · 0.85
AttentionBlockClass · 0.70
__init__Method · 0.45

Tested by

no test coverage detected