Neural networks with JAX#
Working with neural networks in JAX is quite different from that in PyTorch. Here is a tutorial of training a simple neural network in pure JAX.
Libraries are available to make working with neural networks in JAX easier. Examples include Equinox and Flax. In this course, we recommend using Equinox and here is a tutorial on training a convolutional neural network on the MNIST dataset using Equinox.