Why is numerical analysis still a big issue?
Hint: it's the speed.
It’s a Saturday during the start (middle?) of perf season; so like many others I’m procrastinating. My manager and reports probably won’t appreciate my spending an hour writing this. Nevertheless there was an interesting thread on hacker news about NSan; a numerical sanitizer you can enable in LLVM when running your unit tests. I haven’t read it… the idea doesn’t seem to be new, though the fact that it’s fast and may end up in LLVM, is awesome.
I commented on the thread and was hoping to expand on that comment. ggm followed up mentioning the great care with which the NAG library maintainers took to make sure their routines were extremely correct:
If they stuffed it up, bridges fell down during construction or rockets blew up, and they knew it.
This is absolutely true, and folks doing mechanical engineering, working with EDAs, solving PDEs, etc, absolutely need to be confident that the numerical libraries they use are correct up to machine precision. They’re then responsible for combining those calculations in loops, data structures, etc in such a way that precision is not reduced below their required tolerances (this is something like “business logic” interacting with numerical computing).
But Deep Learning folks don’t care. At least not until it bites them in the butt. And in my extremely limited view of the software engineering industry, DL is the hottest and most important area. Right? I mean, just look at how much ML researchers make!
Well, obviously some deep learning folks care. Big shout out to the TFP team and their pytorch counterparts1 for (almost) always caring about numerical precision. A bunch of TFP folks even submit HLO implementations2 for special functions so that their code can be XLA jitted as well as run on TPUs. And thanks to folks like the ruy developers for making sure TFLite matmuls are accurate on edge devices. Here you’re seeing people who both care about numerical precision, performance, and neural networks. One can only hope they get paid handsomely for critical ML infra work that (I’m guessing) would be boring for most CS majors. I certainly care; I spent a lot of dev cycles porting various cephes special functions to Eigen3 so we could expose it in TF on CPUs and GPUs and using it in a library that eventually became TFP.
But most don’t care. The primary thing they care about is speed.
“Theoretical” NN researchers who want to build new NN architectures or describe and solve interesting problems. They don’t usually need more than a few bits of precision in their output. For example, Mixed Precision Training from Baidu+NVIDIA (convincingly) made the case that you can get comparable model quality using mostly bfloat164. These researchers care about the time it takes them to train a new model and compare it to a baseline (and training bigger models using the same GPUs). In other words; lower precision is a feature, not a bug. And if a model doesn’t work, they have a big pool of creative modifications they can use that implicitly correct for bad precision. Residual networks, various normalization layers, mixture-of-experts, Adam, dropout, weight decay, … all of these tricks probably help to correct for some numerical precision issues. Or just work around them. We’re still in the golden age of graduate student gradient descent5, where neat architecture ideas can make a big difference in model quality, grow the field, and (last but not least) help you publish your next paper.
“More applied” researchers who primarily focus on bringing NN models to real products6. Speed is also just as important here. Often-times, product needs focus on classification, ranking, regression, and contextual bandits being good enough to improve the current metric. How many bits of precision do you need for your classifier? Hint: just quantize better and you can often do it all in int8. Maybe after distilling the big model. You have lots of data. You care more about being able to train on the data firehouse before the next model gets pushed out in 6, 12, 24 hours. And long-tail embedding distributions. And making inference work within budget using CPUs, GPUs, TPUs, and ScaNN.
Research organization leads care about two things7: the speed with which their researchers can experiment with ideas (and help client organizations get research into products), and “OMG how much am I gonna have to pay for all those new GPUs?”8
There you have it. Hopefully I’ve convinced you that speed takes priority over numerical accuracy for most folks working in the new hotness, and that this is probably just fine.
That said, with the heavy need for numerical computing, NNs, etc to move towards SIMD (and its beefier cousins on GPU and TPU), the lack of comprehensive work towards keeping computations accurate to machine epsilon creates all sorts of problems. Let’s take Softplus and XLA as an example. Just a couple of PRs over the years:
(there are a couple of others; this was a quick search). You can bet there were a couple of internal and external bugs filed that led to each of these fixes. How many cumulative SWE-hours/days did it take to go from “This model didn’t port well to bfloat16” slash “Why does is the output wonky when I move to JAX/enable JIT compilation (god I just want it to be fast!)” to get to this fix?
The trade-off is real. I recently heard from an applied contextual bandits team telling the TFP folks that the LambertW implementation blows up and slows down their XLA CPU-jitted code. The current implementation uses a broadcasted Halley’s method to iterate the value until convergence. My initial thought was to just set maximum_iterations of 2 or 4 for the loop. After all, Halley’s method can probably get you within machine epsilon in 2-4 iterations. Well, as it turns out even with 1 while_loop iteration the XLA HLO balloons. Maybe it’s because the stopping criterion somehow gets compiled wonkily? The only way is to debug the HLO, and that’s above my pay grade.
Maybe (probably) NSan will be part of the solution. Maybe we should think about other interesting approaches that will get the NN researchers excited? For example, what if we could write better graph transformations, or MLIR passes that make your code faster while functionally equivalent AND preserve numerical stability? Or just do a better job of pruning dead numerical branches? Or maybe we can use ML? We’re already exploring automated graph rewriting a little bit for ML compilers, and in the not-too-distant future we may want to look into this for LLVM.
All we have do is convince the org leads…
Thanks to Ed for reviewing this post and pointing out the Mixed Precision Training paper.
Just guessing here.
Yes, this PR was submitted by a SWE on the TensorFlow probability team.
Huge thanks to the original cephes author, Stephen L. Moshier, who gave us permission to make the port!
With float32 used in a few key accumulators.
Not limited to graduate students; remember Transformer?
When they’re thinking about numerical precision in their copious free time…
Money = Time, right?