JAX arrays

JAX arrays are similar to NumPy arrays in many ways and different in some crucial ways. Since JAX focuses on functional programming, hence JAX arrays are immutable. This may require significant amount of changes while transitioning from NumPy code to JAX code.

In this chapter, we will go through a number of examples explaining the similarities and differences.

Importing

JAX includes several libraries. The one that closely resembles NumPy is jax.numpy. Just like we use the shorthand np for numpy during the import numpy as np as a convention, the convention for importing jax.numpy is:

import jax.numpy as jnp

1-D vectors

z = jnp.zeros(4)
print(z)
[0. 0. 0. 0.]
print(z.dtype, z.shape)
float32 (4,)
jnp.ones(4)
DeviceArray([1., 1., 1., 1.], dtype=float32)
jnp.empty(4)
DeviceArray([0., 0., 0., 0.], dtype=float32)
jnp.ones(4, dtype=int)
DeviceArray([1, 1, 1, 1], dtype=int32)

A range of integers

a = jnp.arange(5)
print(a)
[0 1 2 3 4]
print(a.dtype, a.shape)
int32 (5,)
# start and stop
jnp.arange(2,8)
DeviceArray([2, 3, 4, 5, 6, 7], dtype=int32)
# start, stop and step size
jnp.arange(2,8, 2)
DeviceArray([2, 4, 6], dtype=int32)

Linearly spaced values

jnp.linspace(0, 1, num=5)
DeviceArray([0.  , 0.25, 0.5 , 0.75, 1.  ], dtype=float32)
# excluding the endpoint.  
jnp.linspace(0, 1, num=5, endpoint=False)
DeviceArray([0. , 0.2, 0.4, 0.6, 0.8], dtype=float32)

Boolean vectors

jnp.ones(4, dtype=bool)
DeviceArray([ True,  True,  True,  True], dtype=bool)
jnp.zeros(4, dtype=bool)
DeviceArray([False, False, False, False], dtype=bool)

2-D Matrices

jnp.zeros((4,4))
DeviceArray([[0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.]], dtype=float32)
jnp.ones((4,4))
DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32)

64 Bit Support

By default, JAX works with 32-bit integers and floating point numbers. All calculations are in 32-bit. If you need 64-bit integers and floats, in your calculations, you need to explicitly enable the support.

It is recommended that you enable 64-bit support at the beginning of your program. You shouldn’t switch this parameter in between.

# enabling 64-bit support
from jax.config import config
config.update("jax_enable_x64", True)
jnp.ones(4, dtype=jnp.int64)
DeviceArray([1, 1, 1, 1], dtype=int64)