Differentiate this for me

Simple ideas in math usually tend to be the go-to resource for harder problems. In the case of optimization –as in machine learning–, the idea of gradient descent has huge relevance, as it is the intrinsic method to find the minimum of a function –usually called the loss function in the context of machine learning– by generating a sequence of points that converge to the minimum, provided that we displace the points in a descent direction, which is usually related to the gradient of the function. But in order to compute the gradient, we need to differentiate the function.

When the functions are simple enough, we can differentiate them by hand: there is nothing new in performing the chain rule in a sequence of calculations to reduce the problem to a sequence –a product– of simpler functions that we know how to differentiate. It turns out that neural networks are –mostly– iterated compositions of simple functions (as activation functions and affine transformations), so we can differentiate them very easily by hand. If we can do it by hand, we can also do it by computer...

In Higham, C.F., and Higham, D. J. (2019), the authors describe how to update the parameters of a neural network by computing the derivative of the loss function with respect to the parameters, and it's mainly a matter of applying the chain rule!

Running down the hill

Another aspect[1] I don't want to delve into is the descent algorithm itself. Gradient descent is nice, but there are many other algorithms that can be used to find the minimum of a function. Usually, the current framework is to minimize

minEft(x)\min \mathbb{E} f_t(x)

where ftf_t is a function that depends on the parameters xx and the index tt. The main goal now is to minimize the regret of the algorithm, which is defined as

Regret=t=1Tft(xt)minxt=1Tft(x).\mathsf{Regret} = \sum_{t=1}^T f_t(x_t) - \min_x \sum_{t=1}^T f_t(x).

A great summary and explanation of how the gradients are used in a cummulative way is described in Reddi, Sashank J., Satyen Kale, and Sanjiv Kumar (2019).

Roadmap to FFI4JAX

I like to learn by example, so while reading about JAX, I ran into the foreign function interface (FFI) for JAX, which allows us to call C++ functions from JAX code. I already knew about the Enzyme AD project, which is a LLVM-based automatic differentiation tool that can be used to differentiate C++ code, so I thought it would be a good idea to combine the two of them!

It is possible to write complex C++ functions that can be automatically differentiated by Enzyme, wrap them in with the XLA FFI interface, compile them into a module with Nanobind, and then call them from JAX code.

Right now, after fixing some CMake impasses, I have some working examples in a private repository. This feels quite promising and it is a good excuse to learn more about JAX, Enzyme, C++, LLVM, Nanobind, FFI, and maybe some math as well!


[1] See this video...
CC BY-SA 4.0 Gabriel Pinochet-Soto. Last modified: July 04, 2025. Website built with Franklin.jl and the Julia programming language.