Skip to content

Commit 8d232e7

Browse files
Adam Staniszewskiclaude
andcommitted
refactor(pld): extract replay_ssm_forward to importable utils module
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local copy of _replay_ssm_forward rather than the production code path, meaning tests could pass while the production method diverged silently. Fix: - New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream) - Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper - tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay directly — tests now exercise the actual production code path 6/6 tests passing. Addresses the merge blocker from the jjang-ai review. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 6eac716 commit 8d232e7

3 files changed

Lines changed: 116 additions & 114 deletions

File tree

tests/test_pld_ssm_replay.py

Lines changed: 4 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
44
Tests the hybrid partial-accept replay path without requiring real model
55
weights or a full mlx-lm/transformers environment. Uses _FakeSSMLayer /
6-
_FakeKVCache stubs and a mock model callable. The static method is
7-
replicated directly in this file to isolate the test from the full
8-
scheduler import chain.
6+
_FakeKVCache stubs and a mock model callable. Imports replay_ssm_forward
7+
from vmlx_engine.utils.pld_replay to test the real production code path.
98
109
Run:
1110
.venv/bin/python -m pytest tests/test_pld_ssm_replay.py -v
@@ -21,56 +20,8 @@
2120
import mlx.core as mx
2221

2322

24-
# ---------------------------------------------------------------------------
25-
# Standalone implementation of _replay_ssm_forward for testing
26-
# (mirrors the logic in Scheduler._replay_ssm_forward without importing
27-
# the full scheduler module which pulls in mlx_lm/transformers)
28-
# ---------------------------------------------------------------------------
29-
30-
def _replay_ssm_forward(model, kv_cache, saved_array_caches, accepted_tokens,
31-
pre_verify_offset):
32-
"""Test-local copy of Scheduler._replay_ssm_forward logic."""
33-
import numpy as _np_local
34-
35-
def _rewind_kv_to(kv_cache, target_offset):
36-
for c in kv_cache:
37-
if not c.is_trimmable() or c.offset == 0:
38-
continue
39-
if c.offset <= target_offset:
40-
continue
41-
if isinstance(c.keys, mx.array):
42-
_kd, _vd = c.keys.dtype, c.values.dtype
43-
_ka = c.keys.astype(mx.float16) if "bfloat16" in str(_kd) else c.keys
44-
_va = c.values.astype(mx.float16) if "bfloat16" in str(_vd) else c.values
45-
_k, _v = _np_local.array(_ka), _np_local.array(_va)
46-
c.keys = mx.array(_k[..., :target_offset, :]).astype(_kd)
47-
c.values = mx.array(_v[..., :target_offset, :]).astype(_vd)
48-
c.offset = target_offset
49-
if hasattr(c, "_idx"):
50-
c._idx = target_offset
51-
52-
try:
53-
for i, c in enumerate(kv_cache):
54-
if i in saved_array_caches:
55-
c.cache = saved_array_caches[i]
56-
_rewind_kv_to(kv_cache, pre_verify_offset)
57-
58-
replay_input = mx.array([accepted_tokens])
59-
_ = model(replay_input, cache=kv_cache)
60-
mx.eval(kv_cache)
61-
62-
return True
63-
64-
except Exception as exc:
65-
# Best-effort restore
66-
try:
67-
for i, c in enumerate(kv_cache):
68-
if i in saved_array_caches:
69-
c.cache = saved_array_caches[i]
70-
_rewind_kv_to(kv_cache, pre_verify_offset)
71-
except Exception:
72-
pass
73-
return False
23+
# Import the production replay helper directly — tests exercise the real code path.
24+
from vmlx_engine.utils.pld_replay import replay_ssm_forward as _replay_ssm_forward
7425

7526

7627
# ---------------------------------------------------------------------------

vmlx_engine/scheduler.py

Lines changed: 7 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3791,69 +3791,15 @@ def _replay_ssm_forward(
37913791
) -> bool:
37923792
"""Replay accepted_tokens through model to advance SSM+KV caches to N+K'.
37933793
3794-
After hybrid partial rejection, restores both caches to N, then runs a
3795-
single forward pass over accepted_tokens to reach N+num_accept. The logits
3796-
are discarded; only the cache side-effect matters.
3794+
Delegates to vmlx_engine.utils.pld_replay.replay_ssm_forward so the
3795+
production code path is directly importable and testable without pulling
3796+
in the full Scheduler module.
37973797
3798-
Returns True on success; False on failure (caches left at pre_verify_offset).
3798+
Returns True on success; False on failure (caches at pre_verify_offset).
37993799
"""
3800-
import mlx.core as mx
3801-
import numpy as _np_local
3802-
from mlx_lm.generate import generation_stream as _gen_stream
3803-
3804-
try:
3805-
from mlx_lm.models.cache import CacheList as _CL_inner
3806-
except ImportError:
3807-
_CL_inner = None
3808-
3809-
def _rewind_kv_to(kv_cache, target_offset):
3810-
for c in kv_cache:
3811-
if not c.is_trimmable() or c.offset == 0:
3812-
continue
3813-
if c.offset <= target_offset:
3814-
continue
3815-
if _CL_inner is not None and isinstance(c, _CL_inner):
3816-
c.trim(c.offset - target_offset)
3817-
continue
3818-
if isinstance(c.keys, mx.array):
3819-
_kd, _vd = c.keys.dtype, c.values.dtype
3820-
_ka = c.keys.astype(mx.float16) if "bfloat16" in str(_kd) else c.keys
3821-
_va = c.values.astype(mx.float16) if "bfloat16" in str(_vd) else c.values
3822-
_k, _v = _np_local.array(_ka), _np_local.array(_va)
3823-
c.keys = mx.array(_k[..., :target_offset, :]).astype(_kd)
3824-
c.values = mx.array(_v[..., :target_offset, :]).astype(_vd)
3825-
c.offset = target_offset
3826-
if hasattr(c, "_idx"):
3827-
c._idx = target_offset
3828-
3829-
try:
3830-
# 1. Restore ArraysCache layers to pre-verify snapshot
3831-
for i, c in enumerate(kv_cache):
3832-
if i in saved_array_caches:
3833-
c.cache = saved_array_caches[i]
3834-
3835-
# 2. Rewind KV layers to pre_verify_offset
3836-
_rewind_kv_to(kv_cache, pre_verify_offset)
3837-
3838-
# 3. Replay forward: shape (1, num_accept) — advances caches to N+num_accept
3839-
replay_input = mx.array([accepted_tokens])
3840-
with mx.stream(_gen_stream):
3841-
_ = model(replay_input, cache=kv_cache)
3842-
mx.eval(kv_cache)
3843-
3844-
return True
3845-
3846-
except Exception as exc:
3847-
logger.warning("[PLD-replay] SSM replay failed: %s", exc, exc_info=False)
3848-
# Best-effort restore: re-apply snapshot, re-rewind KV
3849-
try:
3850-
for i, c in enumerate(kv_cache):
3851-
if i in saved_array_caches:
3852-
c.cache = saved_array_caches[i]
3853-
_rewind_kv_to(kv_cache, pre_verify_offset)
3854-
except Exception:
3855-
pass
3856-
return False
3800+
from vmlx_engine.utils.pld_replay import replay_ssm_forward
3801+
return replay_ssm_forward(model, kv_cache, saved_array_caches,
3802+
accepted_tokens, pre_verify_offset)
38573803

38583804
def _extract_cache_states(self, raw_cache: List[Any]) -> List[Dict[str, Any]]:
38593805
"""

vmlx_engine/utils/pld_replay.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Hybrid SSM partial-accept replay helper — issue #134.
3+
4+
Factored out of Scheduler._replay_ssm_forward so tests can import and
5+
exercise the production code path directly without pulling in the full
6+
vmlx_engine.scheduler module.
7+
"""
8+
from __future__ import annotations
9+
10+
import contextlib
11+
import logging
12+
from typing import Dict, List
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
def replay_ssm_forward(
18+
model,
19+
kv_cache: list,
20+
saved_array_caches: Dict[int, list],
21+
accepted_tokens: List[int],
22+
pre_verify_offset: int,
23+
) -> bool:
24+
"""Replay accepted_tokens through model to advance SSM+KV caches to N+K'.
25+
26+
After hybrid partial rejection, restores both caches to N, then runs a
27+
single forward pass over accepted_tokens to reach N+num_accept. The logits
28+
are discarded; only the cache side-effect matters.
29+
30+
Args:
31+
model: The language model callable (model(input, cache=...) -> logits).
32+
kv_cache: Per-layer cache list (mix of SSM ArraysCache + KVCache).
33+
saved_array_caches: Snapshot dict {layer_idx: list_of_arrays} captured
34+
before the verify forward pass.
35+
accepted_tokens: Draft tokens that were accepted (length = num_accept).
36+
pre_verify_offset: KV offset N before the verify forward ran.
37+
38+
Returns:
39+
True on success (caches at N+num_accept).
40+
False on failure (caches restored to pre_verify_offset by except handler).
41+
"""
42+
import mlx.core as mx
43+
import numpy as _np_local
44+
45+
# Lazy import: generation_stream may not be available in minimal test envs.
46+
try:
47+
from mlx_lm.generate import generation_stream as _gen_stream
48+
_stream_ctx = mx.stream(_gen_stream)
49+
except Exception:
50+
_stream_ctx = contextlib.nullcontext()
51+
52+
# Lazy import: CacheList for RotatingKVCache-based lists.
53+
try:
54+
from mlx_lm.models.cache import CacheList as _CL
55+
except ImportError:
56+
_CL = None
57+
58+
def _rewind_kv_to(target_offset: int) -> None:
59+
for c in kv_cache:
60+
if not c.is_trimmable() or c.offset == 0:
61+
continue
62+
if c.offset <= target_offset:
63+
continue
64+
if _CL is not None and isinstance(c, _CL):
65+
c.trim(c.offset - target_offset)
66+
continue
67+
if isinstance(c.keys, mx.array):
68+
_kd, _vd = c.keys.dtype, c.values.dtype
69+
_ka = c.keys.astype(mx.float16) if "bfloat16" in str(_kd) else c.keys
70+
_va = c.values.astype(mx.float16) if "bfloat16" in str(_vd) else c.values
71+
_k, _v = _np_local.array(_ka), _np_local.array(_va)
72+
c.keys = mx.array(_k[..., :target_offset, :]).astype(_kd)
73+
c.values = mx.array(_v[..., :target_offset, :]).astype(_vd)
74+
c.offset = target_offset
75+
if hasattr(c, "_idx"):
76+
c._idx = target_offset
77+
78+
try:
79+
# 1. Restore ArraysCache layers to pre-verify snapshot
80+
for i, c in enumerate(kv_cache):
81+
if i in saved_array_caches:
82+
c.cache = saved_array_caches[i]
83+
84+
# 2. Rewind KV layers to pre_verify_offset
85+
_rewind_kv_to(pre_verify_offset)
86+
87+
# 3. Replay forward: shape (1, num_accept) — advances caches to N+num_accept
88+
replay_input = mx.array([accepted_tokens])
89+
with _stream_ctx:
90+
_ = model(replay_input, cache=kv_cache)
91+
mx.eval(kv_cache)
92+
93+
return True
94+
95+
except Exception as exc:
96+
logger.warning("[PLD-replay] SSM replay failed: %s", exc, exc_info=False)
97+
# Best-effort restore: re-apply snapshot, re-rewind KV
98+
try:
99+
for i, c in enumerate(kv_cache):
100+
if i in saved_array_caches:
101+
c.cache = saved_array_caches[i]
102+
_rewind_kv_to(pre_verify_offset)
103+
except Exception:
104+
pass
105+
return False

0 commit comments

Comments
 (0)