MaxText
High performance, highly scalable, open-source LLM library and reference implementation written in pure Python/JAX and targeting Google Cloud TPUs and GPUs for training.
High-performance
MaxText achieves high Model FLOPs Utilization (MFU) and tokens/second from single host to very large clusters while staying simple and largely "optimization-free" thanks to the power of JAX and the XLA compiler.
Pre-training
MaxText provides opinionated implementations for how to achieve optimal performance across a wide variety of dimensions like sharding, quantization, and checkpointing.
Post-training
MaxText provides a scalable framework to fine-tune proprietary or OSS models using state-of-the-art Reinforcement Learning (RL) algorithms (e.g., GRPO) and techniques (e.g. SFT, Knowledge Distillation, etc).
JAX AI Stack
The JAX AI Stack is a curated collection of libraries that researchers and engineers, both inside and outside of Google, have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more.
- JAX - core array operations and program transformations
- Flax - For building neural networks
- Orbax - For checkpointing and persistence utilities
- Optax - For gradient processing and optimization
- Tunix - A JAX Library with the latest experimental algorithms and post-training techniques
- ml_dtypes - NumPy dtype extensions for machine learning.
- MaxText model library for JAX LLMs highly optimized for TPUs
- vLLM on TPU for high performance sampling (inference) for Reinforcement Learning (RL)
- Pathways for multi-host inference (sampling) and highly efficient weight transfer
- Optional data loading libraries (Grain or tf.data)