I wanted to get too fancy and I tried
* LoopVectorization.jl - @turbo choked on the loop
* a direct llvmcall to use AVX512 pop count - I malformed the types for the instruction
* Defining the `db` as
db = [rand(Int8) for _ in 1:64, j in 1:(10^6)];
to avoid the vec of vecs structure,
and then
function my_cluster!(db, query, k)
db .= query .⊻ db
popcounts = mapreduce(count_ones, +, db, dims = 1)
results = reshape(popcounts, last(size(db)))
partialsortperm!(results, results, k)
@views results[begin:k]
end
...which I couldn't get to be faster than your version. If you use the `partialsortperm!` and reuse the same cache array, I suspect you'll get good speedups, as you won't be sorting the array every time. This is a classic `nth_element` algorithm.
The above is not the most amazing code, but I suspect the lack of indexing will make it ridiculously friendly for a GPU (Edit: Nope, it chokes on `partialsortperm!`).
I'm guessing the manual loopy approach should be just as good but I battled hard to get it somewhat competitive here in 6 lines of code
#@be my_cluster!(X2, q1, 5)
Benchmark: 3 samples with 1 evaluation
42.883 ms (17 allocs: 15.259 MiB)
45.711 ms (17 allocs: 15.259 MiB)
46.670 ms (17 allocs: 15.259 MiB)
#@be k_closest(X1, q1, 5)
Benchmark: 4 samples with 1 evaluation
27.994 ms (2 allocs: 176 bytes)
28.733 ms (2 allocs: 176 bytes)
29.000 ms (2 allocs: 176 bytes)
30.709 ms (2 allocs: 176 bytes)
I also didn't try using `FixedSizedArrays.jl` as Mose Giordano recommended in my livestream chat.
Under the hood it’s doing the same thing with a vector of ints (64 bits for bitvectors) and all the bulk manipulation is handled that way so SIMD in inherent as well. Worth a shot.
There is, it's called count_ones. Though I wouldn't be surprised if LLVM could maybe optimize some of these loops into a popcnt, but I'm sure it would be brittle
I think you may need to update the figures in the rest of the article. At some point you mention it should take around 128ns but with the new benchmark that's probably closer to 64*1.25=80ns.
For those like me who are not familiar with the field... The article assumes you know the entire context - as far as I could see there is no explanation of any part except the technical details.
RAG = Retrieval-Augmented Generation
The field is machine learning. Retrieval = get relevant documents. Generation = create answer for user (based on the docs).
I'm not sure what is meant by "exact" here - do they describe their binarisation process at all? This seems more like an XOR benchmark than a rag benchmark, no mention of recall or other relevant performance metrics
Some (not all) of your questions may be answered by the linked article near the top of the submitted article, which goes into more detail about how much is lost quantizing to 1 bit (and 1 byte): https://huggingface.co/blog/embedding-quantization
exact in this case means that all the vectors are compared against the query vector. Where as other search methods such as HNSW are approximate searches.
I used ```hamming_bitwise(x::Union{UInt, UInt128, BigInt}, y::Union{UInt, UInt128, BigInt}) = count_ones(x ⊻ y)``` to get a fast hamming distance with "binary vectors" encoded as ints.
I loved this post <3 The hamming distance is one of my favorite demos of the conciseness of Julia:
hamming_distance(s1, s2) = mapreduce(!=, +, s1, s2)
I'm a bit swamped at the moment but I'll a response article later - they're still some juicy perf on the table here.
Thanks for the post, such a good showcase.