|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""Logit-equivalence tests for sequential vs multi-token verify (PR #172 Layer 1). |
| 3 | +
|
| 4 | +Pins the root cause of MLLM PLD output drift: multi-token forward produces |
| 5 | +different logits than K+1 sequential single-token forwards at the same input |
| 6 | +positions. The sequential-verify fix in `_step_speculative` (PR #172) trades |
| 7 | +K+1× kernel-launch cost for byte-equivalence with standalone decode. |
| 8 | +
|
| 9 | +These tests load a real model (small pure-attention LLM) and directly |
| 10 | +compare logits position-by-position. Marked @pytest.mark.slow; require |
| 11 | +model to be cached or downloadable. |
| 12 | +
|
| 13 | +Run: |
| 14 | + .venv/bin/python -m pytest tests/test_mllm_pld_logits_equivalence.py \\ |
| 15 | + -v --run-slow |
| 16 | +""" |
| 17 | + |
| 18 | +from __future__ import annotations |
| 19 | + |
| 20 | +import copy |
| 21 | + |
| 22 | +import mlx.core as mx |
| 23 | +import pytest |
| 24 | + |
| 25 | + |
| 26 | +# Small pure-attention LLM. mlx-community/Llama-3.2-1B-Instruct-4bit is the |
| 27 | +# same model used by tests/test_batching_deterministic.py (proven deterministic |
| 28 | +# at T=0 in vmlx). 600 MB cached. |
| 29 | +TEST_MODEL = "mlx-community/Llama-3.2-1B-Instruct-4bit" |
| 30 | + |
| 31 | + |
| 32 | +@pytest.fixture(scope="module") |
| 33 | +def model_and_tokenizer(): |
| 34 | + try: |
| 35 | + from mlx_lm import load |
| 36 | + model, tokenizer = load(TEST_MODEL) |
| 37 | + return model, tokenizer |
| 38 | + except Exception as e: |
| 39 | + pytest.skip(f"Could not load model {TEST_MODEL}: {e}") |
| 40 | + |
| 41 | + |
| 42 | +def _prefill_and_cache(model, prompt_ids): |
| 43 | + """Prefill the model on prompt_ids, return the cache.""" |
| 44 | + if hasattr(model, "make_cache"): |
| 45 | + cache = model.make_cache() |
| 46 | + else: |
| 47 | + from mlx_lm.models.cache import KVCache |
| 48 | + n_layers = ( |
| 49 | + len(model.layers) if hasattr(model, "layers") else |
| 50 | + len(model.model.layers) if hasattr(model, "model") else 1 |
| 51 | + ) |
| 52 | + cache = [KVCache() for _ in range(n_layers)] |
| 53 | + prefill_input = mx.array([prompt_ids]) |
| 54 | + _ = model(prefill_input, cache=cache) |
| 55 | + mx.eval([c.keys if hasattr(c, "keys") and c.keys is not None else c for c in cache]) |
| 56 | + return cache |
| 57 | + |
| 58 | + |
| 59 | +@pytest.mark.slow |
| 60 | +class TestLogitsEquivalence: |
| 61 | + """Validate that sequential K+1 single-token forwards produce equivalent |
| 62 | + output to a single multi-token forward of shape (B, K+1). |
| 63 | +
|
| 64 | + NOTE: On some models (e.g., smolvlm), these tests FAIL — that's the bug |
| 65 | + PR #172's sequential verify works around. The tests document the |
| 66 | + expected behaviour and serve as regression detection if mlx_lm |
| 67 | + upstream is fixed. |
| 68 | + """ |
| 69 | + |
| 70 | + def test_sequential_matches_multi_token_via_argmax(self, model_and_tokenizer): |
| 71 | + """argmax token at each position must match between multi-token |
| 72 | + and sequential forwards.""" |
| 73 | + model, tokenizer = model_and_tokenizer |
| 74 | + prompt = "The quick brown fox" |
| 75 | + prompt_ids = tokenizer.encode(prompt) if hasattr(tokenizer, "encode") else list( |
| 76 | + tokenizer(prompt).input_ids |
| 77 | + ) |
| 78 | + |
| 79 | + cache_for_prefill = _prefill_and_cache(model, prompt_ids) |
| 80 | + |
| 81 | + # 3 draft tokens (ASCII range, safe for most tokenizers) |
| 82 | + drafts = [100, 101, 102] |
| 83 | + |
| 84 | + # Path 1: multi-token forward |
| 85 | + cache_M = copy.deepcopy(cache_for_prefill) |
| 86 | + out_M = model(mx.array([drafts]), cache=cache_M) |
| 87 | + logits_M = out_M.logits if hasattr(out_M, "logits") else out_M |
| 88 | + mx.eval(logits_M) |
| 89 | + assert logits_M.shape[0] == 1 |
| 90 | + assert logits_M.shape[1] == 3, f"expected T=3, got {logits_M.shape}" |
| 91 | + |
| 92 | + # Path 2: sequential single-token forwards |
| 93 | + cache_S = copy.deepcopy(cache_for_prefill) |
| 94 | + seq_logits = [] |
| 95 | + for t in drafts: |
| 96 | + out_S = model(mx.array([[t]]), cache=cache_S) |
| 97 | + logits_t = out_S.logits if hasattr(out_S, "logits") else out_S |
| 98 | + mx.eval(logits_t) |
| 99 | + seq_logits.append(logits_t[:, -1:, :]) |
| 100 | + logits_S = mx.concatenate(seq_logits, axis=1) |
| 101 | + assert logits_S.shape[1] == 3 |
| 102 | + |
| 103 | + # Compare argmax per position |
| 104 | + for j in range(3): |
| 105 | + argmax_M = int(mx.argmax(logits_M[:, j, :], axis=-1).item()) |
| 106 | + argmax_S = int(mx.argmax(logits_S[:, j, :], axis=-1).item()) |
| 107 | + # On Llama-3.2 we expect these to match (deterministic at T=0). |
| 108 | + # On smolvlm they don't — bug documented in PR #172. |
| 109 | + assert argmax_M == argmax_S, ( |
| 110 | + f"argmax mismatch at pos {j}: M={argmax_M} S={argmax_S}. " |
| 111 | + f"Sequential verify is the workaround." |
| 112 | + ) |
| 113 | + |
| 114 | + def test_sequential_matches_multi_token_via_logits_tolerance(self, model_and_tokenizer): |
| 115 | + """logits should match within FP tolerance (1e-3 ABS). |
| 116 | +
|
| 117 | + On smolvlm this fails (max_diff up to 0.375). On Llama-3.2 it should |
| 118 | + pass — meaning Llama's multi-token forward IS correct, and the |
| 119 | + smolvlm-specific bug is in mlx_vlm's wrapper, not mlx_lm core. |
| 120 | + """ |
| 121 | + model, tokenizer = model_and_tokenizer |
| 122 | + prompt = "The quick brown fox" |
| 123 | + prompt_ids = tokenizer.encode(prompt) if hasattr(tokenizer, "encode") else list( |
| 124 | + tokenizer(prompt).input_ids |
| 125 | + ) |
| 126 | + cache_for_prefill = _prefill_and_cache(model, prompt_ids) |
| 127 | + drafts = [100, 101, 102] |
| 128 | + |
| 129 | + cache_M = copy.deepcopy(cache_for_prefill) |
| 130 | + out_M = model(mx.array([drafts]), cache=cache_M) |
| 131 | + logits_M = out_M.logits if hasattr(out_M, "logits") else out_M |
| 132 | + mx.eval(logits_M) |
| 133 | + |
| 134 | + cache_S = copy.deepcopy(cache_for_prefill) |
| 135 | + seq_logits = [] |
| 136 | + for t in drafts: |
| 137 | + out_S = model(mx.array([[t]]), cache=cache_S) |
| 138 | + logits_t = out_S.logits if hasattr(out_S, "logits") else out_S |
| 139 | + mx.eval(logits_t) |
| 140 | + seq_logits.append(logits_t[:, -1:, :]) |
| 141 | + logits_S = mx.concatenate(seq_logits, axis=1) |
| 142 | + |
| 143 | + max_diff_overall = 0.0 |
| 144 | + per_pos_diffs = [] |
| 145 | + for j in range(3): |
| 146 | + d = mx.abs(logits_M[:, j, :] - logits_S[:, j, :]).max().item() |
| 147 | + per_pos_diffs.append(d) |
| 148 | + max_diff_overall = max(max_diff_overall, d) |
| 149 | + |
| 150 | + # Tolerance: 1e-2 covers normal MLX FP variance on Apple Silicon. |
| 151 | + # If this fails, the model has a known multi-token bug → use |
| 152 | + # sequential verify as workaround. |
| 153 | + assert max_diff_overall < 1e-2, ( |
| 154 | + f"Multi-token logits diverge from sequential: " |
| 155 | + f"max_diff={max_diff_overall:.6e}, per_pos={per_pos_diffs}. " |
| 156 | + f"Sequential verify is the correctness workaround." |
| 157 | + ) |
| 158 | + |
| 159 | + def test_sequential_matches_multi_token_cache_state(self, model_and_tokenizer): |
| 160 | + """Ending cache state (keys at last position) should match between |
| 161 | + multi-token and sequential paths.""" |
| 162 | + model, tokenizer = model_and_tokenizer |
| 163 | + prompt_ids = tokenizer.encode("Hello world") if hasattr(tokenizer, "encode") else [1, 2] |
| 164 | + cache_for_prefill = _prefill_and_cache(model, prompt_ids) |
| 165 | + drafts = [100, 101, 102] |
| 166 | + |
| 167 | + cache_M = copy.deepcopy(cache_for_prefill) |
| 168 | + _ = model(mx.array([drafts]), cache=cache_M) |
| 169 | + mx.eval(cache_M[0].keys if hasattr(cache_M[0], "keys") else cache_M[0]) |
| 170 | + |
| 171 | + cache_S = copy.deepcopy(cache_for_prefill) |
| 172 | + for t in drafts: |
| 173 | + _ = model(mx.array([[t]]), cache=cache_S) |
| 174 | + mx.eval(cache_S[0].keys if hasattr(cache_S[0], "keys") else cache_S[0]) |
| 175 | + |
| 176 | + # Compare final offset |
| 177 | + if hasattr(cache_M[0], "offset") and hasattr(cache_S[0], "offset"): |
| 178 | + off_M = int(cache_M[0].offset) if not isinstance(cache_M[0].offset, mx.array) else int(cache_M[0].offset.item()) |
| 179 | + off_S = int(cache_S[0].offset) if not isinstance(cache_S[0].offset, mx.array) else int(cache_S[0].offset.item()) |
| 180 | + assert off_M == off_S, f"cache offsets differ: M={off_M} S={off_S}" |
| 181 | + |
| 182 | + # Compare last-position keys |
| 183 | + if hasattr(cache_M[0], "keys") and cache_M[0].keys is not None: |
| 184 | + last_M = cache_M[0].keys[..., -1:, :] |
| 185 | + last_S = cache_S[0].keys[..., -1:, :] |
| 186 | + key_diff = mx.abs(last_M - last_S).max().item() |
| 187 | + # Expect <= 1e-3 even if logits diverge; cache writes should match |
| 188 | + assert key_diff < 1e-3, f"cache key writes diverge: max_diff={key_diff}" |
0 commit comments