I’m currently writing an arbitrary precision library for fun (yes, I know that GMP exists) and I’m using it to generate large prime numbers (~5000 digits) with the Miller-Rabin primality test. I profiled the code and noticed that around 30 percent of the time was spend comparing two numbers as part of the modulo operation. This is what that part of the code looks like:
struct number {
// maxDigits is a compile-time constant
uint32_t digits[maxDigits];
...
constexpr bool number::operator>=(const number& b) const {
for (size_t i = maxDigits - 1; i != static_cast<size_t>(-1); --i) {
if (const auto x = digits[i], y = b.digits[i]; x != y) return x > y;
}
return true;
}
}
This is the simplest way to do this, but I imagine there are some ways to do this with fewer branches or maybe by somehow using SIMD instructions to compare multiple digits at once. I’ve tried making the loop’s upper bound be the exact length of the number instead of the maximum size, but that actually made it slower (maybe it couldn’t unroll the loop?). This seems like the lowest hanging fruit to optimize without changing any of the algorithms used but I can’t seem to make it more efficient. Any ideas would be greatly appreciated.
10
Here is a little experiment on godbolt:
https://godbolt.org/z/T6x6vaTPq
It begins with a few definitions:
static const int maxDigits = 1000;
typedef uint32_t digit;
typedef int8_t byte;
Then, your code is in function cmp1()
:
bool cmp1( digit* digits1, digit* digits2 )
{
for (size_t i = maxDigits - 1; i != static_cast<size_t>(-1); --i)
if (const auto x = digits1[i], y = digits2[i]; x != y) return x > y;
return true;
}
which translates to:
mov eax, 999
jmp .L4
.L2:
sub rax, 1
jb .L7
.L4:
mov edx, DWORD PTR [rdi+rax*4]
mov ecx, DWORD PTR [rsi+rax*4]
cmp edx, ecx
je .L2
cmp ecx, edx
setb al
ret
.L7:
mov eax, 1
ret
Note that the 2nd cmp edx, exc
instruction is redundant, since the je .L2
instruction does not modify the flags register; this is probably a flaw in gcc
and proves that if you are chasing clock cycles, you can do a bit better if you tweak the assembly by hand. But luckily, we are not chasing clock-cycles, usually.
Function cmp2()
is the same function, just slightly more readable:
bool cmp2( digit* digits1, digit* digits2 )
{
for( int i = maxDigits - 1; i >= 0; i-- )
{
const auto x = digits1[i];
const auto y = digits2[i];
if( x != y )
return x > y;
}
return true;
}
This one translates to the following:
mov eax, 3996
jmp .L11
.L9:
sub rax, 4
cmp rax, -4
je .L13
.L11:
mov edx, DWORD PTR [rdi+rax]
mov ecx, DWORD PTR [rsi+rax]
cmp edx, ecx
je .L9
cmp ecx, edx
setb al
ret
.L13:
mov eax, 1
ret
As you can see, the loop is longer by one instruction, but then two of the instructions use shorter forms; so, I do not know which one is more optimal. But whichever is more optimal, it would not make much difference.
I am not sure why this does cmp rax, -4
followed by je .L13
instead of cmp rax, 0
and jl .L13
, but I suppose it does not make any difference anyway.
The only way that I can think of for making this operation run faster is to re-engineer your arrays of digits so that the digits are stored with the most significant digit first, instead of the most significant digit last.
If you do that, then the comparison function cmp3()
becomes:
bool cmp3( digit* digits1, digit* digits2 )
{
for( int i = 0; i < maxDigits; i++ )
{
if( *digits1 != *digits2 )
return *digits1 > *digits2;
digits1++;
digits2++;
}
return true;
}
Which translates to this:
xor eax, eax
jmp .L17
.L15:
add rax, 4
cmp rax, 4000
je .L19
.L17:
mov edx, DWORD PTR [rdi+rax]
mov ecx, DWORD PTR [rsi+rax]
cmp edx, ecx
je .L15
cmp ecx, edx
setb al
ret
.L19:
mov eax, 1
ret
And can even be further optimized into function cmp4()
:
bool cmp4( digit* digits1, digit* digits2 )
{
return memcmp( digits1, digits2, maxDigits * sizeof(digit) ) >= 0;
}
which translates to this:
sub rsp, 8
mov edx, 4000
call memcmp
add rsp, 8
not eax
shr eax, 31
ret
However, I doubt that this will give you a tremendous performance increase, either. It may well be that the performance increase you would get from doing this will be offset by the overhead of all the tricks that you will have to do in the rest of your code, in order to reverse the order of the digits.
2
An AVX2 approach could use _mm256_cmpgt_epi32(x,y)
in parallel with _mm256_cmpgt_epi(y,x)
. Subtract the two results (using signed arithmetic, _mm256_sub_epi
), and you’ll get either -1,0 or +1 if the particular digit in x
was bigger, equal or smaller. It’s a bit backwards, but that’s because _mm256_cmpgt_epi32(x,y)
returns all-zeroes or all-ones, and signed arithmetic interprets all-ones as -1. You can write this back to memory, find the first nonzero value (in digit order), and return that from operator<=>
.
An efficient solution is indeed to vectorize the operation. One way to do it in a portable way is to use the experimental C++ SIMD TS. Here is an (untested) implementation:
#include <experimental/simd>
namespace stdx = std::experimental;
using vuint = stdx::native_simd<uint32_t>;
// Note: constexpr certainly does not mix well with the SIMD TS
bool number::operator>=(const number& b) const {
bool ge_simd(const number& a, const number& b) {
int i = maxDigits - 1;
const int vsize = vuint().size();
for (; i > vsize - 2; --i)
{
vuint x(&digits[i], stdx::element_aligned);
vuint y(&b.digits[i], stdx::element_aligned);
if(stdx::any_of(x != y)) [[unlikely]]
{
// Non-simd code for sake of simplicity
// (less efficient but done at most once per call)
for (int j = 0; j < vsize; ++j) {
if (digits[i - j] != b.digits[i - j]) {
return digits[i - j] > b.digits[i - j];
}
}
// Unreachable
// In C++23, you can write: std::unreachable();
return false;
}
i -= vsize;
}
// Remaining items
for (; i > -1; --i)
if (const auto x = digits[i], y = b.digits[i]; x != y)
return x > y;
return true;
}
This implementation has the benefit to generate a relatively fast code on a wide range of different CPU architectures (not just x86-64 CPU supporting AVX-2, but also ARM CPUs supporting Neon/SVE/SVE2 and recent x86-64 CPUs supporting AVX-512, as well as old x86 CPU not even supporting AVX-2). You just need to tweak the compilation flags (e.g. -march=native
, -march=mavx
, -march=mavx2
, etc.).
Here is the assembly code generated by Clang (with -O3 -mavx2
) which is relatively good (unrolled twice):
.LBB0_2:
cmp rax, 16
jb .LBB0_3
vmovdqu ymm1, ymmword ptr [rdi + 4*rax - 36]
vpcmpeqd ymm1, ymm1, ymmword ptr [rsi + 4*rax - 36]
vptest ymm1, ymm0
jae .LBB0_4
add rax, -18
vmovdqu ymm1, ymmword ptr [rdi + 4*rax]
vpcmpeqd ymm1, ymm1, ymmword ptr [rsi + 4*rax]
vptest ymm1, ymm0
jb .LBB0_2
Here is a complete program on GodBolt.
On Intel Ice-Lake and newer CPUs, it takes 3 cycles/iterations. Each iteration operates on 2 x 8 x 32-bit integers. This means ~5 integers/cycles. Since the initial scalar code cannot be faster than 1 integer/cycle, this code should be significantly faster: about 5 times faster on this kind of CPU (assuming the code is not memory bound).