Skip to content

Commit 95aef1b

Browse files
Adam Staniszewskiclaude
andcommitted
test(pld): TurboQuant short-circuit + logit-equivalence + invariant tests
PR jjang-ai#172 Phase C + test layers 1/2/3/5. **TurboQuant cache short-circuit** (Phase C, vmlx_engine/mllm_batch_generator.py): PLD's per-row writeback doesn't preserve `left_padding` on TurboQuant KV caches, causing `c.extract(idx)` to crash on request completion (observed on smolvlm with `--kv-cache-quantization q4`). Detect TQ via `type(c).__name__ == "TurboQuantKVCache"` at the entry to `_step_speculative` and fall back to standard `_step`. Logs once at INFO. PLD + TQ + MLLM compatibility is deferred to a follow-up; short-term workaround: use `--kv-cache-quantization none/q8` for PLD. **Logit-equivalence tests (Layer 1, slow, real-model):** `tests/test_mllm_pld_logits_equivalence.py` — 3 tests on `mlx-community/Llama-3.2-1B-Instruct-4bit` (small pure-attention LLM, verified deterministic at T=0): - argmax match between multi-token and sequential forward - logits max-diff within 1e-2 FP tolerance - cache state (offset + last-position keys) matches All 3 PASS on Llama, proving the multi-token forward IS mathematically correct in mlx_lm. The smolvlm-specific divergence (PR jjang-ai#172 Phase B diagnostic) is therefore in mlx_vlm's wrapper layer, not mlx_lm core. **TQ + invariant + edge case unit tests (Layers 2/3/5, mock):** `tests/test_mllm_pld_tq_and_invariants.py` — 11 tests covering: - TQ cache detection short-circuits (no spec steps, no crash) - TQ skip log fires once - scratch_extra_tokens initialized to None - Cooldown decrement invariant + no underflow - Offset monotonicity under sequential verify - Acceptance histogram increments per row per step - Telemetry counters strictly monotonic - K=0 / B=0 / prefill-phase early returns **Live byte-equality on Llama-3.2 (PR jjang-ai#149 simple-engine path):** Discovered that PR jjang-ai#149's PLD path is NOT reachable via `vmlx serve` — both batched and non-batched modes route through different schedulers that bypass `Scheduler._try_speculative_decode`. PR jjang-ai#149's `pld_ssm_replay` counters remain 0 in all tested configurations. PR jjang-ai#149 was validated via mock unit tests; live activation needs a deeper investigation (likely needs `vmlx_engine` API call directly with `AsyncEngineCore(..., scheduler=Scheduler(...))`). **Total: 117 tests passing** (95 mock unit + 3 slow logit-equivalence + 11 new TQ/invariant + 8 new B>1 from PR jjang-ai#171). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3b04068 commit 95aef1b

3 files changed

Lines changed: 511 additions & 0 deletions

File tree

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Logit-equivalence tests for sequential vs multi-token verify (PR #172 Layer 1).
3+
4+
Pins the root cause of MLLM PLD output drift: multi-token forward produces
5+
different logits than K+1 sequential single-token forwards at the same input
6+
positions. The sequential-verify fix in `_step_speculative` (PR #172) trades
7+
K+1× kernel-launch cost for byte-equivalence with standalone decode.
8+
9+
These tests load a real model (small pure-attention LLM) and directly
10+
compare logits position-by-position. Marked @pytest.mark.slow; require
11+
model to be cached or downloadable.
12+
13+
Run:
14+
.venv/bin/python -m pytest tests/test_mllm_pld_logits_equivalence.py \\
15+
-v --run-slow
16+
"""
17+
18+
from __future__ import annotations
19+
20+
import copy
21+
22+
import mlx.core as mx
23+
import pytest
24+
25+
26+
# Small pure-attention LLM. mlx-community/Llama-3.2-1B-Instruct-4bit is the
27+
# same model used by tests/test_batching_deterministic.py (proven deterministic
28+
# at T=0 in vmlx). 600 MB cached.
29+
TEST_MODEL = "mlx-community/Llama-3.2-1B-Instruct-4bit"
30+
31+
32+
@pytest.fixture(scope="module")
33+
def model_and_tokenizer():
34+
try:
35+
from mlx_lm import load
36+
model, tokenizer = load(TEST_MODEL)
37+
return model, tokenizer
38+
except Exception as e:
39+
pytest.skip(f"Could not load model {TEST_MODEL}: {e}")
40+
41+
42+
def _prefill_and_cache(model, prompt_ids):
43+
"""Prefill the model on prompt_ids, return the cache."""
44+
if hasattr(model, "make_cache"):
45+
cache = model.make_cache()
46+
else:
47+
from mlx_lm.models.cache import KVCache
48+
n_layers = (
49+
len(model.layers) if hasattr(model, "layers") else
50+
len(model.model.layers) if hasattr(model, "model") else 1
51+
)
52+
cache = [KVCache() for _ in range(n_layers)]
53+
prefill_input = mx.array([prompt_ids])
54+
_ = model(prefill_input, cache=cache)
55+
mx.eval([c.keys if hasattr(c, "keys") and c.keys is not None else c for c in cache])
56+
return cache
57+
58+
59+
@pytest.mark.slow
60+
class TestLogitsEquivalence:
61+
"""Validate that sequential K+1 single-token forwards produce equivalent
62+
output to a single multi-token forward of shape (B, K+1).
63+
64+
NOTE: On some models (e.g., smolvlm), these tests FAIL — that's the bug
65+
PR #172's sequential verify works around. The tests document the
66+
expected behaviour and serve as regression detection if mlx_lm
67+
upstream is fixed.
68+
"""
69+
70+
def test_sequential_matches_multi_token_via_argmax(self, model_and_tokenizer):
71+
"""argmax token at each position must match between multi-token
72+
and sequential forwards."""
73+
model, tokenizer = model_and_tokenizer
74+
prompt = "The quick brown fox"
75+
prompt_ids = tokenizer.encode(prompt) if hasattr(tokenizer, "encode") else list(
76+
tokenizer(prompt).input_ids
77+
)
78+
79+
cache_for_prefill = _prefill_and_cache(model, prompt_ids)
80+
81+
# 3 draft tokens (ASCII range, safe for most tokenizers)
82+
drafts = [100, 101, 102]
83+
84+
# Path 1: multi-token forward
85+
cache_M = copy.deepcopy(cache_for_prefill)
86+
out_M = model(mx.array([drafts]), cache=cache_M)
87+
logits_M = out_M.logits if hasattr(out_M, "logits") else out_M
88+
mx.eval(logits_M)
89+
assert logits_M.shape[0] == 1
90+
assert logits_M.shape[1] == 3, f"expected T=3, got {logits_M.shape}"
91+
92+
# Path 2: sequential single-token forwards
93+
cache_S = copy.deepcopy(cache_for_prefill)
94+
seq_logits = []
95+
for t in drafts:
96+
out_S = model(mx.array([[t]]), cache=cache_S)
97+
logits_t = out_S.logits if hasattr(out_S, "logits") else out_S
98+
mx.eval(logits_t)
99+
seq_logits.append(logits_t[:, -1:, :])
100+
logits_S = mx.concatenate(seq_logits, axis=1)
101+
assert logits_S.shape[1] == 3
102+
103+
# Compare argmax per position
104+
for j in range(3):
105+
argmax_M = int(mx.argmax(logits_M[:, j, :], axis=-1).item())
106+
argmax_S = int(mx.argmax(logits_S[:, j, :], axis=-1).item())
107+
# On Llama-3.2 we expect these to match (deterministic at T=0).
108+
# On smolvlm they don't — bug documented in PR #172.
109+
assert argmax_M == argmax_S, (
110+
f"argmax mismatch at pos {j}: M={argmax_M} S={argmax_S}. "
111+
f"Sequential verify is the workaround."
112+
)
113+
114+
def test_sequential_matches_multi_token_via_logits_tolerance(self, model_and_tokenizer):
115+
"""logits should match within FP tolerance (1e-3 ABS).
116+
117+
On smolvlm this fails (max_diff up to 0.375). On Llama-3.2 it should
118+
pass — meaning Llama's multi-token forward IS correct, and the
119+
smolvlm-specific bug is in mlx_vlm's wrapper, not mlx_lm core.
120+
"""
121+
model, tokenizer = model_and_tokenizer
122+
prompt = "The quick brown fox"
123+
prompt_ids = tokenizer.encode(prompt) if hasattr(tokenizer, "encode") else list(
124+
tokenizer(prompt).input_ids
125+
)
126+
cache_for_prefill = _prefill_and_cache(model, prompt_ids)
127+
drafts = [100, 101, 102]
128+
129+
cache_M = copy.deepcopy(cache_for_prefill)
130+
out_M = model(mx.array([drafts]), cache=cache_M)
131+
logits_M = out_M.logits if hasattr(out_M, "logits") else out_M
132+
mx.eval(logits_M)
133+
134+
cache_S = copy.deepcopy(cache_for_prefill)
135+
seq_logits = []
136+
for t in drafts:
137+
out_S = model(mx.array([[t]]), cache=cache_S)
138+
logits_t = out_S.logits if hasattr(out_S, "logits") else out_S
139+
mx.eval(logits_t)
140+
seq_logits.append(logits_t[:, -1:, :])
141+
logits_S = mx.concatenate(seq_logits, axis=1)
142+
143+
max_diff_overall = 0.0
144+
per_pos_diffs = []
145+
for j in range(3):
146+
d = mx.abs(logits_M[:, j, :] - logits_S[:, j, :]).max().item()
147+
per_pos_diffs.append(d)
148+
max_diff_overall = max(max_diff_overall, d)
149+
150+
# Tolerance: 1e-2 covers normal MLX FP variance on Apple Silicon.
151+
# If this fails, the model has a known multi-token bug → use
152+
# sequential verify as workaround.
153+
assert max_diff_overall < 1e-2, (
154+
f"Multi-token logits diverge from sequential: "
155+
f"max_diff={max_diff_overall:.6e}, per_pos={per_pos_diffs}. "
156+
f"Sequential verify is the correctness workaround."
157+
)
158+
159+
def test_sequential_matches_multi_token_cache_state(self, model_and_tokenizer):
160+
"""Ending cache state (keys at last position) should match between
161+
multi-token and sequential paths."""
162+
model, tokenizer = model_and_tokenizer
163+
prompt_ids = tokenizer.encode("Hello world") if hasattr(tokenizer, "encode") else [1, 2]
164+
cache_for_prefill = _prefill_and_cache(model, prompt_ids)
165+
drafts = [100, 101, 102]
166+
167+
cache_M = copy.deepcopy(cache_for_prefill)
168+
_ = model(mx.array([drafts]), cache=cache_M)
169+
mx.eval(cache_M[0].keys if hasattr(cache_M[0], "keys") else cache_M[0])
170+
171+
cache_S = copy.deepcopy(cache_for_prefill)
172+
for t in drafts:
173+
_ = model(mx.array([[t]]), cache=cache_S)
174+
mx.eval(cache_S[0].keys if hasattr(cache_S[0], "keys") else cache_S[0])
175+
176+
# Compare final offset
177+
if hasattr(cache_M[0], "offset") and hasattr(cache_S[0], "offset"):
178+
off_M = int(cache_M[0].offset) if not isinstance(cache_M[0].offset, mx.array) else int(cache_M[0].offset.item())
179+
off_S = int(cache_S[0].offset) if not isinstance(cache_S[0].offset, mx.array) else int(cache_S[0].offset.item())
180+
assert off_M == off_S, f"cache offsets differ: M={off_M} S={off_S}"
181+
182+
# Compare last-position keys
183+
if hasattr(cache_M[0], "keys") and cache_M[0].keys is not None:
184+
last_M = cache_M[0].keys[..., -1:, :]
185+
last_S = cache_S[0].keys[..., -1:, :]
186+
key_diff = mx.abs(last_M - last_S).max().item()
187+
# Expect <= 1e-3 even if logits diverge; cache writes should match
188+
assert key_diff < 1e-3, f"cache key writes diverge: max_diff={key_diff}"

0 commit comments

Comments
 (0)