Motivation: This repo was originally a series of lessons introducing Jax concepts, and it can still be used for that. However, it has also advanced beyond that to include example code solving various real-world problems. These case-studies themselves can be used as a follow-up lesson for those who have finished the first lesson.
Blurb: Jax is often thought of as Numpy for the GPU, but it is so much more (both in terms of features, and sharp edges). The tutorials presented here—one aimed at a general audience and the other at computational neuroscientists—were inspired by a roadblock I encountered in my research. Specifically, I was working on a LIF simulation problem that, despite using vectorized Numpy, took excessively long to run. By incorporating Jax into my workflow and iterating on it, I managed to reduce the runtime from ~10 seconds to ~0.2 seconds.
The exercises
folder contains the code structured as a series of exercises for you to work through to reinforce the concepts.
-
Using
jit
-
Understanding when to use jit a.k.a why not jit everything?
-
Timing
jax
-
reading haskell-like function signatures
-
fori_loop
,while_loop
,scan
- make your code look more like the math described in the papers
- in prior notebooks we had introduced methods to speed up code, and the JIT compilation. Let's investigate if and how much they speed up code!
Einsum isn't specific to Jax, but it's still useful to know!
Case studies build on the exercises and rely on concepts covered in the lessons. In the case studies we see the concepts applied to real-world problems.
- Randomness in JAX
- reproducible randomness across machines across accelerators
- pmap
- Introduction to parallel execution with examples using multiple devices. Having said that, this should still work even with a single device
- grad
- and grad_and_val to get the gradient
- GMM Advanced
- Add in
rng
andpmap
to a separate notebook
- Add in
If you use this software in your research, please cite it as follows:
@misc{numpy_to_jax,
title = {Numpy To Jax},
author = {Ian Quah},
year = {2024},
url = {https://github.com/IanQS/numpy_to_jax},
version = {1.0.0},
note = {Jax is often thought of as Numpy for the GPU, but it is so much more (both in terms of features, and sharp edges). The tutorials presented here—one aimed at a general audience and the other at computational neuroscientists—were inspired by a roadblock I encountered in my research}
}