Introduction

This week I wanted to try out JAX, a GPU accelerated implementation of numpy designed to improve performance significantly for complex computations. I use numpy at work and I was interested in if JAX could potentially be implemented in some of the computer vision projects we have for better performance.

Writing A Benchmark

I wanted to try out JAX with a quick benchmark, so I wrote the following code (based off of this code)

import numpy as np
import timeit

def benchmark():
    size = 1024
    A, B = np.random.random((size, size)), np.random.random((size, size))

    N = 20

    for i in range(N):
        np.dot(A, B)

time = timeit.timeit(benchmark, number=50)

print(time)

Porting The Benchmark To JAX

To port the benchmark to JAX I had to rely on the original numpy for some of the code:

import jax.numpy as np
import numpy as oldnp
import timeit

def benchmark():
    size = 1024
    A, B = oldnp.random.random((size, size)), oldnp.random.random((size, size))

    N = 20

    for i in range(N):
        np.dot(A, B)

time = timeit.timeit(benchmark, number=50)
print(time)

The Results

JAX was amazingly fast!

The results in seconds:

> python numpy_cpu.py
27.300557171000037
> python numpy_gpu.py
12.697763189000398

Conclusion

I am amazed by JAX, and it will likely be seeing use in my future projects. I expect that since it’s niche it might be best to wrap it in a guard with the original numpy, though this might be an issue if it makes too many checks for JAX’s missing APIs in the code:

# Attempt to import GPU numpy (JAX), if not present use CPU operations
try:
    import jax.numpy as np
except ModuleImportError:
    import numpy as np