Skip to content

feat(pld): hybrid partial-accept replay for SSM models (#134)#149

Open
st-adam wants to merge 3 commits into
jjang-ai:mainfrom
st-adam:pld-ssm-replay
Open

feat(pld): hybrid partial-accept replay for SSM models (#134)#149
st-adam wants to merge 3 commits into
jjang-ai:mainfrom
st-adam:pld-ssm-replay

Conversation

@st-adam

@st-adam st-adam commented May 7, 2026

Copy link
Copy Markdown

Summary

  • Fixes the token loss on hybrid SSM/ATT models when 0 < num_accept < K in PLD partial-reject path
  • Adds Scheduler._replay_ssm_forward() to restore caches to N, replay accepted tokens, advance to N+K', emit K'+1 tokens instead of 1
  • Default ON; opt-out: VMLX_DISABLE_PLD_REPLAY=1
  • New /health field pld_ssm_replay.{enabled,attempts,emitted,failures}
  • 6 unit tests in tests/test_pld_ssm_replay.py

Problem

PR #26 PLD on hybrid models (48 GatedDeltaNet + 16 full-attention): with K=2, a partial accept (num_accept=1) still emits only 1 correction token because SSM state cannot be trimmed — both caches must rewind to N.

Solution

After partial rejection, restore to N, replay drafts[:num_accept] forward through the full model. Both caches reach N+num_accept. Emit drafts[:num_accept] + [bonus_token] — same as the full-accept path, minus the extra K-K' tokens.

Expected gain

+5-10% on top of PR #26's +4-7% on hybrid models. Full PLD target for hybrid moves from +4-7% toward the +15-25% cited in #134.

Test plan

  • pytest tests/test_pld_ssm_replay.py -v — 6 unit tests pass
  • pytest tests/test_ssm_companion_cache.py -v — existing tests unaffected
  • Live model: VMLX_DISABLE_PLD_REPLAY=1 vs unset — byte-equal at T=0, higher tok/s unset

Fixes #134

🤖 Generated with Claude Code

@st-adam

st-adam commented May 12, 2026

Copy link
Copy Markdown
Author

Rebased onto current main (1.5.32 series, base commit 9cfbeb24) per @jjang-ai's note on #134 ("PRs in most recent stable versions"). No code changes — clean rebase over the 50 intervening commits.

Verified post-rebase:

  • tests/test_pld_ssm_replay.py — 6/6 passing
  • vmlx_engine.scheduler imports cleanly; Scheduler._replay_ssm_forward present
  • Upstream's memory_limits refactor (v1.5.31) integrated cleanly into our scheduler.py touch points — no shadowing of get_metal_ws_guard_threshold / get_effective_metal_working_set_bytes

PR is CLEAN / MERGEABLE against the 1.5.32 base.

@st-adam

st-adam commented May 20, 2026

Copy link
Copy Markdown
Author

Rebased onto v1.5.44 (upstream/main as of 2026-05-20). Resolved conflicts with hybrid TQ KV path commit (c61b6e4). Tests: 6/6 pass.

@jjang-ai

Copy link
Copy Markdown
Owner

This is the strongest issue-fix candidate in the current PR batch. Thank you for keeping it rebased and tying it to #134.

I am not merging it directly into the active release-hardening tree yet because that tree is dirty with DSV4/cache/tool/settings work, and this touches vmlx_engine/scheduler.py plus /health telemetry. The right next step is an isolated worktree validation:

  • re-check the current scheduler/cache interaction against hybrid SSM/TQ paths;
  • run the PR unit tests plus adjacent scheduler/cache tests;
  • run a live hybrid SSM PLD proof showing partial-accept replay counters and no output regression;
  • only then pull it into the release lane.

Keeping this marked as a candidate fix rather than a feature-only PR.

@jjang-ai

Copy link
Copy Markdown
Owner

I did one more concrete review pass because this is the strongest issue-fix candidate in the open PR batch.

One thing that needs fixing before merge: tests/test_pld_ssm_replay.py copies a local _replay_ssm_forward() implementation instead of importing/exercising Scheduler._replay_ssm_forward. That means the tests can pass while the production scheduler method diverges or is wired incorrectly.

Please retarget those tests so they execute the actual production method, or factor the replay helper into a small importable function that both scheduler and tests call. The current fake-cache approach is fine, but the assertion needs to cover the real code path that will ship.

After that, I still want the isolated validation mentioned above: adjacent scheduler/cache tests plus a live hybrid SSM PLD proof showing replay attempts/emitted counters and no output regression.

@jjang-ai

Copy link
Copy Markdown
Owner

Reviewed for the current release-hardening pass. I am not merging/adapting this into the immediate release branch yet.\n\nThe problem statement is credible: hybrid SSM/attention PLD partial accept can lose accepted draft tokens when cache rollback has to rewind to N. Credit to @st-adam for isolating that design and proposing replay.\n\nReason held for this release:\n- this changes scheduler cache replay behavior in the same hot area as current prefix/paged/L2/hybrid-SSM release risks;\n- the submitted tests include standalone/copied replay logic rather than proving the production Scheduler path end-to-end;\n- no current live proof here shows a hybrid SSM model exercising 0 < num_accept < K replay with healthy follow-up generation and unchanged cache telemetry;\n- /health pld_ssm_replay would need source/API contract coverage before shipping.\n\nCurrent-source guard verification I ran instead: existing PLD/spec guard slice passed (7 passed, 581 deselected), including the PLD non-MLLM short-circuit, current speculative continuous-batching warnings, DSV4 runtime speculative suppression, and prompt-lookup docs matching current scheduler integration.\n\nLeaving open as future feature work; not safe to close as merged/adapted.

st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 22, 2026
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local
copy of _replay_ssm_forward rather than the production code path, meaning
tests could pass while the production method diverged silently.

Fix:
- New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with
  minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream)
- Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper
- tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay
  directly — tests now exercise the actual production code path

6/6 tests passing. Addresses the merge blocker from the jjang-ai review.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@st-adam

st-adam commented May 22, 2026

Copy link
Copy Markdown
Author

Addressing review feedback — production code path now tested directly.

The test-local copy of _replay_ssm_forward has been removed. The fix:

  • New vmlx_engine/utils/pld_replay.py — canonical replay_ssm_forward() with minimal deps (lazy mlx_lm imports, contextlib.nullcontext fallback if generation_stream unavailable in test env)
  • Scheduler._replay_ssm_forward is now a 2-line delegation wrapper calling replay_ssm_forward from the utils module
  • tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay import replay_ssm_forward — tests exercise the real production code path

If the production method diverges, the import itself will surface it (wrong signature → TypeError; wrong module path → ImportError).

Tests: 6/6 pass (same assertions, now over real code).


Re: isolated worktree validation + live PLD proof

The live hybrid SSM proof (showing pld_ssm_replay.attempts/emitted counters) requires a server started with a hybrid SSM model (e.g. Qwen3.6-27B). That cannot be automated here — it needs the target hardware and model weights. tests/benchmark/test_pld_acceptance.py includes the partial_accept_stress task (Task 5) specifically for this; the check is:

python tests/benchmark/test_pld_acceptance.py --port 8080
curl -s http://127.0.0.1:8080/health | python3 -m json.tool | grep -A5 pld_ssm_replay

Expected after Task 5: pld_ssm_replay.attempts > 0, emitted > attempts (each replay emits K'+1 tokens, not 1). Happy to coordinate on hardware access if that helps move the validation forward.

@st-adam

st-adam commented May 22, 2026

Copy link
Copy Markdown
Author

Follow-up: special-token filter added.

Latent issue identified during PR #150 review work: prompt_lookup.find_draft_tokens() and NgramIndex.find_drafts() had no special-token filtering. n-gram lookup over the prompt could propose tokens like pad/image-pad/vision-start as drafts. Verify usually rejects, but truncating at the first excluded ID is cheaper and removes the risk entirely.

New commit on this branch:

  • prompt_lookup._truncate_at_excluded() — helper
  • find_draft_tokens(exclude_token_ids=None) — backward-compatible param
  • NgramIndex.find_drafts(exclude_token_ids=None) — same
  • Scheduler._build_pld_excluded_token_ids() — collects IDs from tokenizer/processor (pad_token_id, image_token_id, image_pad_id, vision_start_token_id, vision_end_token_id, additional_special_tokens_ids)
  • Wired through both PLD call sites (scheduler.py:5193 and :6626)
  • EOS/BOS intentionally NOT excluded — legitimate end-of-decode signals already gated by verify

Tests: 20/20 pass (tests/test_prompt_lookup_filter.py 14 new + test_pld_ssm_replay.py 6 existing).

st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 22, 2026
…jang-ai#134, jjang-ai#135)

Flesh out _step_speculative with full in-batch verify, per-row accept/reject,
and per-row rollback. PLD candidate source is now wired end-to-end:
configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K
candidates per row, _step_speculative builds the batched verify input and
emits accepted drafts via response.extra_tokens.

- Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row.
- Single batched forward through self.language_model — same call signature
  as _step (verified V0.2/V0.4).
- Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j].
- Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify;
  restore rejected rows via _restore_ssm_rows after accept/reject.
- KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array.
  Handles both batched (offset is mx.array) and single-request (scalar)
  cases. Syncs _idx for RotatingKVCache compatibility.
- Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid,
  drop accepted drafts and emit correction only (predicted at pos 0).
  Per-row replay forward to recover those drafts is a future enhancement
  (C.5) — would require per-row mini-batch + write-back into BatchMambaCache.
- Set req.scratch_extra_tokens per request for downstream emission.
- Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate.
- Dispatch in _next() now routes PLD-enabled steps to _step_speculative
  with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice).
  Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1)
  takes precedence; PLD activates when only --enable-pld is set.
- Response builder: appends primary token first then extras (correct
  chronological order: seed processed → accepted drafts → next-step bonus).
  Populates response.extra_tokens. Truncates on stop token or max_tok cap.

10 unit tests in tests/test_mllm_step_speculative.py covering:
  - Fallback paths (no drafts, prefill state, K=0)
  - Pure-attention: full accept / full reject / partial accept w/ correct
    KV offset rewind
  - Hybrid: full accept (no rollback) / partial accept (drop drafts +
    SSM restore)
  - Telemetry counter increments

69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required —
tests use a mock language_model that simulates BatchKVCache offset advance.

Remaining work for full PR jjang-ai#150 (next session):
  - Per-row replay forward for hybrid partial-accept (recover dropped drafts)
  - MLLMScheduler wiring to call configure_pld_spec() at startup
  - Live validation on Heretic2 workload
  - /health pld_ssm_replay telemetry passthrough

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 22, 2026
, jjang-ai#135)

Hook up MLLMBatchGenerator's PLD speculative decoding path through the
scheduler and the /health endpoint.

- MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version
  from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from
  tokenizer + processor. Returns None if no special IDs present.
- MLLMScheduler.__init__: when config.pld_enabled AND batch generator is
  hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...).
  Failures logged but don't crash startup (falls back to standard decode).
- Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N"
  at startup so the activation path is visible in server logs.
- server.py /health endpoint: now probes pld_ssm_replay counters from THREE
  possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and
  exposes whichever has them. Also surfaces batched-spec counters
  (_spec_batched_steps, tokens, acceptance EMA) under
  speculative_decoding.batched.

69/69 tests still pass. Live validation on Heretic2 workload (next step):
launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled
true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0
under load.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 22, 2026
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local
copy of _replay_ssm_forward rather than the production code path, meaning
tests could pass while the production method diverged silently.

Fix:
- New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with
  minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream)
- Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper
- tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay
  directly — tests now exercise the actual production code path

6/6 tests passing. Addresses the merge blocker from the jjang-ai review.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 22, 2026
…jang-ai#134, jjang-ai#135)

Flesh out _step_speculative with full in-batch verify, per-row accept/reject,
and per-row rollback. PLD candidate source is now wired end-to-end:
configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K
candidates per row, _step_speculative builds the batched verify input and
emits accepted drafts via response.extra_tokens.

- Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row.
- Single batched forward through self.language_model — same call signature
  as _step (verified V0.2/V0.4).
- Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j].
- Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify;
  restore rejected rows via _restore_ssm_rows after accept/reject.
- KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array.
  Handles both batched (offset is mx.array) and single-request (scalar)
  cases. Syncs _idx for RotatingKVCache compatibility.
- Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid,
  drop accepted drafts and emit correction only (predicted at pos 0).
  Per-row replay forward to recover those drafts is a future enhancement
  (C.5) — would require per-row mini-batch + write-back into BatchMambaCache.
- Set req.scratch_extra_tokens per request for downstream emission.
- Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate.
- Dispatch in _next() now routes PLD-enabled steps to _step_speculative
  with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice).
  Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1)
  takes precedence; PLD activates when only --enable-pld is set.
- Response builder: appends primary token first then extras (correct
  chronological order: seed processed → accepted drafts → next-step bonus).
  Populates response.extra_tokens. Truncates on stop token or max_tok cap.

10 unit tests in tests/test_mllm_step_speculative.py covering:
  - Fallback paths (no drafts, prefill state, K=0)
  - Pure-attention: full accept / full reject / partial accept w/ correct
    KV offset rewind
  - Hybrid: full accept (no rollback) / partial accept (drop drafts +
    SSM restore)
  - Telemetry counter increments

69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required —
tests use a mock language_model that simulates BatchKVCache offset advance.

Remaining work for full PR jjang-ai#150 (next session):
  - Per-row replay forward for hybrid partial-accept (recover dropped drafts)
  - MLLMScheduler wiring to call configure_pld_spec() at startup
  - Live validation on Heretic2 workload
  - /health pld_ssm_replay telemetry passthrough

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 22, 2026
, jjang-ai#135)

Hook up MLLMBatchGenerator's PLD speculative decoding path through the
scheduler and the /health endpoint.

- MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version
  from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from
  tokenizer + processor. Returns None if no special IDs present.
- MLLMScheduler.__init__: when config.pld_enabled AND batch generator is
  hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...).
  Failures logged but don't crash startup (falls back to standard decode).
- Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N"
  at startup so the activation path is visible in server logs.
- server.py /health endpoint: now probes pld_ssm_replay counters from THREE
  possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and
  exposes whichever has them. Also surfaces batched-spec counters
  (_spec_batched_steps, tokens, acceptance EMA) under
  speculative_decoding.batched.

69/69 tests still pass. Live validation on Heretic2 workload (next step):
launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled
true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0
under load.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 23, 2026
…or truncate

Two changes from the live byte-equality validation cycle on smolvlm
(pure-attention MLLM, mlx-community/SmolVLM-Instruct-bf16):

1. **MLLM PLD default-off** (correctness gate). Live byte-equality test
   showed PLD on pure-attention MLLM ALSO produces output drift at T=0
   ("TheTheThe number number is is"), not just hybrid. The PR jjang-ai#150
   multi-token verify forward produces logits that don't byte-match
   sequential single-token forwards for the SAME input position. Affects
   both hybrid and pure-attention MLLM paths.

   New gate: `VMLX_ENABLE_MLLM_PLD=1` (experimental opt-in for either
   architecture). Old `VMLX_ENABLE_MLLM_PLD_HYBRID=1` kept for backward
   compat. Without the flag, --enable-pld is silently ignored on MLLM
   path (logs the reason).

   Simple-engine PLD (PR jjang-ai#149's Scheduler path) is unaffected — PR jjang-ai#26's
   +4-7% baseline holds for the non-MLLM hybrid case.

2. **Per-n_accept histogram** in /health. Diagnostic telemetry:
   `speculative_decoding.batched.accept_histogram[n]` = count of rounds
   where exactly n drafts accepted. Lets ops debug workload-specific
   acceptance patterns (e.g. code shows partial accepts; JSON shows full
   accepts).

3. **B=1 tensor truncate on pure-attention KV rewind**. When verify
   advances cache to N+K+1 but rollback rewinds offset to N+1+n, the
   keys/values tensor still has stale verify content at the trailing
   positions. Some MLX attention paths may read the full tensor (not
   just :offset). For B=1 we truncate the tensor; for B>1 we preserve
   the multi-row structure (best-effort, attention mask must respect
   per-row offset).

Tests: 95/95 pass.

Live validation (mlx-community/SmolVLM-Instruct-bf16, max_num_seqs=1):
- PLD off vs PLD on at T=0: 1/4 prompts byte-equal (the short one);
  3/4 diverge. Confirms PLD on MLLM has unresolved correctness issue.
- With this commit, default behavior is PLD-off on MLLM (byte-equal
  to PLD-off baseline by construction).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 23, 2026
…ests

PR jjang-ai#172 Phase C + test layers 1/2/3/5.

**TurboQuant cache short-circuit** (Phase C, vmlx_engine/mllm_batch_generator.py):
PLD's per-row writeback doesn't preserve `left_padding` on TurboQuant KV
caches, causing `c.extract(idx)` to crash on request completion (observed
on smolvlm with `--kv-cache-quantization q4`). Detect TQ via
`type(c).__name__ == "TurboQuantKVCache"` at the entry to
`_step_speculative` and fall back to standard `_step`. Logs once at
INFO. PLD + TQ + MLLM compatibility is deferred to a follow-up;
short-term workaround: use `--kv-cache-quantization none/q8` for PLD.

**Logit-equivalence tests (Layer 1, slow, real-model):**
`tests/test_mllm_pld_logits_equivalence.py` — 3 tests on
`mlx-community/Llama-3.2-1B-Instruct-4bit` (small pure-attention LLM,
verified deterministic at T=0):
  - argmax match between multi-token and sequential forward
  - logits max-diff within 1e-2 FP tolerance
  - cache state (offset + last-position keys) matches

All 3 PASS on Llama, proving the multi-token forward IS mathematically
correct in mlx_lm. The smolvlm-specific divergence (PR jjang-ai#172 Phase B
diagnostic) is therefore in mlx_vlm's wrapper layer, not mlx_lm core.

**TQ + invariant + edge case unit tests (Layers 2/3/5, mock):**
`tests/test_mllm_pld_tq_and_invariants.py` — 11 tests covering:
  - TQ cache detection short-circuits (no spec steps, no crash)
  - TQ skip log fires once
  - scratch_extra_tokens initialized to None
  - Cooldown decrement invariant + no underflow
  - Offset monotonicity under sequential verify
  - Acceptance histogram increments per row per step
  - Telemetry counters strictly monotonic
  - K=0 / B=0 / prefill-phase early returns

**Live byte-equality on Llama-3.2 (PR jjang-ai#149 simple-engine path):**
Discovered that PR jjang-ai#149's PLD path is NOT reachable via `vmlx serve` —
both batched and non-batched modes route through different schedulers
that bypass `Scheduler._try_speculative_decode`. PR jjang-ai#149's `pld_ssm_replay`
counters remain 0 in all tested configurations. PR jjang-ai#149 was validated
via mock unit tests; live activation needs a deeper investigation (likely
needs `vmlx_engine` API call directly with `AsyncEngineCore(..., scheduler=Scheduler(...))`).

**Total: 117 tests passing** (95 mock unit + 3 slow logit-equivalence
+ 11 new TQ/invariant + 8 new B>1 from PR jjang-ai#171).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local
copy of _replay_ssm_forward rather than the production code path, meaning
tests could pass while the production method diverged silently.

Fix:
- New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with
  minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream)
- Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper
- tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay
  directly — tests now exercise the actual production code path

6/6 tests passing. Addresses the merge blocker from the jjang-ai review.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…jang-ai#134, jjang-ai#135)

Flesh out _step_speculative with full in-batch verify, per-row accept/reject,
and per-row rollback. PLD candidate source is now wired end-to-end:
configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K
candidates per row, _step_speculative builds the batched verify input and
emits accepted drafts via response.extra_tokens.

- Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row.
- Single batched forward through self.language_model — same call signature
  as _step (verified V0.2/V0.4).
- Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j].
- Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify;
  restore rejected rows via _restore_ssm_rows after accept/reject.
- KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array.
  Handles both batched (offset is mx.array) and single-request (scalar)
  cases. Syncs _idx for RotatingKVCache compatibility.
- Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid,
  drop accepted drafts and emit correction only (predicted at pos 0).
  Per-row replay forward to recover those drafts is a future enhancement
  (C.5) — would require per-row mini-batch + write-back into BatchMambaCache.
- Set req.scratch_extra_tokens per request for downstream emission.
- Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate.
- Dispatch in _next() now routes PLD-enabled steps to _step_speculative
  with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice).
  Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1)
  takes precedence; PLD activates when only --enable-pld is set.
- Response builder: appends primary token first then extras (correct
  chronological order: seed processed → accepted drafts → next-step bonus).
  Populates response.extra_tokens. Truncates on stop token or max_tok cap.

10 unit tests in tests/test_mllm_step_speculative.py covering:
  - Fallback paths (no drafts, prefill state, K=0)
  - Pure-attention: full accept / full reject / partial accept w/ correct
    KV offset rewind
  - Hybrid: full accept (no rollback) / partial accept (drop drafts +
    SSM restore)
  - Telemetry counter increments

69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required —
tests use a mock language_model that simulates BatchKVCache offset advance.

Remaining work for full PR jjang-ai#150 (next session):
  - Per-row replay forward for hybrid partial-accept (recover dropped drafts)
  - MLLMScheduler wiring to call configure_pld_spec() at startup
  - Live validation on Heretic2 workload
  - /health pld_ssm_replay telemetry passthrough

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
, jjang-ai#135)

Hook up MLLMBatchGenerator's PLD speculative decoding path through the
scheduler and the /health endpoint.

- MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version
  from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from
  tokenizer + processor. Returns None if no special IDs present.
- MLLMScheduler.__init__: when config.pld_enabled AND batch generator is
  hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...).
  Failures logged but don't crash startup (falls back to standard decode).
- Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N"
  at startup so the activation path is visible in server logs.
- server.py /health endpoint: now probes pld_ssm_replay counters from THREE
  possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and
  exposes whichever has them. Also surfaces batched-spec counters
  (_spec_batched_steps, tokens, acceptance EMA) under
  speculative_decoding.batched.

69/69 tests still pass. Live validation on Heretic2 workload (next step):
launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled
true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0
under load.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…jang-ai#153)

Documents the batched-MLLM PLD path added by PRs jjang-ai#149/jjang-ai#150:

- Dispatch path: when each of {draft-spec, PLD, standard _step} activates
- In-batch verify mechanics (B, K+1) input shape
- Per-row rollback variants (full accept / pure-attention reject / hybrid)
- Hybrid Mamba limitation: parallel-scan kernel vs recurrent-replay state
  divergence; default-disabled with VMLX_ENABLE_MLLM_PLD_HYBRID=1 opt-in
- Deferred attention-only verify proposal (issue jjang-ai#134 original design)
  with cost/benefit analysis and decision criteria for future implementers
- Full env-var reference table (VMLX_DISABLE_PLD_REPLAY,
  VMLX_ENABLE_MLLM_PLD_HYBRID, VMLX_PLD_MIN_ACCEPTANCE,
  VMLX_PLD_WARMUP_STEPS, VMLX_PLD_PROBE_INTERVAL, VMLX_PLD_DEBUG)
- Telemetry format on /health (pld_ssm_replay + speculative_decoding.batched)
- 3-layer validation strategy: unit (~95 tests), live byte-equality script,
  live benchmark

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…or truncate

Two changes from the live byte-equality validation cycle on smolvlm
(pure-attention MLLM, mlx-community/SmolVLM-Instruct-bf16):

1. **MLLM PLD default-off** (correctness gate). Live byte-equality test
   showed PLD on pure-attention MLLM ALSO produces output drift at T=0
   ("TheTheThe number number is is"), not just hybrid. The PR jjang-ai#150
   multi-token verify forward produces logits that don't byte-match
   sequential single-token forwards for the SAME input position. Affects
   both hybrid and pure-attention MLLM paths.

   New gate: `VMLX_ENABLE_MLLM_PLD=1` (experimental opt-in for either
   architecture). Old `VMLX_ENABLE_MLLM_PLD_HYBRID=1` kept for backward
   compat. Without the flag, --enable-pld is silently ignored on MLLM
   path (logs the reason).

   Simple-engine PLD (PR jjang-ai#149's Scheduler path) is unaffected — PR jjang-ai#26's
   +4-7% baseline holds for the non-MLLM hybrid case.

2. **Per-n_accept histogram** in /health. Diagnostic telemetry:
   `speculative_decoding.batched.accept_histogram[n]` = count of rounds
   where exactly n drafts accepted. Lets ops debug workload-specific
   acceptance patterns (e.g. code shows partial accepts; JSON shows full
   accepts).

3. **B=1 tensor truncate on pure-attention KV rewind**. When verify
   advances cache to N+K+1 but rollback rewinds offset to N+1+n, the
   keys/values tensor still has stale verify content at the trailing
   positions. Some MLX attention paths may read the full tensor (not
   just :offset). For B=1 we truncate the tensor; for B>1 we preserve
   the multi-row structure (best-effort, attention mask must respect
   per-row offset).

Tests: 95/95 pass.

Live validation (mlx-community/SmolVLM-Instruct-bf16, max_num_seqs=1):
- PLD off vs PLD on at T=0: 1/4 prompts byte-equal (the short one);
  3/4 diverge. Confirms PLD on MLLM has unresolved correctness issue.
- With this commit, default behavior is PLD-off on MLLM (byte-equal
  to PLD-off baseline by construction).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…ests

PR jjang-ai#172 Phase C + test layers 1/2/3/5.

**TurboQuant cache short-circuit** (Phase C, vmlx_engine/mllm_batch_generator.py):
PLD's per-row writeback doesn't preserve `left_padding` on TurboQuant KV
caches, causing `c.extract(idx)` to crash on request completion (observed
on smolvlm with `--kv-cache-quantization q4`). Detect TQ via
`type(c).__name__ == "TurboQuantKVCache"` at the entry to
`_step_speculative` and fall back to standard `_step`. Logs once at
INFO. PLD + TQ + MLLM compatibility is deferred to a follow-up;
short-term workaround: use `--kv-cache-quantization none/q8` for PLD.

**Logit-equivalence tests (Layer 1, slow, real-model):**
`tests/test_mllm_pld_logits_equivalence.py` — 3 tests on
`mlx-community/Llama-3.2-1B-Instruct-4bit` (small pure-attention LLM,
verified deterministic at T=0):
  - argmax match between multi-token and sequential forward
  - logits max-diff within 1e-2 FP tolerance
  - cache state (offset + last-position keys) matches

All 3 PASS on Llama, proving the multi-token forward IS mathematically
correct in mlx_lm. The smolvlm-specific divergence (PR jjang-ai#172 Phase B
diagnostic) is therefore in mlx_vlm's wrapper layer, not mlx_lm core.

**TQ + invariant + edge case unit tests (Layers 2/3/5, mock):**
`tests/test_mllm_pld_tq_and_invariants.py` — 11 tests covering:
  - TQ cache detection short-circuits (no spec steps, no crash)
  - TQ skip log fires once
  - scratch_extra_tokens initialized to None
  - Cooldown decrement invariant + no underflow
  - Offset monotonicity under sequential verify
  - Acceptance histogram increments per row per step
  - Telemetry counters strictly monotonic
  - K=0 / B=0 / prefill-phase early returns

**Live byte-equality on Llama-3.2 (PR jjang-ai#149 simple-engine path):**
Discovered that PR jjang-ai#149's PLD path is NOT reachable via `vmlx serve` —
both batched and non-batched modes route through different schedulers
that bypass `Scheduler._try_speculative_decode`. PR jjang-ai#149's `pld_ssm_replay`
counters remain 0 in all tested configurations. PR jjang-ai#149 was validated
via mock unit tests; live activation needs a deeper investigation (likely
needs `vmlx_engine` API call directly with `AsyncEngineCore(..., scheduler=Scheduler(...))`).

**Total: 117 tests passing** (95 mock unit + 3 slow logit-equivalence
+ 11 new TQ/invariant + 8 new B>1 from PR jjang-ai#171).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local
copy of _replay_ssm_forward rather than the production code path, meaning
tests could pass while the production method diverged silently.

Fix:
- New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with
  minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream)
- Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper
- tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay
  directly — tests now exercise the actual production code path

6/6 tests passing. Addresses the merge blocker from the jjang-ai review.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…jang-ai#134, jjang-ai#135)

Flesh out _step_speculative with full in-batch verify, per-row accept/reject,
and per-row rollback. PLD candidate source is now wired end-to-end:
configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K
candidates per row, _step_speculative builds the batched verify input and
emits accepted drafts via response.extra_tokens.

- Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row.
- Single batched forward through self.language_model — same call signature
  as _step (verified V0.2/V0.4).
- Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j].
- Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify;
  restore rejected rows via _restore_ssm_rows after accept/reject.
- KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array.
  Handles both batched (offset is mx.array) and single-request (scalar)
  cases. Syncs _idx for RotatingKVCache compatibility.
- Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid,
  drop accepted drafts and emit correction only (predicted at pos 0).
  Per-row replay forward to recover those drafts is a future enhancement
  (C.5) — would require per-row mini-batch + write-back into BatchMambaCache.
- Set req.scratch_extra_tokens per request for downstream emission.
- Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate.
- Dispatch in _next() now routes PLD-enabled steps to _step_speculative
  with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice).
  Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1)
  takes precedence; PLD activates when only --enable-pld is set.
- Response builder: appends primary token first then extras (correct
  chronological order: seed processed → accepted drafts → next-step bonus).
  Populates response.extra_tokens. Truncates on stop token or max_tok cap.

10 unit tests in tests/test_mllm_step_speculative.py covering:
  - Fallback paths (no drafts, prefill state, K=0)
  - Pure-attention: full accept / full reject / partial accept w/ correct
    KV offset rewind
  - Hybrid: full accept (no rollback) / partial accept (drop drafts +
    SSM restore)
  - Telemetry counter increments

69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required —
tests use a mock language_model that simulates BatchKVCache offset advance.

Remaining work for full PR jjang-ai#150 (next session):
  - Per-row replay forward for hybrid partial-accept (recover dropped drafts)
  - MLLMScheduler wiring to call configure_pld_spec() at startup
  - Live validation on Heretic2 workload
  - /health pld_ssm_replay telemetry passthrough

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
, jjang-ai#135)

Hook up MLLMBatchGenerator's PLD speculative decoding path through the
scheduler and the /health endpoint.

- MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version
  from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from
  tokenizer + processor. Returns None if no special IDs present.
- MLLMScheduler.__init__: when config.pld_enabled AND batch generator is
  hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...).
  Failures logged but don't crash startup (falls back to standard decode).
- Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N"
  at startup so the activation path is visible in server logs.
- server.py /health endpoint: now probes pld_ssm_replay counters from THREE
  possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and
  exposes whichever has them. Also surfaces batched-spec counters
  (_spec_batched_steps, tokens, acceptance EMA) under
  speculative_decoding.batched.

69/69 tests still pass. Live validation on Heretic2 workload (next step):
launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled
true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0
under load.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local
copy of _replay_ssm_forward rather than the production code path, meaning
tests could pass while the production method diverged silently.

Fix:
- New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with
  minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream)
- Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper
- tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay
  directly — tests now exercise the actual production code path

6/6 tests passing. Addresses the merge blocker from the jjang-ai review.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local
copy of _replay_ssm_forward rather than the production code path, meaning
tests could pass while the production method diverged silently.

Fix:
- New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with
  minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream)
- Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper
- tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay
  directly — tests now exercise the actual production code path

6/6 tests passing. Addresses the merge blocker from the jjang-ai review.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…jang-ai#134, jjang-ai#135)

Flesh out _step_speculative with full in-batch verify, per-row accept/reject,
and per-row rollback. PLD candidate source is now wired end-to-end:
configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K
candidates per row, _step_speculative builds the batched verify input and
emits accepted drafts via response.extra_tokens.

- Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row.
- Single batched forward through self.language_model — same call signature
  as _step (verified V0.2/V0.4).
- Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j].
- Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify;
  restore rejected rows via _restore_ssm_rows after accept/reject.
- KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array.
  Handles both batched (offset is mx.array) and single-request (scalar)
  cases. Syncs _idx for RotatingKVCache compatibility.
- Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid,
  drop accepted drafts and emit correction only (predicted at pos 0).
  Per-row replay forward to recover those drafts is a future enhancement
  (C.5) — would require per-row mini-batch + write-back into BatchMambaCache.
- Set req.scratch_extra_tokens per request for downstream emission.
- Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate.
- Dispatch in _next() now routes PLD-enabled steps to _step_speculative
  with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice).
  Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1)
  takes precedence; PLD activates when only --enable-pld is set.
- Response builder: appends primary token first then extras (correct
  chronological order: seed processed → accepted drafts → next-step bonus).
  Populates response.extra_tokens. Truncates on stop token or max_tok cap.

10 unit tests in tests/test_mllm_step_speculative.py covering:
  - Fallback paths (no drafts, prefill state, K=0)
  - Pure-attention: full accept / full reject / partial accept w/ correct
    KV offset rewind
  - Hybrid: full accept (no rollback) / partial accept (drop drafts +
    SSM restore)
  - Telemetry counter increments

69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required —
tests use a mock language_model that simulates BatchKVCache offset advance.

Remaining work for full PR jjang-ai#150 (next session):
  - Per-row replay forward for hybrid partial-accept (recover dropped drafts)
  - MLLMScheduler wiring to call configure_pld_spec() at startup
  - Live validation on Heretic2 workload
  - /health pld_ssm_replay telemetry passthrough

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
, jjang-ai#135)

Hook up MLLMBatchGenerator's PLD speculative decoding path through the
scheduler and the /health endpoint.

- MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version
  from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from
  tokenizer + processor. Returns None if no special IDs present.
- MLLMScheduler.__init__: when config.pld_enabled AND batch generator is
  hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...).
  Failures logged but don't crash startup (falls back to standard decode).
- Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N"
  at startup so the activation path is visible in server logs.
- server.py /health endpoint: now probes pld_ssm_replay counters from THREE
  possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and
  exposes whichever has them. Also surfaces batched-spec counters
  (_spec_batched_steps, tokens, acceptance EMA) under
  speculative_decoding.batched.

69/69 tests still pass. Live validation on Heretic2 workload (next step):
launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled
true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0
under load.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…jang-ai#153)

Documents the batched-MLLM PLD path added by PRs jjang-ai#149/jjang-ai#150:

- Dispatch path: when each of {draft-spec, PLD, standard _step} activates
- In-batch verify mechanics (B, K+1) input shape
- Per-row rollback variants (full accept / pure-attention reject / hybrid)
- Hybrid Mamba limitation: parallel-scan kernel vs recurrent-replay state
  divergence; default-disabled with VMLX_ENABLE_MLLM_PLD_HYBRID=1 opt-in
- Deferred attention-only verify proposal (issue jjang-ai#134 original design)
  with cost/benefit analysis and decision criteria for future implementers
- Full env-var reference table (VMLX_DISABLE_PLD_REPLAY,
  VMLX_ENABLE_MLLM_PLD_HYBRID, VMLX_PLD_MIN_ACCEPTANCE,
  VMLX_PLD_WARMUP_STEPS, VMLX_PLD_PROBE_INTERVAL, VMLX_PLD_DEBUG)
- Telemetry format on /health (pld_ssm_replay + speculative_decoding.batched)
- 3-layer validation strategy: unit (~95 tests), live byte-equality script,
  live benchmark

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…or truncate

Two changes from the live byte-equality validation cycle on smolvlm
(pure-attention MLLM, mlx-community/SmolVLM-Instruct-bf16):

1. **MLLM PLD default-off** (correctness gate). Live byte-equality test
   showed PLD on pure-attention MLLM ALSO produces output drift at T=0
   ("TheTheThe number number is is"), not just hybrid. The PR jjang-ai#150
   multi-token verify forward produces logits that don't byte-match
   sequential single-token forwards for the SAME input position. Affects
   both hybrid and pure-attention MLLM paths.

   New gate: `VMLX_ENABLE_MLLM_PLD=1` (experimental opt-in for either
   architecture). Old `VMLX_ENABLE_MLLM_PLD_HYBRID=1` kept for backward
   compat. Without the flag, --enable-pld is silently ignored on MLLM
   path (logs the reason).

   Simple-engine PLD (PR jjang-ai#149's Scheduler path) is unaffected — PR jjang-ai#26's
   +4-7% baseline holds for the non-MLLM hybrid case.

2. **Per-n_accept histogram** in /health. Diagnostic telemetry:
   `speculative_decoding.batched.accept_histogram[n]` = count of rounds
   where exactly n drafts accepted. Lets ops debug workload-specific
   acceptance patterns (e.g. code shows partial accepts; JSON shows full
   accepts).

3. **B=1 tensor truncate on pure-attention KV rewind**. When verify
   advances cache to N+K+1 but rollback rewinds offset to N+1+n, the
   keys/values tensor still has stale verify content at the trailing
   positions. Some MLX attention paths may read the full tensor (not
   just :offset). For B=1 we truncate the tensor; for B>1 we preserve
   the multi-row structure (best-effort, attention mask must respect
   per-row offset).

Tests: 95/95 pass.

Live validation (mlx-community/SmolVLM-Instruct-bf16, max_num_seqs=1):
- PLD off vs PLD on at T=0: 1/4 prompts byte-equal (the short one);
  3/4 diverge. Confirms PLD on MLLM has unresolved correctness issue.
- With this commit, default behavior is PLD-off on MLLM (byte-equal
  to PLD-off baseline by construction).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adam Staniszewski and others added 3 commits May 24, 2026 22:20
…i#134)

On hybrid SSM/ATT models with 0 < num_accept < K, the PLD path previously
discarded accepted drafts and emitted only a correction token. This PR adds
_replay_ssm_forward() which restores caches to N, replays the accepted tokens
through the model, and emits num_accept+1 tokens instead of 1.

- New Scheduler._replay_ssm_forward() staticmethod (scheduler.py)
- Modify case (b): try replay first, fall back to correction-only on failure
- Add _pld_replay_{enabled,attempts,emitted,failures} counters
- Add pld_ssm_replay telemetry to /health endpoint (server.py)
- Document the fix in notes/prompt-lookup-decoding.md
- 6 unit tests in tests/test_pld_ssm_replay.py
- New partial_accept_stress benchmark in tests/benchmark/test_pld_acceptance.py

Expected gain: +5-10% on top of PR jjang-ai#26's +4-7% on hybrid models.
Disable: VMLX_DISABLE_PLD_REPLAY=1

Closes jjang-ai#134

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Address jjang-ai review feedback on PR jjang-ai#149: tests were exercising a local
copy of _replay_ssm_forward rather than the production code path, meaning
tests could pass while the production method diverged silently.

Fix:
- New vmlx_engine/utils/pld_replay.py: canonical replay_ssm_forward() with
  minimal deps (lazy mlx_lm imports, contextlib fallback for generation_stream)
- Scheduler._replay_ssm_forward becomes a 2-line delegation wrapper
- tests/test_pld_ssm_replay.py imports from vmlx_engine.utils.pld_replay
  directly — tests now exercise the actual production code path

6/6 tests passing. Addresses the merge blocker from the jjang-ai review.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ai#134)

PLD n-gram lookup over the prompt could propose model-special tokens
(image-pad, vision-start, pad) as drafts. The verify forward usually rejects
them, but truncating drafts at the first excluded token avoids the risk
entirely and saves a wasted verify pass.

- prompt_lookup.find_draft_tokens(exclude_token_ids=None) added
- NgramIndex.find_drafts(exclude_token_ids=None) added
- New _truncate_at_excluded() helper
- Scheduler.__init__ builds _pld_excluded_token_ids set from
  tokenizer/processor special-token attributes (pad, image-pad, vision
  markers, additional_special_tokens). EOS/BOS NOT excluded - those are
  legitimate end-of-decode signals already gated by verify.
- Wired through both PLD call sites (lines 5193, 6626)
- 14 unit tests in tests/test_prompt_lookup_filter.py

Default behaviour unchanged when exclude_token_ids is None (back-compat).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…jang-ai#134, jjang-ai#135)

Flesh out _step_speculative with full in-batch verify, per-row accept/reject,
and per-row rollback. PLD candidate source is now wired end-to-end:
configure_pld_spec() sets the flag, _pld_drafts_for_request() finds K
candidates per row, _step_speculative builds the batched verify input and
emits accepted drafts via response.extra_tokens.

- Build (B, K+1) verify input from [last_token, d0, ..., d_{K-1}] per row.
- Single batched forward through self.language_model — same call signature
  as _step (verified V0.2/V0.4).
- Greedy per-row accept loop: compare argmax(logits[i,j]) vs drafts[i][j].
- Snapshot per-row SSM state via _snapshot_ssm_per_row BEFORE verify;
  restore rejected rows via _restore_ssm_rows after accept/reject.
- KV cache rollback: per-row offset.tolist() arithmetic, rebuild mx.array.
  Handles both batched (offset is mx.array) and single-request (scalar)
  cases. Syncs _idx for RotatingKVCache compatibility.
- Hybrid partial-accept simplification: when 0 < n_accept < K on hybrid,
  drop accepted drafts and emit correction only (predicted at pos 0).
  Per-row replay forward to recover those drafts is a future enhancement
  (C.5) — would require per-row mini-batch + write-back into BatchMambaCache.
- Set req.scratch_extra_tokens per request for downstream emission.
- Telemetry: _spec_batched_steps, _spec_batched_tokens, EMA acceptance rate.
- Dispatch in _next() now routes PLD-enabled steps to _step_speculative
  with K=2 for hybrid models, K=5 otherwise (mirrors Scheduler choice).
  Draft-model spec (when --speculative-model + VMLX_ENABLE_BATCHED_SPEC=1)
  takes precedence; PLD activates when only --enable-pld is set.
- Response builder: appends primary token first then extras (correct
  chronological order: seed processed → accepted drafts → next-step bonus).
  Populates response.extra_tokens. Truncates on stop token or max_tok cap.

10 unit tests in tests/test_mllm_step_speculative.py covering:
  - Fallback paths (no drafts, prefill state, K=0)
  - Pure-attention: full accept / full reject / partial accept w/ correct
    KV offset rewind
  - Hybrid: full accept (no rollback) / partial accept (drop drafts +
    SSM restore)
  - Telemetry counter increments

69/69 tests pass across all PR jjang-ai#149/jjang-ai#150 suites. No real model required —
tests use a mock language_model that simulates BatchKVCache offset advance.

Remaining work for full PR jjang-ai#150 (next session):
  - Per-row replay forward for hybrid partial-accept (recover dropped drafts)
  - MLLMScheduler wiring to call configure_pld_spec() at startup
  - Live validation on Heretic2 workload
  - /health pld_ssm_replay telemetry passthrough

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
, jjang-ai#135)

Hook up MLLMBatchGenerator's PLD speculative decoding path through the
scheduler and the /health endpoint.

- MLLMScheduler._build_pld_excluded_token_ids(): mirrors Scheduler version
  from PR jjang-ai#149. Collects pad/image-pad/vision-start/vision-end IDs from
  tokenizer + processor. Returns None if no special IDs present.
- MLLMScheduler.__init__: when config.pld_enabled AND batch generator is
  hybrid, calls configure_pld_spec(enabled=True, excluded_token_ids=...).
  Failures logged but don't crash startup (falls back to standard decode).
- Logs "[PLD] MLLM batched PLD enabled — K=2 (hybrid), excluded_token_ids=N"
  at startup so the activation path is visible in server logs.
- server.py /health endpoint: now probes pld_ssm_replay counters from THREE
  possible sources (Scheduler, MLLMScheduler, MLLMBatchGenerator) and
  exposes whichever has them. Also surfaces batched-spec counters
  (_spec_batched_steps, tokens, acceptance EMA) under
  speculative_decoding.batched.

69/69 tests still pass. Live validation on Heretic2 workload (next step):
launch server with --enable-pld, hit /health, expect pld_ssm_replay.enabled
true on Qwen3.6-27B hybrid; verify speculative_decoding.batched.steps > 0
under load.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…jang-ai#153)

Documents the batched-MLLM PLD path added by PRs jjang-ai#149/jjang-ai#150:

- Dispatch path: when each of {draft-spec, PLD, standard _step} activates
- In-batch verify mechanics (B, K+1) input shape
- Per-row rollback variants (full accept / pure-attention reject / hybrid)
- Hybrid Mamba limitation: parallel-scan kernel vs recurrent-replay state
  divergence; default-disabled with VMLX_ENABLE_MLLM_PLD_HYBRID=1 opt-in
- Deferred attention-only verify proposal (issue jjang-ai#134 original design)
  with cost/benefit analysis and decision criteria for future implementers
- Full env-var reference table (VMLX_DISABLE_PLD_REPLAY,
  VMLX_ENABLE_MLLM_PLD_HYBRID, VMLX_PLD_MIN_ACCEPTANCE,
  VMLX_PLD_WARMUP_STEPS, VMLX_PLD_PROBE_INTERVAL, VMLX_PLD_DEBUG)
- Telemetry format on /health (pld_ssm_replay + speculative_decoding.batched)
- 3-layer validation strategy: unit (~95 tests), live byte-equality script,
  live benchmark

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…or truncate

Two changes from the live byte-equality validation cycle on smolvlm
(pure-attention MLLM, mlx-community/SmolVLM-Instruct-bf16):

1. **MLLM PLD default-off** (correctness gate). Live byte-equality test
   showed PLD on pure-attention MLLM ALSO produces output drift at T=0
   ("TheTheThe number number is is"), not just hybrid. The PR jjang-ai#150
   multi-token verify forward produces logits that don't byte-match
   sequential single-token forwards for the SAME input position. Affects
   both hybrid and pure-attention MLLM paths.

   New gate: `VMLX_ENABLE_MLLM_PLD=1` (experimental opt-in for either
   architecture). Old `VMLX_ENABLE_MLLM_PLD_HYBRID=1` kept for backward
   compat. Without the flag, --enable-pld is silently ignored on MLLM
   path (logs the reason).

   Simple-engine PLD (PR jjang-ai#149's Scheduler path) is unaffected — PR jjang-ai#26's
   +4-7% baseline holds for the non-MLLM hybrid case.

2. **Per-n_accept histogram** in /health. Diagnostic telemetry:
   `speculative_decoding.batched.accept_histogram[n]` = count of rounds
   where exactly n drafts accepted. Lets ops debug workload-specific
   acceptance patterns (e.g. code shows partial accepts; JSON shows full
   accepts).

3. **B=1 tensor truncate on pure-attention KV rewind**. When verify
   advances cache to N+K+1 but rollback rewinds offset to N+1+n, the
   keys/values tensor still has stale verify content at the trailing
   positions. Some MLX attention paths may read the full tensor (not
   just :offset). For B=1 we truncate the tensor; for B>1 we preserve
   the multi-row structure (best-effort, attention mask must respect
   per-row offset).

Tests: 95/95 pass.

Live validation (mlx-community/SmolVLM-Instruct-bf16, max_num_seqs=1):
- PLD off vs PLD on at T=0: 1/4 prompts byte-equal (the short one);
  3/4 diverge. Confirms PLD on MLLM has unresolved correctness issue.
- With this commit, default behavior is PLD-off on MLLM (byte-equal
  to PLD-off baseline by construction).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
st-adam pushed a commit to st-adam/vmlx that referenced this pull request May 24, 2026
…ests

PR jjang-ai#172 Phase C + test layers 1/2/3/5.

**TurboQuant cache short-circuit** (Phase C, vmlx_engine/mllm_batch_generator.py):
PLD's per-row writeback doesn't preserve `left_padding` on TurboQuant KV
caches, causing `c.extract(idx)` to crash on request completion (observed
on smolvlm with `--kv-cache-quantization q4`). Detect TQ via
`type(c).__name__ == "TurboQuantKVCache"` at the entry to
`_step_speculative` and fall back to standard `_step`. Logs once at
INFO. PLD + TQ + MLLM compatibility is deferred to a follow-up;
short-term workaround: use `--kv-cache-quantization none/q8` for PLD.

**Logit-equivalence tests (Layer 1, slow, real-model):**
`tests/test_mllm_pld_logits_equivalence.py` — 3 tests on
`mlx-community/Llama-3.2-1B-Instruct-4bit` (small pure-attention LLM,
verified deterministic at T=0):
  - argmax match between multi-token and sequential forward
  - logits max-diff within 1e-2 FP tolerance
  - cache state (offset + last-position keys) matches

All 3 PASS on Llama, proving the multi-token forward IS mathematically
correct in mlx_lm. The smolvlm-specific divergence (PR jjang-ai#172 Phase B
diagnostic) is therefore in mlx_vlm's wrapper layer, not mlx_lm core.

**TQ + invariant + edge case unit tests (Layers 2/3/5, mock):**
`tests/test_mllm_pld_tq_and_invariants.py` — 11 tests covering:
  - TQ cache detection short-circuits (no spec steps, no crash)
  - TQ skip log fires once
  - scratch_extra_tokens initialized to None
  - Cooldown decrement invariant + no underflow
  - Offset monotonicity under sequential verify
  - Acceptance histogram increments per row per step
  - Telemetry counters strictly monotonic
  - K=0 / B=0 / prefill-phase early returns

**Live byte-equality on Llama-3.2 (PR jjang-ai#149 simple-engine path):**
Discovered that PR jjang-ai#149's PLD path is NOT reachable via `vmlx serve` —
both batched and non-batched modes route through different schedulers
that bypass `Scheduler._try_speculative_decode`. PR jjang-ai#149's `pld_ssm_replay`
counters remain 0 in all tested configurations. PR jjang-ai#149 was validated
via mock unit tests; live activation needs a deeper investigation (likely
needs `vmlx_engine` API call directly with `AsyncEngineCore(..., scheduler=Scheduler(...))`).

**Total: 117 tests passing** (95 mock unit + 3 slow logit-equivalence
+ 11 new TQ/invariant + 8 new B>1 from PR jjang-ai#171).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Perf: PLD verify-cost on hybrid SSM models — proposal for SSM checkpoint/replay (extends #26)

3 participants