I’ve been playing around with cuBLAS’s strided batch matrix-matrix multiplication and I was suprised to find that for small matrix sizes, enabling the transposition flag drastically reduces performance for small matrix sizes.
I’ve attached some code to demonstrate my findings. Apologies that it is in Julia using CUDA.jl—I assume that it is still perfectly understanding to a C++ CUDA developer. The code snippet compares:
- gemm without transposition
- gemm with transposition of B
- manually transposing B using a custom kernel and then using gemm with no transposition
using BenchmarkTools
using CUDA
N = 1000
D = 3
A = CUDA.randn(D, D, N)
B = CUDA.randn(D, D, N)
# Simple kernel to transpose the first two dimensions of a 3D array
function batch_transpose!(B, B_T)
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
stride = blockDim().x * gridDim().x
for i in index:stride:size(B, 3)
for j in 1:size(B, 2)
for k in 1:size(B, 1)
B_T[k, j, i] = B[j, k, i]
end
end
end
end
# Compute A * B_T by first transposing B explicitly
function transpose_first(A, B)
B_T = CuArray{Float32}(undef, size(B, 2), size(B, 1), size(B, 3))
@cuda threads = 256 blocks = ceil(Int, size(B, 3) / 256) batch_transpose!(B, B_T)
return CUDA.CUBLAS.gemm_strided_batched('N', 'N', A, B_T)
end
# Benchmark the three approaches
bench1 = @benchmark CUDA.@sync CUDA.CUBLAS.gemm_strided_batched('N', 'N', $A,$B)
bench2 = @benchmark CUDA.@sync CUDA.CUBLAS.gemm_strided_batched('N', 'T', $A, $B)
bench3 = @benchmark CUDA.@sync transpose_first($A, $B)
# Print the results
println("Non-transposed gemm: $(median(bench1.times) / 1e6)")
println("Transposed gemm: $(median(bench2.times) / 1e6)")
println("Transposed manually: $(median(bench3.times) / 1e6)")
Benchmarks ran on an RTX 4090
When D=3, the output is
Non-transposed gemm: 9.854 μs
Transposed gemm : 25.248 μs
Transposed manually: 14.011 μs
That is, gemm with transposition is much slower than without as is comfortably beaten by my thrown-together kernel for directly transposing the batch first.
When D=30, the output is
Non-transposed gemm: 38.468 μs
Transposed gemm : 38.357 μs
Transposed manually: 309.046 μs
The gap between non-transposed and transposed gemm has closed and by hacky approach is slow (as expected).
I’m surprised that for an operation as common as batch multiplication with a transposition cuBLAS performs so poorly.
Is there any intuition as to why this is the case?