r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
(self, x, seq_lens, grid_sizes, freqs)
| 124 | self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() |
| 125 | |
| 126 | def forward(self, x, seq_lens, grid_sizes, freqs): |
| 127 | r""" |
| 128 | Args: |
| 129 | x(Tensor): Shape [B, L, num_heads, C / num_heads] |
| 130 | seq_lens(Tensor): Shape [B] |
| 131 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) |
| 132 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] |
| 133 | """ |
| 134 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
| 135 | |
| 136 | # query, key, value function |
| 137 | def qkv_fn(x): |
| 138 | q = self.norm_q(self.q(x)).view(b, s, n, d) |
| 139 | k = self.norm_k(self.k(x)).view(b, s, n, d) |
| 140 | v = self.v(x).view(b, s, n, d) |
| 141 | return q, k, v |
| 142 | |
| 143 | q, k, v = qkv_fn(x) |
| 144 | |
| 145 | x = flash_attention( |
| 146 | q=rope_apply(q, grid_sizes, freqs), |
| 147 | k=rope_apply(k, grid_sizes, freqs), |
| 148 | v=v, |
| 149 | k_lens=seq_lens, |
| 150 | window_size=self.window_size) |
| 151 | |
| 152 | # output |
| 153 | x = x.flatten(2) |
| 154 | x = self.o(x) |
| 155 | return x |
| 156 | |
| 157 | |
| 158 | class WanCrossAttention(WanSelfAttention): |
nothing calls this directly
no test coverage detected