MCPcopy
hub / github.com/z-lab/dflash / _GDNStateCapture

Class _GDNStateCapture

dflash/model_mlx.py:293–397  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

291
292
293class _GDNStateCapture:
294 def __init__(self):
295 self.conv_data = []
296 self._gdn_inputs = []
297 self._gdn_cls = None
298 self._orig_call = None
299 self._patched_call = None
300 self._closed = False
301 _GDN_PATCH_LOCK.acquire()
302 try:
303 self._patch()
304 except Exception:
305 _GDN_PATCH_LOCK.release()
306 raise
307
308 def _patch(self):
309 from mlx_lm.models.qwen3_5 import GatedDeltaNet
310 self._gdn_cls = GatedDeltaNet
311 self._orig_call = GatedDeltaNet.__call__
312 capture = self
313
314 def _capturing_gdn_call(self_layer, inputs, mask=None, cache=None):
315 B, S, _ = inputs.shape
316 if self_layer.sharding_group is not None:
317 from mlx_lm.models.qwen3_5 import sum_gradients
318 inputs = sum_gradients(self_layer.sharding_group)(inputs)
319 qkv = self_layer.in_proj_qkv(inputs)
320 z = self_layer.in_proj_z(inputs).reshape(B, S, self_layer.num_v_heads, self_layer.head_v_dim)
321 b, a = self_layer.in_proj_b(inputs), self_layer.in_proj_a(inputs)
322 conv_state = cache[0] if (cache is not None and cache[0] is not None) else mx.zeros((B, self_layer.conv_kernel_size - 1, self_layer.conv_dim), dtype=inputs.dtype)
323 if mask is not None:
324 qkv = mx.where(mask[..., None], qkv, 0)
325 conv_input = mx.concatenate([conv_state, qkv], axis=1)
326 capture.conv_data.append((conv_input, self_layer.conv_kernel_size))
327 if cache is not None:
328 cache[0] = conv_input[:, -(self_layer.conv_kernel_size - 1):]
329 conv_out = nn.silu(self_layer.conv1d(conv_input))
330 q, k, v = [
331 t.reshape(B, S, h, d)
332 for t, h, d in zip(
333 mx.split(conv_out, [self_layer.key_dim, 2 * self_layer.key_dim], -1),
334 [self_layer.num_k_heads, self_layer.num_k_heads, self_layer.num_v_heads],
335 [self_layer.head_k_dim, self_layer.head_k_dim, self_layer.head_v_dim],
336 )
337 ]
338 state = cache[1] if cache else None
339 inv_scale = k.shape[-1] ** -0.5
340 q = (inv_scale ** 2) * mx.fast.rms_norm(q, None, 1e-6)
341 k = inv_scale * mx.fast.rms_norm(k, None, 1e-6)
342 capture._gdn_inputs.append((q, k, v, a, b, self_layer.A_log, self_layer.dt_bias, state, mask))
343 out, new_state = _gd_mod.gated_delta_update(
344 q, k, v, a, b, self_layer.A_log, self_layer.dt_bias, state, mask, use_kernel=True
345 )
346 if cache is not None:
347 cache[1] = new_state
348 out = self_layer.norm(out, z)
349 out = self_layer.out_proj(out.reshape(B, S, -1))
350 if self_layer.sharding_group is not None:

Callers 1

stream_generateFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected