MCPcopy Index your code
hub / github.com/MeiGen-AI/InfiniteTalk / RMS_norm

Class RMS_norm

wan/modules/vae.py:39–54  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

37
38
39class RMS_norm(nn.Module):
40
41 def __init__(self, dim, channel_first=True, images=True, bias=False):
42 super().__init__()
43 broadcastable_dims = (1, 1, 1) if not images else (1, 1)
44 shape = (dim, *broadcastable_dims) if channel_first else (dim,)
45
46 self.channel_first = channel_first
47 self.scale = dim**0.5
48 self.gamma = nn.Parameter(torch.ones(shape))
49 self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
50
51 def forward(self, x):
52 return F.normalize(
53 x, dim=(1 if self.channel_first else
54 -1)) * self.scale * self.gamma + self.bias
55
56
57class Upsample(nn.Upsample):

Callers 4

__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected