Exploring Differentiable Programming
Contents
Exploring Differentiable Programming#
I have read a bunch about differentiable programming. So I sort-of know what I’m talking about in broad strokes. However, I don’t really know what it takes. How much data do you have to move around the system? What are the actual operations. How do you make a selection cut that is differentiable. How do you use JAX?
This book is me trying to teach myself step-by-step. So it is very basic! Comments and pull requests are welcome at the github
repo!
Goals#
Construct a simple \(S/\sqrt{B}\) problem that needs to be optimized - a single selection cut, and a single signal and background. Solve it by brute force.
Learn the basics of
JAX
as anumpy
replacement.Figure out how to make a hard selection cut (
data[data > cut]
) differentiable w.r.tcut
.Write a very simple gradient decent loop using
JAX
tools to solve this problem.Compare \(S/\sqrt{B}\) to using a more standard ML loss function.
Explore how predicate push-down into a system that does not understand auto-diff might work with differentiable programming.