| 28 | |
| 29 | @dataclass |
| 30 | class DFlashConfig: |
| 31 | hidden_size: int |
| 32 | num_hidden_layers: int |
| 33 | num_attention_heads: int |
| 34 | num_key_value_heads: int |
| 35 | head_dim: int |
| 36 | intermediate_size: int |
| 37 | vocab_size: int |
| 38 | rms_norm_eps: float |
| 39 | rope_theta: float |
| 40 | max_position_embeddings: int |
| 41 | block_size: int |
| 42 | target_layer_ids: Tuple[int, ...] |
| 43 | num_target_layers: int |
| 44 | mask_token_id: int = 0 |
| 45 | rope_scaling: Optional[Dict[str, Any]] = None |
| 46 | layer_types: Tuple[str, ...] = field(default_factory=tuple) |
| 47 | sliding_window: Optional[int] = None |
| 48 | final_logit_softcapping: Optional[float] = None |
| 49 | |
| 50 | |
| 51 | def _build_rope( |