An Introduction to Neural Ordinary Differential Equations [pdf]
diposit.ub.edu78 points by gballan 3 days ago
78 points by gballan 3 days ago
looks like a nice overview. i’ve implemented neural ODEs in Jax for low dimensional problems and it works well, but I keep looking for a good, fast, CPU-first implementation that is good for models that fit in cache and don’t require a GPU or big Torch/TF machinery.
Did you use https://github.com/patrick-kidger/diffrax ?
JAX Talk: Diffrax https://www.youtube.com/watch?v=Jy5Jw8hNiAQ
Anecdotally, I used diffrax (and equinox) throughout last year after jumping between a few differential equation solvers in Python, for a project based on Dynamic Field Theory [1]. I only scratched the surface, but so far, it's been a pleasure to use, and it's quite fast. It also introduced me to equinox [2], by the same author, which I'm using to get the JAX-friendly equivalent of dataclasses.
`vmap`-able differential equation solving is really cool.
[1]: https://dynamicfieldtheory.org/ [2]: https://github.com/patrick-kidger/equinox
Thanks, that looks neat.
Kidger's thesis is wonderful https://arxiv.org/abs/2202.02435