(self, x:Tensor)
| 184 | return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid())) |
| 185 | |
| 186 | def _init_state(self, x:Tensor): |
| 187 | if not hasattr(self, "cache_kv"): |
| 188 | # TODO: how is the dtype of this determined? |
| 189 | self.cache_kv = Tensor.empty(2, x.shape[0], self.config.n_kv_heads, self.config.max_context, self.config.head_dim, device=x.device) |
| 190 | self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta, device=x.device) |
| 191 | |
| 192 | class MLATransformerBlock(FFNBlock): |
| 193 | def __init__(self, config:TransformerConfig): |
nothing calls this directly
no test coverage detected