x: [B, L1, C]. context: [B, L2, C] or None. mask: [B, L2] or [B, L1, L2] or None.
(self, x, context=None, mask=None, pos_bias=None)
| 84 | self.dropout = nn.Dropout(dropout) |
| 85 | |
| 86 | def forward(self, x, context=None, mask=None, pos_bias=None): |
| 87 | """ |
| 88 | x: [B, L1, C]. |
| 89 | context: [B, L2, C] or None. |
| 90 | mask: [B, L2] or [B, L1, L2] or None. |
| 91 | """ |
| 92 | # check inputs |
| 93 | context = x if context is None else context |
| 94 | b, n, c = x.size(0), self.num_heads, self.head_dim |
| 95 | |
| 96 | # compute query, key, value |
| 97 | q = self.q(x).view(b, -1, n, c) |
| 98 | k = self.k(context).view(b, -1, n, c) |
| 99 | v = self.v(context).view(b, -1, n, c) |
| 100 | |
| 101 | # attention bias |
| 102 | attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) |
| 103 | if pos_bias is not None: |
| 104 | attn_bias += pos_bias |
| 105 | if mask is not None: |
| 106 | assert mask.ndim in [2, 3] |
| 107 | mask = mask.view(b, 1, 1, |
| 108 | -1) if mask.ndim == 2 else mask.unsqueeze(1) |
| 109 | attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) |
| 110 | |
| 111 | # compute attention (T5 does not use scaling) |
| 112 | attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias |
| 113 | attn = F.softmax(attn.float(), dim=-1).type_as(attn) |
| 114 | x = torch.einsum('bnij,bjnc->binc', attn, v) |
| 115 | |
| 116 | # output |
| 117 | x = x.reshape(b, -1, n * c) |
| 118 | x = self.o(x) |
| 119 | x = self.dropout(x) |
| 120 | return x |
| 121 | |
| 122 | |
| 123 | class T5FeedForward(nn.Module): |