Post

Benchmarking JAX, A GPU Accelerated Version Of Numpy

Measuring the improvements gained from switching a novel benchmark to using the GPU

Introduction

This week I wanted to try out JAX, a GPU accelerated implementation of numpy designed to improve performance significantly for complex computations. I use numpy at work and I was interested in if JAX could potentially be implemented in some of the computer vision projects we have for better performance.

Writing A Benchmark

I wanted to try out JAX with a quick benchmark, so I wrote the following code (based off of this code)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np
import timeit

def benchmark():
    size = 1024
    A, B = np.random.random((size, size)), np.random.random((size, size))

    N = 20

    for i in range(N):
        np.dot(A, B)

time = timeit.timeit(benchmark, number=50)

print(time)

Porting The Benchmark To JAX

To port the benchmark to JAX I had to rely on the original numpy for some of the code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import jax.numpy as np
import numpy as oldnp
import timeit

def benchmark():
    size = 1024
    A, B = oldnp.random.random((size, size)), oldnp.random.random((size, size))

    N = 20

    for i in range(N):
        np.dot(A, B)

time = timeit.timeit(benchmark, number=50)
print(time)

The Results

JAX was amazingly fast!

The results in seconds:

1
2
> python numpy_cpu.py
27.300557171000037
1
2
> python numpy_gpu.py
12.697763189000398

Conclusion

I am amazed by JAX, and it will likely be seeing use in my future projects. I expect that since it’s niche it might be best to wrap it in a guard with the original numpy, though this might be an issue if it makes too many checks for JAX’s missing APIs in the code:

1
2
3
4
5
# Attempt to import GPU numpy (JAX), if not present use CPU operations
try:
    import jax.numpy as np
except ModuleImportError:
    import numpy as np
This post is licensed under CC BY 4.0 by the author.