MCPcopy Index your code
hub / github.com/z-lab/dflash / rollback

Method rollback

dflash/model_mlx.py:374–397  ·  view source on GitHub ↗
(self, cache, accepted, trim)

Source from the content-addressed store, hash-verified

372 _GDN_PATCH_LOCK.release()
373
374 def rollback(self, cache, accepted, trim):
375 n_non_trimmable = sum(1 for c in cache if not c.is_trimmable())
376 assert n_non_trimmable == len(self._gdn_inputs), (
377 f"non-trimmable cache count ({n_non_trimmable}) != "
378 f"captured GDN inputs ({len(self._gdn_inputs)}); "
379 "DFlash MLX rollback assumes every non-trimmable cache is a GatedDeltaNet layer"
380 )
381 j = 0
382 for c in cache:
383 if c.is_trimmable():
384 c.trim(trim)
385 else:
386 q, k, v, a, b, A_log, dt_bias, init_state, mask = self._gdn_inputs[j]
387 n = accepted + 1
388 _, state = _gd_mod.gated_delta_update(
389 q[:, :n], k[:, :n], v[:, :n], a[:, :n], b[:, :n],
390 A_log, dt_bias, init_state,
391 None if mask is None else mask[:, :n],
392 use_kernel=True,
393 )
394 c.cache[1] = state
395 conv_input, K = self.conv_data[j]
396 c.cache[0] = conv_input[:, accepted + 1 : accepted + K]
397 j += 1
398
399
400@dataclass

Callers 1

stream_generateFunction · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected