Thinking in JAX¶
This is a beginner level book to get your grooves in Google JAX.
What is JAX¶
We attempt to describe JAX by comparing with its predecessors and peers.
In short, JAX is a new Python library for numerical computing.
It provides a
NumPylike API for numerical computing. This helps in easy transition of existing users ofNumPytoJAX.At the same time, it is built from ground up using functional programming principles.
Thus, data structures like JAX arrays are immutable.
So, although its API resembles
NumPy, it’s not quite the same. InNumPy, arrays are mutable.A major focus area for
JAXis deep learning. It has been regularly compared withTensorFlowandPyTorch.However,
JAXactually provides the low level plumbing for building machine learning libraries. The libraries which provide actual ML building blocks aredm-haiku,flax,rlax,trax, etc.The
JAXcore can be used in any scientific computing problem.JAXprovides built-in support for automatic differentiation (AD). AD is key for successfully implementing large deep learning networks.JAXprovides a JUST-IN-TIME (jit) compiler. It can compile a Python code written as per JAX conventions (functional programming etc.) to suitable machine code for a variety of hardware architectures thanks toXLA.JAXenables you to write code in Python which can run efficiently across CPU/GPU/TPU architectures.JAXhas built-in support for vectorizing a function over different dimensions of input data.