What is the appropriate way to load data for the recent AVX-VNNI and Arm Neon MMLA instructions?
For example, the description of SMMLA is:
Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies the 2×8 matrix of signed 8-bit integer values in the first source vector by the 8×2 matrix of signed 8-bit integer values in the second source vector. The resulting 2×2 32-bit integer matrix […]
Similarly, the description for _mm256_dpbusd_epi32
is:
Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding signed 8-bit integers in b, producing 4 intermediate signed 16-bit results. Sum these 4 results with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst.
It seems that they all require inputs of the form 2[4]x8 and 8x[4]2. and produce outputs of the form 2[4]x[4]2. How can I efficiently load and store data for these functions?
I see three broad possibilities to use these instructions, none appealing:
- [Split and Combine] I load two consecutive 128-bit vectors and then split them. Similarly, for AVX, I would load 4 128 or 256 vectors and then split them. Storing is equally “complicated” since I need to extract the relevant parts of the 2[4]x[4]2 matrix before storing it. My code is cluttered with splitting/merging instructions.
- [Smaller Vectors] Alternatively, I could load smaller portions, but that seems inefficient too.
- [Reorder Input Data] Of course, I could reorder the input data so that the vectorized loads already span multiple rows or columns. Should that be the intended use?
An example code for the inner loop (reduction over K) of a small 4xK input matrix A (row-major) and a Kx4 matrix B (column-major) is as follows:
for (size_t k = 0; k < 64; k += 8) {
uint8x8_t low = vld1_u8(row0);
uint8x8_t high = vld1_u8(row1);
uint8x16_t row01x01234567 = vcombine_u8(low, high);
row0 += 8;
row1 += 8;
low = vld1_u8(row2);
high = vld1_u8(row3);
uint8x16_t row23x01234567 = vcombine_u8(low, high);
row2 += 8;
row3 += 8;
low = vld1_u8(col0);
high = vld1_u8(col1);
uint8x16_t col01x01234567 = vcombine_u8(low, high);
col0 += 8;
col1 += 8;
low = vld1_u8(col2);
high = vld1_u8(col3);
uint8x16_t col23x01234567 = vcombine_u8(low, high);
col2 += 8;
col3 += 8;
out01x01 = vmmlaq_u32(out01x01, row01x01234567, col01x01234567);
out01x23 = vmmlaq_u32(out01x23, row01x01234567, col23x01234567);
out23x01 = vmmlaq_u32(out23x01, row23x01234567, col01x01234567);
out23x23 = vmmlaq_u32(out23x23, row23x01234567, col23x01234567);
}
The result is correct, but seems terribly inefficient. The code above is just an example. I actually would use larger tile sizes to maximize register usage.
2