Skip to content

Instantly share code, notes, and snippets.

@imadmali
Created June 21, 2021 19:52
Show Gist options
  • Select an option

  • Save imadmali/9751fc4550b427f1d9ca0d84ab29ecdd to your computer and use it in GitHub Desktop.

Select an option

Save imadmali/9751fc4550b427f1d9ca0d84ab29ecdd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# JAX is Fast\n",
"Author: Imad Ali \n",
"Date: 2020-06-20\n",
"\n",
"---\n",
"\n",
"[JAX](https://jax.readthedocs.io/en/latest/index.html) can speed up your Python code. This notebook explores different combinations of a base function, jitted function, base loop, and vmap to illustrate that efficiency. Here's a summary of the results that shows a clear improvement in performance when using jit and vmap in your code.\n",
"\n",
"| loop | jit | vmap | time |\n",
"| :----: | :----: | :----: | :----: |\n",
"| x | | | ~ 39.7 ms |\n",
"| x | x | | ~ 6.41 ms |\n",
"| | | x | ~ 1.42 ms |\n",
"| | x | x | ~ 0.419 ms |\n",
"\n",
"\n",
"Here are the docs associated with the JAX functions used in this notebook.\n",
"* `jnp`: Numpy API in JAX. https://jax.readthedocs.io/en/latest/jax.numpy.html?highlight=jax.numpy#module-jax.numpy.\n",
"* `jit`: Just-in-time compilation with XLA. https://jax.readthedocs.io/en/latest/jax.html?highlight=jit#jax.jit.\n",
"* `vmap`: Vectorizing functions by mapping function over axes. https://jax.readthedocs.io/en/latest/jax.html?highlight=vmap#jax.vmap."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from jax import vmap, jit\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate some data to work with."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"v = np.random.normal(size=(100,5))\n",
"w = np.random.normal(size=(5))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can define any function that can be applied iteratively to data. The function defined here takes two 1D arrays of the same size and returns the sum of the element wise product of the two arrays."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def f(a,b):\n",
" return(jnp.sum(jnp.multiply(a,b)))\n",
"\n",
"jit_f = jit(f) # you can use the @jit decorator instead\n",
"batch_f = vmap(f, in_axes=(0, None))\n",
"batch_jit_f = vmap(jit_f, in_axes=(0, None))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Test base function and the jitted function."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"base function: 0.7014111\n",
"jitted function: 0.701411\n"
]
}
],
"source": [
"print('base function:', f(v[0],w))\n",
"print('jitted function:', jit_f(v[0],w))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## base loop + base function\n",
"Loop over each element in `v` and apply the base function."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"39.7 ms ± 336 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit -n 100 -r 10\n",
"result = []\n",
"for i in range(len(v)):\n",
" result.append(f(v[i],w))\n",
"result = np.array(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## base loop + jit\n",
"Loop over each element in `v` and apply the jitted function."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6.41 ms ± 80.3 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit -n 100 -r 10\n",
"result = []\n",
"for i in range(len(v)):\n",
" result.append(jit_f(v[i],w))\n",
"result = np.array(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## vmap + base function\n",
"Use vmap to map the base function over specified axes.\n",
"In this case only the first axis of `v`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.42 ms ± 117 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit -n 100 -r 10\n",
"result = batch_f(v, w)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## vmap + jit\n",
"Use vmap to map the jitted function over specified axes.\n",
"In this case only the first axis of `v`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"419 µs ± 67.3 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit -n 100 -r 10\n",
"result = batch_jit_f(v, w)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment