(
self, x: torch.Tensor, freqs_cis: torch.Tensor = None,
input_pos: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
)
| 205 | self.resid_dropout = nn.Dropout(config.resid_dropout_p) |
| 206 | |
| 207 | def forward( |
| 208 | self, x: torch.Tensor, freqs_cis: torch.Tensor = None, |
| 209 | input_pos: Optional[torch.Tensor] = None, |
| 210 | mask: Optional[torch.Tensor] = None |
| 211 | ): |
| 212 | bsz, seqlen, _ = x.shape |
| 213 | kv_size = self.n_kv_head * self.head_dim |
| 214 | xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) |
| 215 | |
| 216 | xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) |
| 217 | xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim) |
| 218 | xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim) |
| 219 | |
| 220 | xq = apply_rotary_emb(xq, freqs_cis) |
| 221 | xk = apply_rotary_emb(xk, freqs_cis) |
| 222 | |
| 223 | xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) |
| 224 | |
| 225 | if self.kv_cache is not None: |
| 226 | keys, values = self.kv_cache.update(input_pos, xk, xv) |
| 227 | else: |
| 228 | keys, values = xk, xv |
| 229 | keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1) |
| 230 | values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1) |
| 231 | |
| 232 | output = F.scaled_dot_product_attention( |
| 233 | xq, keys, values, |
| 234 | attn_mask=mask, |
| 235 | is_causal=True if mask is None else False, # is_causal=False is for KV cache |
| 236 | dropout_p=self.attn_dropout_p if self.training else 0) |
| 237 | |
| 238 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) |
| 239 | |
| 240 | output = self.resid_dropout(self.wo(output)) |
| 241 | return output |
| 242 | |
| 243 | |
| 244 | class TransformerBlock(nn.Module): |
nothing calls this directly
no test coverage detected