MCPcopy
hub / github.com/Wan-Video/Wan2.2 / forward

Method forward

wan/modules/t5.py:86–120  ·  view source on GitHub ↗

x: [B, L1, C]. context: [B, L2, C] or None. mask: [B, L2] or [B, L1, L2] or None.

(self, x, context=None, mask=None, pos_bias=None)

Source from the content-addressed store, hash-verified

84 self.dropout = nn.Dropout(dropout)
85
86 def forward(self, x, context=None, mask=None, pos_bias=None):
87 """
88 x: [B, L1, C].
89 context: [B, L2, C] or None.
90 mask: [B, L2] or [B, L1, L2] or None.
91 """
92 # check inputs
93 context = x if context is None else context
94 b, n, c = x.size(0), self.num_heads, self.head_dim
95
96 # compute query, key, value
97 q = self.q(x).view(b, -1, n, c)
98 k = self.k(context).view(b, -1, n, c)
99 v = self.v(context).view(b, -1, n, c)
100
101 # attention bias
102 attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
103 if pos_bias is not None:
104 attn_bias += pos_bias
105 if mask is not None:
106 assert mask.ndim in [2, 3]
107 mask = mask.view(b, 1, 1,
108 -1) if mask.ndim == 2 else mask.unsqueeze(1)
109 attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
110
111 # compute attention (T5 does not use scaling)
112 attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
113 attn = F.softmax(attn.float(), dim=-1).type_as(attn)
114 x = torch.einsum('bnij,bjnc->binc', attn, v)
115
116 # output
117 x = x.reshape(b, -1, n * c)
118 x = self.o(x)
119 x = self.dropout(x)
120 return x
121
122
123class T5FeedForward(nn.Module):

Callers

nothing calls this directly

Calls 2

sizeMethod · 0.80
type_asMethod · 0.80

Tested by

no test coverage detected