| 372 | _GDN_PATCH_LOCK.release() |
| 373 | |
| 374 | def rollback(self, cache, accepted, trim): |
| 375 | n_non_trimmable = sum(1 for c in cache if not c.is_trimmable()) |
| 376 | assert n_non_trimmable == len(self._gdn_inputs), ( |
| 377 | f"non-trimmable cache count ({n_non_trimmable}) != " |
| 378 | f"captured GDN inputs ({len(self._gdn_inputs)}); " |
| 379 | "DFlash MLX rollback assumes every non-trimmable cache is a GatedDeltaNet layer" |
| 380 | ) |
| 381 | j = 0 |
| 382 | for c in cache: |
| 383 | if c.is_trimmable(): |
| 384 | c.trim(trim) |
| 385 | else: |
| 386 | q, k, v, a, b, A_log, dt_bias, init_state, mask = self._gdn_inputs[j] |
| 387 | n = accepted + 1 |
| 388 | _, state = _gd_mod.gated_delta_update( |
| 389 | q[:, :n], k[:, :n], v[:, :n], a[:, :n], b[:, :n], |
| 390 | A_log, dt_bias, init_state, |
| 391 | None if mask is None else mask[:, :n], |
| 392 | use_kernel=True, |
| 393 | ) |
| 394 | c.cache[1] = state |
| 395 | conv_input, K = self.conv_data[j] |
| 396 | c.cache[0] = conv_input[:, accepted + 1 : accepted + K] |
| 397 | j += 1 |
| 398 | |
| 399 | |
| 400 | @dataclass |