MaxText

High performance, highly scalable, open-source LLM library and reference implementation written in pure Python/JAX and targeting Google Cloud TPUs and GPUs for training.

High-performance

MaxText achieves high Model FLOPs Utilization (MFU) and tokens/second from single host to very large clusters while staying simple and largely "optimization-free" thanks to the power of JAX and the XLA compiler.

Pre-training

MaxText provides opinionated implementations for how to achieve optimal performance across a wide variety of dimensions like sharding, quantization, and checkpointing.

Post-training

MaxText provides a scalable framework to fine-tune proprietary or OSS models using state-of-the-art Reinforcement Learning (RL) algorithms (e.g., GRPO) and techniques (e.g. SFT, Knowledge Distillation, etc).