MCPcopy
hub / github.com/Robbyant/lingbot-world / forward

Method forward

wan/modules/model.py:127–156  ·  view source on GitHub ↗

r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]

(self, x, seq_lens, grid_sizes, freqs)

Source from the content-addressed store, hash-verified

125 self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
126
127 def forward(self, x, seq_lens, grid_sizes, freqs):
128 r"""
129 Args:
130 x(Tensor): Shape [B, L, num_heads, C / num_heads]
131 seq_lens(Tensor): Shape [B]
132 grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
133 freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
134 """
135 b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
136
137 # query, key, value function
138 def qkv_fn(x):
139 q = self.norm_q(self.q(x)).view(b, s, n, d)
140 k = self.norm_k(self.k(x)).view(b, s, n, d)
141 v = self.v(x).view(b, s, n, d)
142 return q, k, v
143
144 q, k, v = qkv_fn(x)
145
146 x = flash_attention(
147 q=rope_apply(q, grid_sizes, freqs),
148 k=rope_apply(k, grid_sizes, freqs),
149 v=v,
150 k_lens=seq_lens,
151 window_size=self.window_size)
152
153 # output
154 x = x.flatten(2)
155 x = self.o(x)
156 return x
157
158
159class WanCrossAttention(WanSelfAttention):

Callers

nothing calls this directly

Calls 3

flash_attentionFunction · 0.85
rope_applyFunction · 0.70
qkv_fnFunction · 0.50

Tested by

no test coverage detected