| 291 | |
| 292 | |
| 293 | class _GDNStateCapture: |
| 294 | def __init__(self): |
| 295 | self.conv_data = [] |
| 296 | self._gdn_inputs = [] |
| 297 | self._gdn_cls = None |
| 298 | self._orig_call = None |
| 299 | self._patched_call = None |
| 300 | self._closed = False |
| 301 | _GDN_PATCH_LOCK.acquire() |
| 302 | try: |
| 303 | self._patch() |
| 304 | except Exception: |
| 305 | _GDN_PATCH_LOCK.release() |
| 306 | raise |
| 307 | |
| 308 | def _patch(self): |
| 309 | from mlx_lm.models.qwen3_5 import GatedDeltaNet |
| 310 | self._gdn_cls = GatedDeltaNet |
| 311 | self._orig_call = GatedDeltaNet.__call__ |
| 312 | capture = self |
| 313 | |
| 314 | def _capturing_gdn_call(self_layer, inputs, mask=None, cache=None): |
| 315 | B, S, _ = inputs.shape |
| 316 | if self_layer.sharding_group is not None: |
| 317 | from mlx_lm.models.qwen3_5 import sum_gradients |
| 318 | inputs = sum_gradients(self_layer.sharding_group)(inputs) |
| 319 | qkv = self_layer.in_proj_qkv(inputs) |
| 320 | z = self_layer.in_proj_z(inputs).reshape(B, S, self_layer.num_v_heads, self_layer.head_v_dim) |
| 321 | b, a = self_layer.in_proj_b(inputs), self_layer.in_proj_a(inputs) |
| 322 | conv_state = cache[0] if (cache is not None and cache[0] is not None) else mx.zeros((B, self_layer.conv_kernel_size - 1, self_layer.conv_dim), dtype=inputs.dtype) |
| 323 | if mask is not None: |
| 324 | qkv = mx.where(mask[..., None], qkv, 0) |
| 325 | conv_input = mx.concatenate([conv_state, qkv], axis=1) |
| 326 | capture.conv_data.append((conv_input, self_layer.conv_kernel_size)) |
| 327 | if cache is not None: |
| 328 | cache[0] = conv_input[:, -(self_layer.conv_kernel_size - 1):] |
| 329 | conv_out = nn.silu(self_layer.conv1d(conv_input)) |
| 330 | q, k, v = [ |
| 331 | t.reshape(B, S, h, d) |
| 332 | for t, h, d in zip( |
| 333 | mx.split(conv_out, [self_layer.key_dim, 2 * self_layer.key_dim], -1), |
| 334 | [self_layer.num_k_heads, self_layer.num_k_heads, self_layer.num_v_heads], |
| 335 | [self_layer.head_k_dim, self_layer.head_k_dim, self_layer.head_v_dim], |
| 336 | ) |
| 337 | ] |
| 338 | state = cache[1] if cache else None |
| 339 | inv_scale = k.shape[-1] ** -0.5 |
| 340 | q = (inv_scale ** 2) * mx.fast.rms_norm(q, None, 1e-6) |
| 341 | k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) |
| 342 | capture._gdn_inputs.append((q, k, v, a, b, self_layer.A_log, self_layer.dt_bias, state, mask)) |
| 343 | out, new_state = _gd_mod.gated_delta_update( |
| 344 | q, k, v, a, b, self_layer.A_log, self_layer.dt_bias, state, mask, use_kernel=True |
| 345 | ) |
| 346 | if cache is not None: |
| 347 | cache[1] = new_state |
| 348 | out = self_layer.norm(out, z) |
| 349 | out = self_layer.out_proj(out.reshape(B, S, -1)) |
| 350 | if self_layer.sharding_group is not None: |