(self, x:Tensor)
| 232 | return self.attn_output(attn) |
| 233 | |
| 234 | def _init_state(self, x:Tensor): |
| 235 | if not hasattr(self, "cache_k"): |
| 236 | self.cache_k = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank + self.config.rope_dim, device=x.device) |
| 237 | self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta, device=x.device) |
| 238 | |
| 239 | class GatedDeltaNetBlock(FFNBlock): |
| 240 | def __init__(self, config:TransformerConfig, ssm:SSMConfig): |
nothing calls this directly
no test coverage detected