How much is the flash attention algorithm tied to the hardware? For example, in this announcement they mention taking advantage of the async capabilities of the H100 GPUs which I assume means you won't get those speedups on non H series card. Two, the actual flash attention library requires CUDA, although the algorithm has apparently?[^0] been ported to metal. I would imagine if the algorithm was literally just a pure function it could be implemented for any GPU/ML framework?
Looking at the docs, in reality, most of the time you want this to call out to FA2 which optimizes the kernals on the device to split ops on the Softmax of the triangular matrix as well as reduce moving unnecessary batches of floating point numbers back and forth from the GPU to the CPU.
FlashAttention's algorithmic improvements is mostly just splitting/combining the softmax part of attention, and is itself not totally novel. The overwhelming contribution is implementing that, and all its fiddly pieces, efficiently on Nvidia hardware.
To clarify further, flash attention is explicitly targeting a compute engine with separate MMA and "scalar" vector execution units that allow post-processing the MMA outputs without involving memory bandwidth (though arithmetic intensity, especially relative between the MMA and the "scalar" instructions, is of concern), with a substantial amount of manually-managed L1D$ to use as sub-matrix accumulator, and a linear-in-context-length amount of "VRAM" that requires sensible arithmetic intensity to avoid being a bandwidth bottleneck (iirc in the hundreds when counting the scalar multiplies hiding in the MMA instructions).
This v3 with async might for once be so tied to Hopper that it's not trivially portable to another platform that has the mentioned hardware blocks (AFAIK every AMD GCN card that can do compute shaders would qualify, though they do lack a specialized MMA unit).
Given the question: "How much is the flash attention algorithm tied to the hardware?"
The answer is 0.
ex. you can find generic flash attention recently added in llama.cpp and ONNX (MS needed it for Phi-3, needed for Recall).
On the side, novelty, I have no direct knowledge on, IMHO, asking that question would devolve the way novelty arguments do in any field: there's always someone else who can claim they did 80% of $X via $X-1, therefore, $X is by and large not novel. Ad infinitum.
I think the right analogy for FA is high-quality cache-aware BLAS kernel implementations. The algorithm(s) is (are) clever and (as you note) completely independent of hardware. However, a hardware-naive implementation is approximately worthless. Most of the value of MKL, or Accelerate, or FA is in the careful matching of the parameters and implementation of the algorithm to the capabilities of hardware it's going run on.
I definitely don't mean to take away from Tri/FA by mentioning novelty - I'm just repeating from paper, which refers back to algebraic aggregates[0] in its discussion of their tiled softmax.
> However, a hardware-naive implementation is approximately worthless.
This isn’t true when there is one vendor that’s 90% of the market and 2 maybe 3 generations of hardware to consider. Support A100, H100 and you are supporting most of the current market.
> How much is the flash attention algorithm tied to the hardware?
The original FA, almost none.
For the latest versions depends on your abstraction, ThunderKittens[0] provides about the same speed up over FA2 (1.3x-2x%) as the article but relatively universal across GPUs. For any new hardware there may be hardware specific features that make it edge out more performance; usually vendors will adopt any new features that seems to beat them, but you do get fragmented API/libraries (which is already true for CUDA).
What do you mean by "relatively universal"? This is Cuda only [0] with a promise of a rocm backend eventually. There's only one project I'm aware of that seriously tries to address the Cuda issue in ml [1].
I mean they're building an API to abstract away some of the SKU-to-SKU differences, but the broader point cuts the other way, I think:
> In fact, more broadly we believe we should really reorient our ideas of AI around what maps well onto the hardware. How big should a recurrent state be? As big can fit onto an SM. How dense should the compute be? No less so than what the hardware demands. An important future direction of this work for us is to use our learnings about the hardware to help us design the AI to match.
The value is in adapting the implementation (either manually at write-time or programmatically at run-time) to the specifics of the hardware.
Also, great line:
> And we ask: if your matrix multiply is smaller than 16x16, are you sure what you’re doing is AI?
Conceptually, just a bit, practically (in terms of implementation), a lot. The standard python implementation internally compiles a kernel for your specific hardware.
To add to the discussion, from a practical perspective, AMD hardware totally sucks and yet to have proper implementation with flash-attention-2. ROCm is moving to usable slowly, but not close to being even comparable with cuda.
[0]: https://github.com/philipturner/metal-flash-attention