Skip to content

Hassan-Sarwat/efficient-speculative-decoding

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

211 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Efficient Speculative Decoding: Chain of Draft vs. Chain of Thought

License Python Status

This project explores the efficiency of Speculative Decoding by comparing standard Chain of Thought (CoT) reasoning against a token-optimized Chain of Draft (CoD) approach.

Key Features

  • Pipeline Automation: End-to-end training and benchmarking.
  • Methodology Comparison: Direct comparison between CoT (verbose) and CoD (concise) reasoning styles.
  • Robust Environment: Uses uv for dependency management.

Installation & Setup

We use uv to manage a single unified environment (env/, Python 3.11) holding both Unsloth (training/quantization) and vLLM 0.19.1 (speculative-decoding inference).

Version pin note: vLLM is held at 0.19.1 because Unsloth 2026.4.8 requires torch<2.11.0 while vLLM 0.20 requires torch==2.11.0. See requirements.txt for the full rationale.

Setup Environment

bash scripts/uv_setup_env.sh

Secrets

Copy .env.example to .env and populate:

  • WANDB_API_KEY — used by src/train.py for logging to the peft_cob W&B project.
  • GEMINI_API_KEY — used by data_generation/ for Google Gemini Batch API calls.

Usage

1. Train Models

Run the training pipeline to fine-tune the Target (Qwen3-14B) and Draft (Qwen3-0.6B) models. This script does NOT merge the models; it only saves the LoRA adapters.

bash scripts/train_pipeline.sh -t <type> -s <scenario>
  • -t: cot or cod
  • -s: easy (GSM8K), medium (MATH Lvl 1-2), hard (MATH Lvl 3-4)

2. Benchmark

Run the benchmark pipeline. This script handles ephemeral merging (Merge -> Benchmark -> Delete) to save space.

bash scripts/benchmark_pipeline.sh -t <type> -s <scenario> [-m <mode>] [-n <N>] [-k]
  • -t, -s: same as above.
  • -m: baseline, speculative, or both (default both).
  • -n: cap the number of evaluation samples (default: all).
  • -k: keep the temporary merged models on disk after benchmarking (default: delete).

Outputs land in outputs/: a per-sample CSV, a metrics JSON, and a comparison text file.

3. Untrained Baseline

Generate baselines using an untrained vanilla model.

bash scripts/untrained_pipeline.sh <scenario>

4. Run Nightly Queue

Run all experiments sequentially.

bash scripts/run_queue.sh > nightly_log.txt 2>&1

Project Structure

.
├── CLAUDE.md                 # Auto-loaded guidance for Claude Code
├── configs/                  # YAML configs (target_14b, draft_0-6b)
├── data/
│   ├── processed/            # Cleaned training data
│   ├── distilled/            # Synthetic data generated by Teacher
│   ├── tests/                # Held-out eval splits
│   └── raw/, analysis/       # Source dumps and analysis artifacts
├── data_generation/          # Gemini Batch API dataset generation
├── deployments/              # Dockerfile + docker-compose.yml
├── models/                   # Saved LoRA adapters (and temp merged models)
├── outputs/                  # Benchmark CSVs, metrics JSON, comparison text
├── scripts/
│   ├── train_pipeline.sh     # Training pipeline (No merge)
│   ├── benchmark_pipeline.sh # Benchmark pipeline (Ephemeral merge)
│   ├── untrained_pipeline.sh # Baseline generation
│   ├── run_queue.sh          # Run all 9 jobs sequentially
│   └── uv_setup_env.sh       # Environment setup
├── src/
│   ├── train.py              # SFTTrainer script
│   ├── distill_data.py       # Data generation script (resumable)
│   ├── distill_untrained.py  # Untrained-baseline data generation
│   ├── merge_adapter.py      # Adapter merging script
│   └── answer_utils.py       # Math answer extraction & comparison
└── tests/
    └── benchmark.py          # Speculative Decoding Benchmark

Data

  • Easy: GSM8K
  • Medium: MATH (Levels 1-2)
  • Hard: MATH (Levels 3-4)

For data generation details, see data_generation/README.md.


Results

All benchmarks run on 1000 samples (easy/hard) and 655 samples (medium) using Qwen3-14B (target) + Qwen3-0.6B (draft), with K=5 speculative tokens.

Speculative Decoding Speedup

Scenario CoT Speedup CoD Speedup CoT Acceptance CoD Acceptance
Easy 1.41x 1.53x 82.73% 88.23%
Medium 1.50x 1.56x 84.97% 88.95%
Hard 1.55x 1.64x 83.44% 87.98%

CoD consistently achieves higher acceptance rates (~88–89%) and throughput speedups than CoT (~83–85%), because the draft model more reliably predicts the concise, structured token sequences CoD produces.

Accuracy

Scenario Base (Untrained) CoT Baseline CoT + Spec CoD Baseline CoD + Spec
Easy 24.9% 75.5% 75.7% 92.2% 92.1%
Medium 21.2% 77.6% 77.9% 82.7% 83.1%
Hard 11.2% 50.5% 50.8% 61.9% 61.7%

Fine-tuning with LoRA (both CoT and CoD) produces large accuracy gains over untrained base models. CoD achieves higher accuracy than CoT across all difficulty levels, most notably on easy tasks (+16.7pp). Speculative decoding introduces negligible accuracy delta (<0.3pp) for both methods.

Token Efficiency

Fine-tuned models generate dramatically fewer tokens than the untrained base (~511 tokens), even in baseline (non-speculative) mode:

Scenario Base Tokens CoT Tokens CoD Tokens
Easy 511 114 96
Medium 512 225 165
Hard 512 311 305

CoD reasoning is slightly more token-efficient than CoT (roughly 15–27% fewer tokens), and both are 4–5x more concise than the untrained model. This reduced sequence length is what allows CoD to attain a higher draft-acceptance rate and greater latency reduction (34–39% vs 29–35% for CoT).

Key Takeaways

  1. CoD > CoT for speculative decoding: shorter, more predictable token sequences yield higher acceptance rates, larger speedups, and better accuracy.
  2. Speculative decoding is accuracy-neutral: enabling draft speculation adds at most ±0.3pp accuracy change.
  3. Fine-tuning is essential: untrained base models plateau at ~11–25% accuracy; LoRA fine-tuning pushes that to 51–92%.
  4. Latency gains are substantial: CoD + spec decoding cuts TTFT by 35–39% and boosts throughput to 363–435 tokens/sec (vs 238–278 tok/s baseline).

About

Improving both reasoning speed of LLM using Chain of Draft fine tuning and token output using Speculative Decoding

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors