Hello everyone,
After recreating the accuracy/rough speed from David Page's implementation in hlb-CIFAR10 0.1.0 (18.1s on an A100, SXM4, Colab), it was down to some basic NVIDIA kernel profiling to figure out which operations were the long poles in the tent. Perhaps (somewhat?) unsurprisingly, the NCHW <-> NHWC thrash was the worst part, but unfortunately the GhostBatchNorm was a barrier even using the faster-on-Ampere channels_last memory format.
A quick note before continuing -- some may find the use of a convolutional network and on CIFAR10 to be curious. A quick answer to that would be that in doing the research that optimizes well-known problems (especially if the testing path is incredibly rapid), we get much clearer pictures of what certain fundamental information learning limits are for systems like this, as well as stable prototypes that can then be translated (potentially somewhat analogously) into other modalities. You can see this practice with a few researchers, Hinton comes to mind though his work is much more fundamental and experimental than this is. Back to the release notes.
Ultimately, however, we were able to get a similar level of regularization to the original GhostBatchNorm (called GhostNorm) in the code, which allowed us to remove it and a bunch of tensor allocation/contiguous tensor calls, saving us nearly exactly 5 seconds or so (!!!!).
Replacing the call for nn.AdaptiveMaxPooling(1,1) with a torch.amax(dim=2,3), added an additional .5 seconds off the clock, bringing us down below Thomas Germer (@99991)'s excellently quick implementation of the same base method (https://github.com/99991/cifar10-fast-simple) and giving us the new world record.
This work is pretty simple on its own -- though the various ways to use the nvidia profiler(s) can be very daunting to use and I can post snippets of the simplest way that I've found (via the torch.profiler route) if someone asks/is curious. That said, looking at kernel execution order and times can really and truly do a lot to quickly improve a network in conjunction with good research engineering practices.
This is what I'm pretty good at doing so getting to flex a bit on a spare time project is fun. I'm consistently storing up time saves into a draft bin of sorts and plan on keeping releasing them in related/clustered releases as I'm able to appropriately polish them to whatever their capabilities seem to be. There is a lot of room to grow, and I think we now definitely have a good chance at making it within that 94% accuracy under ~<2s mark within a few years!
This work is meant to be a living resume for me, feel free to check out my README.md for more info. I love a lot of aspects of the technical/nitty gritty side of the fusion of neural network engineering and the edge of research, particularly when it comes to speed, so this is my strong area. I'm certainly happy to answer whatever reasonable questions anyone might have, let me help with getting this project going for you (or other related stuff -- feel free to ask! <3 :)))) )
I imagine this means that the entire training set was ingested in 12.4 sec.
[0] https://en.wikipedia.org/wiki/CIFAR-10