feat(pld): hybrid partial-accept replay for SSM models (#134)#149
feat(pld): hybrid partial-accept replay for SSM models (#134)#149st-adam wants to merge 3 commits into
Conversation
|
Rebased onto current Verified post-rebase:
PR is |
|
Rebased onto v1.5.44 (upstream/main as of 2026-05-20). Resolved conflicts with hybrid TQ KV path commit (c61b6e4). Tests: 6/6 pass. |
|
This is the strongest issue-fix candidate in the current PR batch. Thank you for keeping it rebased and tying it to #134. I am not merging it directly into the active release-hardening tree yet because that tree is dirty with DSV4/cache/tool/settings work, and this touches
Keeping this marked as a candidate fix rather than a feature-only PR. |
|
I did one more concrete review pass because this is the strongest issue-fix candidate in the open PR batch. One thing that needs fixing before merge: Please retarget those tests so they execute the actual production method, or factor the replay helper into a small importable function that both scheduler and tests call. The current fake-cache approach is fine, but the assertion needs to cover the real code path that will ship. After that, I still want the isolated validation mentioned above: adjacent scheduler/cache tests plus a live hybrid SSM PLD proof showing replay attempts/emitted counters and no output regression. |
|
Reviewed for the current release-hardening pass. I am not merging/adapting this into the immediate release branch yet.\n\nThe problem statement is credible: hybrid SSM/attention PLD partial accept can lose accepted draft tokens when cache rollback has to rewind to N. Credit to @st-adam for isolating that design and proposing replay.\n\nReason held for this release:\n- this changes scheduler cache replay behavior in the same hot area as current prefix/paged/L2/hybrid-SSM release risks;\n- the submitted tests include standalone/copied replay logic rather than proving the production Scheduler path end-to-end;\n- no current live proof here shows a hybrid SSM model exercising 0 < num_accept < K replay with healthy follow-up generation and unchanged cache telemetry;\n- /health pld_ssm_replay would need source/API contract coverage before shipping.\n\nCurrent-source guard verification I ran instead: existing PLD/spec guard slice passed (7 passed, 581 deselected), including the PLD non-MLLM short-circuit, current speculative continuous-batching warnings, DSV4 runtime speculative suppression, and prompt-lookup docs matching current scheduler integration.\n\nLeaving open as future feature work; not safe to close as merged/adapted. |
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local copy of _replay_ssm_forward rather than the production code path, meaning tests could pass while the production method diverged silently. Fix: - New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream) - Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper - tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay directly — tests now exercise the actual production code path 6/6 tests passing. Addresses the merge blocker from the jjang-ai review. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Addressing review feedback — production code path now tested directly. The test-local copy of
If the production method diverges, the import itself will surface it (wrong signature → TypeError; wrong module path → ImportError). Tests: 6/6 pass (same assertions, now over real code). Re: isolated worktree validation + live PLD proof The live hybrid SSM proof (showing python tests/benchmark/test_pld_acceptance.py --port 8080
curl -s http://127.0.0.1:8080/health | python3 -m json.tool | grep -A5 pld_ssm_replayExpected after Task 5: |
|
Follow-up: special-token filter added. Latent issue identified during PR #150 review work: New commit on this branch:
Tests: 20/20 pass ( |
…jang-ai#134, jjang-ai#135) Flesh out _step_speculative with full in-batch verify, per-row accept/reject, and per-row rollback. PLD candidate source is now wired end-to-end: configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K candidates per row, _step_speculative builds the batched verify input and emits accepted drafts via response.extra_tokens. - Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row. - Single batched forward through self.language_model — same call signature as _step (verified V0.2/V0.4). - Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j]. - Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify; restore rejected rows via _restore_ssm_rows after accept/reject. - KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array. Handles both batched (offset is mx.array) and single-request (scalar) cases. Syncs _idx for RotatingKVCache compatibility. - Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid, drop accepted drafts and emit correction only (predicted at pos 0). Per-row replay forward to recover those drafts is a future enhancement (C.5) — would require per-row mini-batch + write-back into BatchMambaCache. - Set req.scratch_extra_tokens per request for downstream emission. - Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate. - Dispatch in _next() now routes PLD-enabled steps to _step_speculative with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice). Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1) takes precedence; PLD activates when only --enable-pld is set. - Response builder: appends primary token first then extras (correct chronological order: seed processed → accepted drafts → next-step bonus). Populates response.extra_tokens. Truncates on stop token or max_tok cap. 10 unit tests in tests/test_mllm_step_speculative.py covering: - Fallback paths (no drafts, prefill state, K=0) - Pure-attention: full accept / full reject / partial accept w/ correct KV offset rewind - Hybrid: full accept (no rollback) / partial accept (drop drafts + SSM restore) - Telemetry counter increments 69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required — tests use a mock language_model that simulates BatchKVCache offset advance. Remaining work for full PR jjang-ai#150 (next session): - Per-row replay forward for hybrid partial-accept (recover dropped drafts) - MLLMScheduler wiring to call configure_pld_spec() at startup - Live validation on Heretic2 workload - /health pld_ssm_replay telemetry passthrough Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
, jjang-ai#135) Hook up MLLMBatchGenerator's PLD speculative decoding path through the scheduler and the /health endpoint. - MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from tokenizer + processor. Returns None if no special IDs present. - MLLMScheduler.__init__: when config.pld_enabled AND batch generator is hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...). Failures logged but don't crash startup (falls back to standard decode). - Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N" at startup so the activation path is visible in server logs. - server.py /health endpoint: now probes pld_ssm_replay counters from THREE possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and exposes whichever has them. Also surfaces batched-spec counters (_spec_batched_steps, tokens, acceptance EMA) under speculative_decoding.batched. 69/69 tests still pass. Live validation on Heretic2 workload (next step): launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0 under load. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local copy of _replay_ssm_forward rather than the production code path, meaning tests could pass while the production method diverged silently. Fix: - New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream) - Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper - tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay directly — tests now exercise the actual production code path 6/6 tests passing. Addresses the merge blocker from the jjang-ai review. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…jang-ai#134, jjang-ai#135) Flesh out _step_speculative with full in-batch verify, per-row accept/reject, and per-row rollback. PLD candidate source is now wired end-to-end: configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K candidates per row, _step_speculative builds the batched verify input and emits accepted drafts via response.extra_tokens. - Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row. - Single batched forward through self.language_model — same call signature as _step (verified V0.2/V0.4). - Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j]. - Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify; restore rejected rows via _restore_ssm_rows after accept/reject. - KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array. Handles both batched (offset is mx.array) and single-request (scalar) cases. Syncs _idx for RotatingKVCache compatibility. - Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid, drop accepted drafts and emit correction only (predicted at pos 0). Per-row replay forward to recover those drafts is a future enhancement (C.5) — would require per-row mini-batch + write-back into BatchMambaCache. - Set req.scratch_extra_tokens per request for downstream emission. - Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate. - Dispatch in _next() now routes PLD-enabled steps to _step_speculative with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice). Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1) takes precedence; PLD activates when only --enable-pld is set. - Response builder: appends primary token first then extras (correct chronological order: seed processed → accepted drafts → next-step bonus). Populates response.extra_tokens. Truncates on stop token or max_tok cap. 10 unit tests in tests/test_mllm_step_speculative.py covering: - Fallback paths (no drafts, prefill state, K=0) - Pure-attention: full accept / full reject / partial accept w/ correct KV offset rewind - Hybrid: full accept (no rollback) / partial accept (drop drafts + SSM restore) - Telemetry counter increments 69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required — tests use a mock language_model that simulates BatchKVCache offset advance. Remaining work for full PR jjang-ai#150 (next session): - Per-row replay forward for hybrid partial-accept (recover dropped drafts) - MLLMScheduler wiring to call configure_pld_spec() at startup - Live validation on Heretic2 workload - /health pld_ssm_replay telemetry passthrough Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
, jjang-ai#135) Hook up MLLMBatchGenerator's PLD speculative decoding path through the scheduler and the /health endpoint. - MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from tokenizer + processor. Returns None if no special IDs present. - MLLMScheduler.__init__: when config.pld_enabled AND batch generator is hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...). Failures logged but don't crash startup (falls back to standard decode). - Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N" at startup so the activation path is visible in server logs. - server.py /health endpoint: now probes pld_ssm_replay counters from THREE possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and exposes whichever has them. Also surfaces batched-spec counters (_spec_batched_steps, tokens, acceptance EMA) under speculative_decoding.batched. 69/69 tests still pass. Live validation on Heretic2 workload (next step): launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0 under load. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…or truncate
Two changes from the live byte-equality validation cycle on smolvlm
(pure-attention MLLM, mlx-community/SmolVLM-Instruct-bf16):
1. **MLLM PLD default-off** (correctness gate). Live byte-equality test
showed PLD on pure-attention MLLM ALSO produces output drift at T=0
("TheTheThe number number is is"), not just hybrid. The PR jjang-ai#150
multi-token verify forward produces logits that don't byte-match
sequential single-token forwards for the SAME input position. Affects
both hybrid and pure-attention MLLM paths.
New gate: `VMLX_ENABLE_MLLM_PLD=1` (experimental opt-in for either
architecture). Old `VMLX_ENABLE_MLLM_PLD_HYBRID=1` kept for backward
compat. Without the flag, --enable-pld is silently ignored on MLLM
path (logs the reason).
Simple-engine PLD (PR jjang-ai#149's Scheduler path) is unaffected — PR jjang-ai#26's
+4-7% baseline holds for the non-MLLM hybrid case.
2. **Per-n_accept histogram** in /health. Diagnostic telemetry:
`speculative_decoding.batched.accept_histogram[n]` = count of rounds
where exactly n drafts accepted. Lets ops debug workload-specific
acceptance patterns (e.g. code shows partial accepts; JSON shows full
accepts).
3. **B=1 tensor truncate on pure-attention KV rewind**. When verify
advances cache to N+K+1 but rollback rewinds offset to N+1+n, the
keys/values tensor still has stale verify content at the trailing
positions. Some MLX attention paths may read the full tensor (not
just :offset). For B=1 we truncate the tensor; for B>1 we preserve
the multi-row structure (best-effort, attention mask must respect
per-row offset).
Tests: 95/95 pass.
Live validation (mlx-community/SmolVLM-Instruct-bf16, max_num_seqs=1):
- PLD off vs PLD on at T=0: 1/4 prompts byte-equal (the short one);
3/4 diverge. Confirms PLD on MLLM has unresolved correctness issue.
- With this commit, default behavior is PLD-off on MLLM (byte-equal
to PLD-off baseline by construction).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ests PR jjang-ai#172 Phase C + test layers 1/2/3/5. **TurboQuant cache short-circuit** (Phase C, vmlx_engine/mllm_batch_generator.py): PLD's per-row writeback doesn't preserve `left_padding` on TurboQuant KV caches, causing `c.extract(idx)` to crash on request completion (observed on smolvlm with `--kv-cache-quantization q4`). Detect TQ via `type(c).__name__ == "TurboQuantKVCache"` at the entry to `_step_speculative` and fall back to standard `_step`. Logs once at INFO. PLD + TQ + MLLM compatibility is deferred to a follow-up; short-term workaround: use `--kv-cache-quantization none/q8` for PLD. **Logit-equivalence tests (Layer 1, slow, real-model):** `tests/test_mllm_pld_logits_equivalence.py` — 3 tests on `mlx-community/Llama-3.2-1B-Instruct-4bit` (small pure-attention LLM, verified deterministic at T=0): - argmax match between multi-token and sequential forward - logits max-diff within 1e-2 FP tolerance - cache state (offset + last-position keys) matches All 3 PASS on Llama, proving the multi-token forward IS mathematically correct in mlx_lm. The smolvlm-specific divergence (PR jjang-ai#172 Phase B diagnostic) is therefore in mlx_vlm's wrapper layer, not mlx_lm core. **TQ + invariant + edge case unit tests (Layers 2/3/5, mock):** `tests/test_mllm_pld_tq_and_invariants.py` — 11 tests covering: - TQ cache detection short-circuits (no spec steps, no crash) - TQ skip log fires once - scratch_extra_tokens initialized to None - Cooldown decrement invariant + no underflow - Offset monotonicity under sequential verify - Acceptance histogram increments per row per step - Telemetry counters strictly monotonic - K=0 / B=0 / prefill-phase early returns **Live byte-equality on Llama-3.2 (PR jjang-ai#149 simple-engine path):** Discovered that PR jjang-ai#149's PLD path is NOT reachable via `vmlx serve` — both batched and non-batched modes route through different schedulers that bypass `Scheduler._try_speculative_decode`. PR jjang-ai#149's `pld_ssm_replay` counters remain 0 in all tested configurations. PR jjang-ai#149 was validated via mock unit tests; live activation needs a deeper investigation (likely needs `vmlx_engine` API call directly with `AsyncEngineCore(..., scheduler=Scheduler(...))`). **Total: 117 tests passing** (95 mock unit + 3 slow logit-equivalence + 11 new TQ/invariant + 8 new B>1 from PR jjang-ai#171). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local copy of _replay_ssm_forward rather than the production code path, meaning tests could pass while the production method diverged silently. Fix: - New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream) - Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper - tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay directly — tests now exercise the actual production code path 6/6 tests passing. Addresses the merge blocker from the jjang-ai review. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…jang-ai#134, jjang-ai#135) Flesh out _step_speculative with full in-batch verify, per-row accept/reject, and per-row rollback. PLD candidate source is now wired end-to-end: configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K candidates per row, _step_speculative builds the batched verify input and emits accepted drafts via response.extra_tokens. - Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row. - Single batched forward through self.language_model — same call signature as _step (verified V0.2/V0.4). - Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j]. - Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify; restore rejected rows via _restore_ssm_rows after accept/reject. - KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array. Handles both batched (offset is mx.array) and single-request (scalar) cases. Syncs _idx for RotatingKVCache compatibility. - Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid, drop accepted drafts and emit correction only (predicted at pos 0). Per-row replay forward to recover those drafts is a future enhancement (C.5) — would require per-row mini-batch + write-back into BatchMambaCache. - Set req.scratch_extra_tokens per request for downstream emission. - Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate. - Dispatch in _next() now routes PLD-enabled steps to _step_speculative with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice). Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1) takes precedence; PLD activates when only --enable-pld is set. - Response builder: appends primary token first then extras (correct chronological order: seed processed → accepted drafts → next-step bonus). Populates response.extra_tokens. Truncates on stop token or max_tok cap. 10 unit tests in tests/test_mllm_step_speculative.py covering: - Fallback paths (no drafts, prefill state, K=0) - Pure-attention: full accept / full reject / partial accept w/ correct KV offset rewind - Hybrid: full accept (no rollback) / partial accept (drop drafts + SSM restore) - Telemetry counter increments 69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required — tests use a mock language_model that simulates BatchKVCache offset advance. Remaining work for full PR jjang-ai#150 (next session): - Per-row replay forward for hybrid partial-accept (recover dropped drafts) - MLLMScheduler wiring to call configure_pld_spec() at startup - Live validation on Heretic2 workload - /health pld_ssm_replay telemetry passthrough Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
, jjang-ai#135) Hook up MLLMBatchGenerator's PLD speculative decoding path through the scheduler and the /health endpoint. - MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from tokenizer + processor. Returns None if no special IDs present. - MLLMScheduler.__init__: when config.pld_enabled AND batch generator is hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...). Failures logged but don't crash startup (falls back to standard decode). - Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N" at startup so the activation path is visible in server logs. - server.py /health endpoint: now probes pld_ssm_replay counters from THREE possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and exposes whichever has them. Also surfaces batched-spec counters (_spec_batched_steps, tokens, acceptance EMA) under speculative_decoding.batched. 69/69 tests still pass. Live validation on Heretic2 workload (next step): launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0 under load. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…jang-ai#153) Documents the batched-MLLM PLD path added by PRs jjang-ai#149/jjang-ai#150: - Dispatch path: when each of {draft-spec, PLD, standard _step} activates - In-batch verify mechanics (B, K+1) input shape - Per-row rollback variants (full accept / pure-attention reject / hybrid) - Hybrid Mamba limitation: parallel-scan kernel vs recurrent-replay state divergence; default-disabled with VMLX_ENABLE_MLLM_PLD_HYBRID=1 opt-in - Deferred attention-only verify proposal (issue jjang-ai#134 original design) with cost/benefit analysis and decision criteria for future implementers - Full env-var reference table (VMLX_DISABLE_PLD_REPLAY, VMLX_ENABLE_MLLM_PLD_HYBRID, VMLX_PLD_MIN_ACCEPTANCE, VMLX_PLD_WARMUP_STEPS, VMLX_PLD_PROBE_INTERVAL, VMLX_PLD_DEBUG) - Telemetry format on /health (pld_ssm_replay + speculative_decoding.batched) - 3-layer validation strategy: unit (~95 tests), live byte-equality script, live benchmark Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…or truncate
Two changes from the live byte-equality validation cycle on smolvlm
(pure-attention MLLM, mlx-community/SmolVLM-Instruct-bf16):
1. **MLLM PLD default-off** (correctness gate). Live byte-equality test
showed PLD on pure-attention MLLM ALSO produces output drift at T=0
("TheTheThe number number is is"), not just hybrid. The PR jjang-ai#150
multi-token verify forward produces logits that don't byte-match
sequential single-token forwards for the SAME input position. Affects
both hybrid and pure-attention MLLM paths.
New gate: `VMLX_ENABLE_MLLM_PLD=1` (experimental opt-in for either
architecture). Old `VMLX_ENABLE_MLLM_PLD_HYBRID=1` kept for backward
compat. Without the flag, --enable-pld is silently ignored on MLLM
path (logs the reason).
Simple-engine PLD (PR jjang-ai#149's Scheduler path) is unaffected — PR jjang-ai#26's
+4-7% baseline holds for the non-MLLM hybrid case.
2. **Per-n_accept histogram** in /health. Diagnostic telemetry:
`speculative_decoding.batched.accept_histogram[n]` = count of rounds
where exactly n drafts accepted. Lets ops debug workload-specific
acceptance patterns (e.g. code shows partial accepts; JSON shows full
accepts).
3. **B=1 tensor truncate on pure-attention KV rewind**. When verify
advances cache to N+K+1 but rollback rewinds offset to N+1+n, the
keys/values tensor still has stale verify content at the trailing
positions. Some MLX attention paths may read the full tensor (not
just :offset). For B=1 we truncate the tensor; for B>1 we preserve
the multi-row structure (best-effort, attention mask must respect
per-row offset).
Tests: 95/95 pass.
Live validation (mlx-community/SmolVLM-Instruct-bf16, max_num_seqs=1):
- PLD off vs PLD on at T=0: 1/4 prompts byte-equal (the short one);
3/4 diverge. Confirms PLD on MLLM has unresolved correctness issue.
- With this commit, default behavior is PLD-off on MLLM (byte-equal
to PLD-off baseline by construction).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ests PR jjang-ai#172 Phase C + test layers 1/2/3/5. **TurboQuant cache short-circuit** (Phase C, vmlx_engine/mllm_batch_generator.py): PLD's per-row writeback doesn't preserve `left_padding` on TurboQuant KV caches, causing `c.extract(idx)` to crash on request completion (observed on smolvlm with `--kv-cache-quantization q4`). Detect TQ via `type(c).__name__ == "TurboQuantKVCache"` at the entry to `_step_speculative` and fall back to standard `_step`. Logs once at INFO. PLD + TQ + MLLM compatibility is deferred to a follow-up; short-term workaround: use `--kv-cache-quantization none/q8` for PLD. **Logit-equivalence tests (Layer 1, slow, real-model):** `tests/test_mllm_pld_logits_equivalence.py` — 3 tests on `mlx-community/Llama-3.2-1B-Instruct-4bit` (small pure-attention LLM, verified deterministic at T=0): - argmax match between multi-token and sequential forward - logits max-diff within 1e-2 FP tolerance - cache state (offset + last-position keys) matches All 3 PASS on Llama, proving the multi-token forward IS mathematically correct in mlx_lm. The smolvlm-specific divergence (PR jjang-ai#172 Phase B diagnostic) is therefore in mlx_vlm's wrapper layer, not mlx_lm core. **TQ + invariant + edge case unit tests (Layers 2/3/5, mock):** `tests/test_mllm_pld_tq_and_invariants.py` — 11 tests covering: - TQ cache detection short-circuits (no spec steps, no crash) - TQ skip log fires once - scratch_extra_tokens initialized to None - Cooldown decrement invariant + no underflow - Offset monotonicity under sequential verify - Acceptance histogram increments per row per step - Telemetry counters strictly monotonic - K=0 / B=0 / prefill-phase early returns **Live byte-equality on Llama-3.2 (PR jjang-ai#149 simple-engine path):** Discovered that PR jjang-ai#149's PLD path is NOT reachable via `vmlx serve` — both batched and non-batched modes route through different schedulers that bypass `Scheduler._try_speculative_decode`. PR jjang-ai#149's `pld_ssm_replay` counters remain 0 in all tested configurations. PR jjang-ai#149 was validated via mock unit tests; live activation needs a deeper investigation (likely needs `vmlx_engine` API call directly with `AsyncEngineCore(..., scheduler=Scheduler(...))`). **Total: 117 tests passing** (95 mock unit + 3 slow logit-equivalence + 11 new TQ/invariant + 8 new B>1 from PR jjang-ai#171). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local copy of _replay_ssm_forward rather than the production code path, meaning tests could pass while the production method diverged silently. Fix: - New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream) - Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper - tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay directly — tests now exercise the actual production code path 6/6 tests passing. Addresses the merge blocker from the jjang-ai review. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…jang-ai#134, jjang-ai#135) Flesh out _step_speculative with full in-batch verify, per-row accept/reject, and per-row rollback. PLD candidate source is now wired end-to-end: configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K candidates per row, _step_speculative builds the batched verify input and emits accepted drafts via response.extra_tokens. - Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row. - Single batched forward through self.language_model — same call signature as _step (verified V0.2/V0.4). - Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j]. - Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify; restore rejected rows via _restore_ssm_rows after accept/reject. - KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array. Handles both batched (offset is mx.array) and single-request (scalar) cases. Syncs _idx for RotatingKVCache compatibility. - Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid, drop accepted drafts and emit correction only (predicted at pos 0). Per-row replay forward to recover those drafts is a future enhancement (C.5) — would require per-row mini-batch + write-back into BatchMambaCache. - Set req.scratch_extra_tokens per request for downstream emission. - Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate. - Dispatch in _next() now routes PLD-enabled steps to _step_speculative with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice). Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1) takes precedence; PLD activates when only --enable-pld is set. - Response builder: appends primary token first then extras (correct chronological order: seed processed → accepted drafts → next-step bonus). Populates response.extra_tokens. Truncates on stop token or max_tok cap. 10 unit tests in tests/test_mllm_step_speculative.py covering: - Fallback paths (no drafts, prefill state, K=0) - Pure-attention: full accept / full reject / partial accept w/ correct KV offset rewind - Hybrid: full accept (no rollback) / partial accept (drop drafts + SSM restore) - Telemetry counter increments 69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required — tests use a mock language_model that simulates BatchKVCache offset advance. Remaining work for full PR jjang-ai#150 (next session): - Per-row replay forward for hybrid partial-accept (recover dropped drafts) - MLLMScheduler wiring to call configure_pld_spec() at startup - Live validation on Heretic2 workload - /health pld_ssm_replay telemetry passthrough Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
, jjang-ai#135) Hook up MLLMBatchGenerator's PLD speculative decoding path through the scheduler and the /health endpoint. - MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from tokenizer + processor. Returns None if no special IDs present. - MLLMScheduler.__init__: when config.pld_enabled AND batch generator is hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...). Failures logged but don't crash startup (falls back to standard decode). - Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N" at startup so the activation path is visible in server logs. - server.py /health endpoint: now probes pld_ssm_replay counters from THREE possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and exposes whichever has them. Also surfaces batched-spec counters (_spec_batched_steps, tokens, acceptance EMA) under speculative_decoding.batched. 69/69 tests still pass. Live validation on Heretic2 workload (next step): launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0 under load. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local copy of _replay_ssm_forward rather than the production code path, meaning tests could pass while the production method diverged silently. Fix: - New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream) - Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper - tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay directly — tests now exercise the actual production code path 6/6 tests passing. Addresses the merge blocker from the jjang-ai review. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local copy of _replay_ssm_forward rather than the production code path, meaning tests could pass while the production method diverged silently. Fix: - New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream) - Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper - tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay directly — tests now exercise the actual production code path 6/6 tests passing. Addresses the merge blocker from the jjang-ai review. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…jang-ai#134, jjang-ai#135) Flesh out _step_speculative with full in-batch verify, per-row accept/reject, and per-row rollback. PLD candidate source is now wired end-to-end: configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K candidates per row, _step_speculative builds the batched verify input and emits accepted drafts via response.extra_tokens. - Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row. - Single batched forward through self.language_model — same call signature as _step (verified V0.2/V0.4). - Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j]. - Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify; restore rejected rows via _restore_ssm_rows after accept/reject. - KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array. Handles both batched (offset is mx.array) and single-request (scalar) cases. Syncs _idx for RotatingKVCache compatibility. - Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid, drop accepted drafts and emit correction only (predicted at pos 0). Per-row replay forward to recover those drafts is a future enhancement (C.5) — would require per-row mini-batch + write-back into BatchMambaCache. - Set req.scratch_extra_tokens per request for downstream emission. - Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate. - Dispatch in _next() now routes PLD-enabled steps to _step_speculative with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice). Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1) takes precedence; PLD activates when only --enable-pld is set. - Response builder: appends primary token first then extras (correct chronological order: seed processed → accepted drafts → next-step bonus). Populates response.extra_tokens. Truncates on stop token or max_tok cap. 10 unit tests in tests/test_mllm_step_speculative.py covering: - Fallback paths (no drafts, prefill state, K=0) - Pure-attention: full accept / full reject / partial accept w/ correct KV offset rewind - Hybrid: full accept (no rollback) / partial accept (drop drafts + SSM restore) - Telemetry counter increments 69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required — tests use a mock language_model that simulates BatchKVCache offset advance. Remaining work for full PR jjang-ai#150 (next session): - Per-row replay forward for hybrid partial-accept (recover dropped drafts) - MLLMScheduler wiring to call configure_pld_spec() at startup - Live validation on Heretic2 workload - /health pld_ssm_replay telemetry passthrough Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
, jjang-ai#135) Hook up MLLMBatchGenerator's PLD speculative decoding path through the scheduler and the /health endpoint. - MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from tokenizer + processor. Returns None if no special IDs present. - MLLMScheduler.__init__: when config.pld_enabled AND batch generator is hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...). Failures logged but don't crash startup (falls back to standard decode). - Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N" at startup so the activation path is visible in server logs. - server.py /health endpoint: now probes pld_ssm_replay counters from THREE possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and exposes whichever has them. Also surfaces batched-spec counters (_spec_batched_steps, tokens, acceptance EMA) under speculative_decoding.batched. 69/69 tests still pass. Live validation on Heretic2 workload (next step): launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0 under load. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…jang-ai#153) Documents the batched-MLLM PLD path added by PRs jjang-ai#149/jjang-ai#150: - Dispatch path: when each of {draft-spec, PLD, standard _step} activates - In-batch verify mechanics (B, K+1) input shape - Per-row rollback variants (full accept / pure-attention reject / hybrid) - Hybrid Mamba limitation: parallel-scan kernel vs recurrent-replay state divergence; default-disabled with VMLX_ENABLE_MLLM_PLD_HYBRID=1 opt-in - Deferred attention-only verify proposal (issue jjang-ai#134 original design) with cost/benefit analysis and decision criteria for future implementers - Full env-var reference table (VMLX_DISABLE_PLD_REPLAY, VMLX_ENABLE_MLLM_PLD_HYBRID, VMLX_PLD_MIN_ACCEPTANCE, VMLX_PLD_WARMUP_STEPS, VMLX_PLD_PROBE_INTERVAL, VMLX_PLD_DEBUG) - Telemetry format on /health (pld_ssm_replay + speculative_decoding.batched) - 3-layer validation strategy: unit (~95 tests), live byte-equality script, live benchmark Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…or truncate
Two changes from the live byte-equality validation cycle on smolvlm
(pure-attention MLLM, mlx-community/SmolVLM-Instruct-bf16):
1. **MLLM PLD default-off** (correctness gate). Live byte-equality test
showed PLD on pure-attention MLLM ALSO produces output drift at T=0
("TheTheThe number number is is"), not just hybrid. The PR jjang-ai#150
multi-token verify forward produces logits that don't byte-match
sequential single-token forwards for the SAME input position. Affects
both hybrid and pure-attention MLLM paths.
New gate: `VMLX_ENABLE_MLLM_PLD=1` (experimental opt-in for either
architecture). Old `VMLX_ENABLE_MLLM_PLD_HYBRID=1` kept for backward
compat. Without the flag, --enable-pld is silently ignored on MLLM
path (logs the reason).
Simple-engine PLD (PR jjang-ai#149's Scheduler path) is unaffected — PR jjang-ai#26's
+4-7% baseline holds for the non-MLLM hybrid case.
2. **Per-n_accept histogram** in /health. Diagnostic telemetry:
`speculative_decoding.batched.accept_histogram[n]` = count of rounds
where exactly n drafts accepted. Lets ops debug workload-specific
acceptance patterns (e.g. code shows partial accepts; JSON shows full
accepts).
3. **B=1 tensor truncate on pure-attention KV rewind**. When verify
advances cache to N+K+1 but rollback rewinds offset to N+1+n, the
keys/values tensor still has stale verify content at the trailing
positions. Some MLX attention paths may read the full tensor (not
just :offset). For B=1 we truncate the tensor; for B>1 we preserve
the multi-row structure (best-effort, attention mask must respect
per-row offset).
Tests: 95/95 pass.
Live validation (mlx-community/SmolVLM-Instruct-bf16, max_num_seqs=1):
- PLD off vs PLD on at T=0: 1/4 prompts byte-equal (the short one);
3/4 diverge. Confirms PLD on MLLM has unresolved correctness issue.
- With this commit, default behavior is PLD-off on MLLM (byte-equal
to PLD-off baseline by construction).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…i#134) On hybrid SSM/ATT models with 0 < num_accept < K, the PLD path previously discarded accepted drafts and emitted only a correction token. This PR adds _replay_ssm_forward() which restores caches to N, replays the accepted tokens through the model, and emits num_accept+1 tokens instead of 1. - New Scheduler._replay_ssm_forward() staticmethod (scheduler.py) - Modify case (b): try replay first, fall back to correction-only on failure - Add _pld_replay_{enabled,attempts,emitted,failures} counters - Add pld_ssm_replay telemetry to /health endpoint (server.py) - Document the fix in notes/prompt-lookup-decoding.md - 6 unit tests in tests/test_pld_ssm_replay.py - New partial_accept_stress benchmark in tests/benchmark/test_pld_acceptance.py Expected gain: +5-10% on top of PR jjang-ai#26's +4-7% on hybrid models. Disable: VMLX_DISABLE_PLD_REPLAY=1 Closes jjang-ai#134 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local copy of _replay_ssm_forward rather than the production code path, meaning tests could pass while the production method diverged silently. Fix: - New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream) - Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper - tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay directly — tests now exercise the actual production code path 6/6 tests passing. Addresses the merge blocker from the jjang-ai review. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ai#134) PLD n-gram lookup over the prompt could propose model-special tokens (image-pad, vision-start, pad) as drafts. The verify forward usually rejects them, but truncating drafts at the first excluded token avoids the risk entirely and saves a wasted verify pass. - prompt_lookup.find_draft_tokens(exclude_token_ids=None) added - NgramIndex.find_drafts(exclude_token_ids=None) added - New _truncate_at_excluded() helper - Scheduler.__init__ builds _pld_excluded_token_ids set from tokenizer/processor special-token attributes (pad, image-pad, vision markers, additional_special_tokens). EOS/BOS NOT excluded - those are legitimate end-of-decode signals already gated by verify. - Wired through both PLD call sites (lines 5193, 6626) - 14 unit tests in tests/test_prompt_lookup_filter.py Default behaviour unchanged when exclude_token_ids is None (back-compat). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…jang-ai#134, jjang-ai#135) Flesh out _step_speculative with full in-batch verify, per-row accept/reject, and per-row rollback. PLD candidate source is now wired end-to-end: configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K candidates per row, _step_speculative builds the batched verify input and emits accepted drafts via response.extra_tokens. - Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row. - Single batched forward through self.language_model — same call signature as _step (verified V0.2/V0.4). - Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j]. - Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify; restore rejected rows via _restore_ssm_rows after accept/reject. - KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array. Handles both batched (offset is mx.array) and single-request (scalar) cases. Syncs _idx for RotatingKVCache compatibility. - Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid, drop accepted drafts and emit correction only (predicted at pos 0). Per-row replay forward to recover those drafts is a future enhancement (C.5) — would require per-row mini-batch + write-back into BatchMambaCache. - Set req.scratch_extra_tokens per request for downstream emission. - Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate. - Dispatch in _next() now routes PLD-enabled steps to _step_speculative with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice). Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1) takes precedence; PLD activates when only --enable-pld is set. - Response builder: appends primary token first then extras (correct chronological order: seed processed → accepted drafts → next-step bonus). Populates response.extra_tokens. Truncates on stop token or max_tok cap. 10 unit tests in tests/test_mllm_step_speculative.py covering: - Fallback paths (no drafts, prefill state, K=0) - Pure-attention: full accept / full reject / partial accept w/ correct KV offset rewind - Hybrid: full accept (no rollback) / partial accept (drop drafts + SSM restore) - Telemetry counter increments 69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required — tests use a mock language_model that simulates BatchKVCache offset advance. Remaining work for full PR jjang-ai#150 (next session): - Per-row replay forward for hybrid partial-accept (recover dropped drafts) - MLLMScheduler wiring to call configure_pld_spec() at startup - Live validation on Heretic2 workload - /health pld_ssm_replay telemetry passthrough Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
, jjang-ai#135) Hook up MLLMBatchGenerator's PLD speculative decoding path through the scheduler and the /health endpoint. - MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from tokenizer + processor. Returns None if no special IDs present. - MLLMScheduler.__init__: when config.pld_enabled AND batch generator is hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...). Failures logged but don't crash startup (falls back to standard decode). - Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N" at startup so the activation path is visible in server logs. - server.py /health endpoint: now probes pld_ssm_replay counters from THREE possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and exposes whichever has them. Also surfaces batched-spec counters (_spec_batched_steps, tokens, acceptance EMA) under speculative_decoding.batched. 69/69 tests still pass. Live validation on Heretic2 workload (next step): launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0 under load. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…jang-ai#153) Documents the batched-MLLM PLD path added by PRs jjang-ai#149/jjang-ai#150: - Dispatch path: when each of {draft-spec, PLD, standard _step} activates - In-batch verify mechanics (B, K+1) input shape - Per-row rollback variants (full accept / pure-attention reject / hybrid) - Hybrid Mamba limitation: parallel-scan kernel vs recurrent-replay state divergence; default-disabled with VMLX_ENABLE_MLLM_PLD_HYBRID=1 opt-in - Deferred attention-only verify proposal (issue jjang-ai#134 original design) with cost/benefit analysis and decision criteria for future implementers - Full env-var reference table (VMLX_DISABLE_PLD_REPLAY, VMLX_ENABLE_MLLM_PLD_HYBRID, VMLX_PLD_MIN_ACCEPTANCE, VMLX_PLD_WARMUP_STEPS, VMLX_PLD_PROBE_INTERVAL, VMLX_PLD_DEBUG) - Telemetry format on /health (pld_ssm_replay + speculative_decoding.batched) - 3-layer validation strategy: unit (~95 tests), live byte-equality script, live benchmark Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…or truncate
Two changes from the live byte-equality validation cycle on smolvlm
(pure-attention MLLM, mlx-community/SmolVLM-Instruct-bf16):
1. **MLLM PLD default-off** (correctness gate). Live byte-equality test
showed PLD on pure-attention MLLM ALSO produces output drift at T=0
("TheTheThe number number is is"), not just hybrid. The PR jjang-ai#150
multi-token verify forward produces logits that don't byte-match
sequential single-token forwards for the SAME input position. Affects
both hybrid and pure-attention MLLM paths.
New gate: `VMLX_ENABLE_MLLM_PLD=1` (experimental opt-in for either
architecture). Old `VMLX_ENABLE_MLLM_PLD_HYBRID=1` kept for backward
compat. Without the flag, --enable-pld is silently ignored on MLLM
path (logs the reason).
Simple-engine PLD (PR jjang-ai#149's Scheduler path) is unaffected — PR jjang-ai#26's
+4-7% baseline holds for the non-MLLM hybrid case.
2. **Per-n_accept histogram** in /health. Diagnostic telemetry:
`speculative_decoding.batched.accept_histogram[n]` = count of rounds
where exactly n drafts accepted. Lets ops debug workload-specific
acceptance patterns (e.g. code shows partial accepts; JSON shows full
accepts).
3. **B=1 tensor truncate on pure-attention KV rewind**. When verify
advances cache to N+K+1 but rollback rewinds offset to N+1+n, the
keys/values tensor still has stale verify content at the trailing
positions. Some MLX attention paths may read the full tensor (not
just :offset). For B=1 we truncate the tensor; for B>1 we preserve
the multi-row structure (best-effort, attention mask must respect
per-row offset).
Tests: 95/95 pass.
Live validation (mlx-community/SmolVLM-Instruct-bf16, max_num_seqs=1):
- PLD off vs PLD on at T=0: 1/4 prompts byte-equal (the short one);
3/4 diverge. Confirms PLD on MLLM has unresolved correctness issue.
- With this commit, default behavior is PLD-off on MLLM (byte-equal
to PLD-off baseline by construction).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ests PR jjang-ai#172 Phase C + test layers 1/2/3/5. **TurboQuant cache short-circuit** (Phase C, vmlx_engine/mllm_batch_generator.py): PLD's per-row writeback doesn't preserve `left_padding` on TurboQuant KV caches, causing `c.extract(idx)` to crash on request completion (observed on smolvlm with `--kv-cache-quantization q4`). Detect TQ via `type(c).__name__ == "TurboQuantKVCache"` at the entry to `_step_speculative` and fall back to standard `_step`. Logs once at INFO. PLD + TQ + MLLM compatibility is deferred to a follow-up; short-term workaround: use `--kv-cache-quantization none/q8` for PLD. **Logit-equivalence tests (Layer 1, slow, real-model):** `tests/test_mllm_pld_logits_equivalence.py` — 3 tests on `mlx-community/Llama-3.2-1B-Instruct-4bit` (small pure-attention LLM, verified deterministic at T=0): - argmax match between multi-token and sequential forward - logits max-diff within 1e-2 FP tolerance - cache state (offset + last-position keys) matches All 3 PASS on Llama, proving the multi-token forward IS mathematically correct in mlx_lm. The smolvlm-specific divergence (PR jjang-ai#172 Phase B diagnostic) is therefore in mlx_vlm's wrapper layer, not mlx_lm core. **TQ + invariant + edge case unit tests (Layers 2/3/5, mock):** `tests/test_mllm_pld_tq_and_invariants.py` — 11 tests covering: - TQ cache detection short-circuits (no spec steps, no crash) - TQ skip log fires once - scratch_extra_tokens initialized to None - Cooldown decrement invariant + no underflow - Offset monotonicity under sequential verify - Acceptance histogram increments per row per step - Telemetry counters strictly monotonic - K=0 / B=0 / prefill-phase early returns **Live byte-equality on Llama-3.2 (PR jjang-ai#149 simple-engine path):** Discovered that PR jjang-ai#149's PLD path is NOT reachable via `vmlx serve` — both batched and non-batched modes route through different schedulers that bypass `Scheduler._try_speculative_decode`. PR jjang-ai#149's `pld_ssm_replay` counters remain 0 in all tested configurations. PR jjang-ai#149 was validated via mock unit tests; live activation needs a deeper investigation (likely needs `vmlx_engine` API call directly with `AsyncEngineCore(..., scheduler=Scheduler(...))`). **Total: 117 tests passing** (95 mock unit + 3 slow logit-equivalence + 11 new TQ/invariant + 8 new B>1 from PR jjang-ai#171). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
0 < num_accept < Kin PLD partial-reject pathScheduler._replay_ssm_forward()to restore caches to N, replay accepted tokens, advance to N+K', emit K'+1 tokens instead of 1VMLX_DISABLE_PLD_REPLAY=1/healthfieldpld_ssm_replay.{enabled,attempts,emitted,failures}tests/test_pld_ssm_replay.pyProblem
PR #26 PLD on hybrid models (48 GatedDeltaNet + 16 full-attention): with K=2, a partial accept (
num_accept=1) still emits only 1 correction token because SSM state cannot be trimmed — both caches must rewind to N.Solution
After partial rejection, restore to N, replay
drafts[:num_accept]forward through the full model. Both caches reach N+num_accept. Emitdrafts[:num_accept] + [bonus_token]— same as the full-accept path, minus the extra K-K' tokens.Expected gain
+5-10% on top of PR #26's +4-7% on hybrid models. Full PLD target for hybrid moves from +4-7% toward the +15-25% cited in #134.
Test plan
pytest tests/test_pld_ssm_replay.py -v— 6 unit tests passpytest tests/test_ssm_companion_cache.py -v— existing tests unaffectedVMLX_DISABLE_PLD_REPLAY=1vs unset — byte-equal at T=0, higher tok/s unsetFixes #134
🤖 Generated with Claude Code