Awesome JAX 

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!
Contents
- Libraries
- Models and Projects
- Videos
- Papershttps://github.com/jax-ml/jax
- Tutorials and Blog Posts
- Books
- Community
Libraries
- Neural Network Libraries
- Flax - Centered on flexibility and clarity.
- Flax NNX - An evolution on Flax by the same team
- Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
- Objax - Has an object oriented design similar to PyTorch.
- Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
- Trax - "Batteries included" deep learning library focused on providing solutions for common workloads.
- Jraph - Lightweight graph neural network library.
- Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
- HuggingFace Transformers - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
- Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
- Scenic - A Jax Library for Computer Vision Research and Beyond.
- Penzai - Prioritizes legibility, visualization, and easy editing of neural network models with composable tools and a simple mental model.
- Flax - Centered on flexibility and clarity.
- Levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX.
- EasyLM - LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
- NumPyro - Probabilistic programming based on the Pyro library.
- Chex - Utilities to write and test reliable JAX code.
- Optax - Gradient processing and optimization library.
- RLax - Library for implementing reinforcement learning agents.
- JAX, M.D. - Accelerated, differential molecular dynamics.