进行多字整数运算的一种方法是使用double-double arithmetic。让我们从一些双倍乘法代码开始
#include <math.h>
typedef struct {
double hi;
double lo;
} doubledouble;
static doubledouble quick_two_sum(double a, double b) {
double s = a + b;
double e = b - (s - a);
return (doubledouble){s, e};
}
static doubledouble two_prod(double a, double b) {
double p = a*b;
double e = fma(a, b, -p);
return (doubledouble){p, e};
}
doubledouble df64_mul(doubledouble a, doubledouble b) {
doubledouble p = two_prod(a.hi, b.hi);
p.lo += a.hi*b.lo;
p.lo += a.lo*b.hi;
return quick_two_sum(p.hi, p.lo);
}
函数two_prod 可以在两条指令中执行整数 53bx53b -> 106b。函数df64_mul可以做整数106bx106b -> 106b。
让我们将其与整数 128bx128b -> 128b 与整数硬件进行比较。
__int128 mul128(__int128 a, __int128 b) {
return a*b;
}
mul128 的程序集
imul rsi, rdx
mov rax, rdi
imul rcx, rdi
mul rdx
add rcx, rsi
add rdx, rcx
df64_mul 的程序集(使用 gcc -O3 -S i128.c -masm=intel -mfma -ffp-contract=off 编译)
vmulsd xmm4, xmm0, xmm2
vmulsd xmm3, xmm0, xmm3
vmulsd xmm1, xmm2, xmm1
vfmsub132sd xmm0, xmm4, xmm2
vaddsd xmm3, xmm3, xmm0
vaddsd xmm1, xmm3, xmm1
vaddsd xmm0, xmm1, xmm4
vsubsd xmm4, xmm0, xmm4
vsubsd xmm1, xmm1, xmm4
mul128 执行 3 次标量乘法和 2 次标量加法/减法,而 df64_mul 执行 3 次 SIMD 乘法、1 次 SIMD FMA 和 5 次 SIMD 加法/减法。我没有分析这些方法,但在我看来,df64_mul 在每个 AVX 寄存器中使用 4-doubles 可以胜过mul128(将sd 更改为pd 和xmm 到ymm),这似乎不是不合理的。
很容易说问题是切换回整数域。但为什么这是必要的?您可以在浮点域中做任何事情。让我们看一些例子。我发现使用float 进行单元测试比使用double 更容易。
doublefloat two_prod(float a, float b) {
float p = a*b;
float e = fma(a, b, -p);
return (doublefloat){p, e};
}
//3202129*4807935=15395628093615
x = two_prod(3202129,4807935)
int64_t hi = p, lo = e, s = hi+lo
//p = 1.53956280e+13, e = 1.02575000e+05
//hi = 15395627991040, lo = 102575, s = 15395628093615
//1450779*1501672=2178594202488
y = two_prod(1450779, 1501672)
int64_t hi = p, lo = e, s = hi+lo
//p = 2.17859424e+12, e = -4.00720000e+04
//hi = 2178594242560 lo = -40072, s = 2178594202488
所以我们最终得到不同的范围,在第二种情况下,错误 (e) 甚至是负数,但总和仍然正确。我们甚至可以将两个 doublefloat 值 x 和 y 相加(一旦我们知道如何进行双双加法 - 参见最后的代码)并得到 15395628093615+2178594202488。无需对结果进行归一化。
但是加法带来了双双运算的主要问题。即,加法/减法很慢,例如128b+128b -> 128b needs at least 11 floating point additions 而整数只需要两个(add 和 adc)。
因此,如果一个算法重乘法而轻加法,那么使用 double-double 进行多字整数运算可能会赢。
附带说明,C 语言足够灵活,可以实现整数完全通过浮点硬件实现的实现。 int 可以是 24 位(来自单个浮点),long 可以是 54 位。 (来自双浮点),long long 可能是 106 位(来自双双)。 C 甚至不需要二进制的补码,因此整数可以像通常使用浮点一样对负数使用带符号的幅度。
这是使用双倍乘法和加法的 C 代码(我还没有实现除法或其他操作,例如 sqrt,但有论文展示了如何做到这一点),以防有人想玩它。看看这是否可以针对整数进行优化会很有趣。
//if compiling with -mfma you must also use -ffp-contract=off
//float-float is easier to debug. If you want double-double replace
//all float words with double and fmaf with fma
#include <stdio.h>
#include <math.h>
#include <inttypes.h>
#include <x86intrin.h>
#include <stdlib.h>
//#include <float.h>
typedef struct {
float hi;
float lo;
} doublefloat;
typedef union {
float f;
int i;
struct {
unsigned mantisa : 23;
unsigned exponent: 8;
unsigned sign: 1;
};
} float_cast;
void print_float(float_cast a) {
printf("%.8e, 0x%x, mantisa 0x%x, exponent 0x%x, expondent-127 %d, sign %u\n", a.f, a.i, a.mantisa, a.exponent, a.exponent-127, a.sign);
}
void print_doublefloat(doublefloat a) {
float_cast hi = {a.hi};
float_cast lo = {a.lo};
printf("hi: "); print_float(hi);
printf("lo: "); print_float(lo);
}
doublefloat quick_two_sum(float a, float b) {
float s = a + b;
float e = b - (s - a);
return (doublefloat){s, e};
// 3 add
}
doublefloat two_sum(float a, float b) {
float s = a + b;
float v = s - a;
float e = (a - (s - v)) + (b - v);
return (doublefloat){s, e};
// 6 add
}
doublefloat df64_add(doublefloat a, doublefloat b) {
doublefloat s, t;
s = two_sum(a.hi, b.hi);
t = two_sum(a.lo, b.lo);
s.lo += t.hi;
s = quick_two_sum(s.hi, s.lo);
s.lo += t.lo;
s = quick_two_sum(s.hi, s.lo);
return s;
// 2*two_sum, 2 add, 2*quick_two_sum = 2*6 + 2 + 2*3 = 20 add
}
doublefloat split(float a) {
//#define SPLITTER (1<<27) + 1
#define SPLITTER (1<<12) + 1
float t = (SPLITTER)*a;
float hi = t - (t - a);
float lo = a - hi;
return (doublefloat){hi, lo};
// 1 mul, 3 add
}
doublefloat split_sse(float a) {
__m128 k = _mm_set1_ps(4097.0f);
__m128 a4 = _mm_set1_ps(a);
__m128 t = _mm_mul_ps(k,a4);
__m128 hi4 = _mm_sub_ps(t,_mm_sub_ps(t, a4));
__m128 lo4 = _mm_sub_ps(a4, hi4);
float tmp[4];
_mm_storeu_ps(tmp, hi4);
float hi = tmp[0];
_mm_storeu_ps(tmp, lo4);
float lo = tmp[0];
return (doublefloat){hi,lo};
}
float mult_sub(float a, float b, float c) {
doublefloat as = split(a), bs = split(b);
//print_doublefloat(as);
//print_doublefloat(bs);
return ((as.hi*bs.hi - c) + as.hi*bs.lo + as.lo*bs.hi) + as.lo*bs.lo;
// 4 mul, 4 add, 2 split = 6 mul, 10 add
}
doublefloat two_prod(float a, float b) {
float p = a*b;
float e = mult_sub(a, b, p);
return (doublefloat){p, e};
// 1 mul, one mult_sub
// 7 mul, 10 add
}
float mult_sub2(float a, float b, float c) {
doublefloat as = split(a);
return ((as.hi*as.hi -c ) + 2*as.hi*as.lo) + as.lo*as.lo;
}
doublefloat two_sqr(float a) {
float p = a*a;
float e = mult_sub2(a, a, p);
return (doublefloat){p, e};
}
doublefloat df64_mul(doublefloat a, doublefloat b) {
doublefloat p = two_prod(a.hi, b.hi);
p.lo += a.hi*b.lo;
p.lo += a.lo*b.hi;
return quick_two_sum(p.hi, p.lo);
//two_prod, 2 add, 2mul, 1 quick_two_sum = 9 mul, 15 add
//or 1 mul, 1 fma, 2add 2mul, 1 quick_two_sum = 3 mul, 1 fma, 5 add
}
doublefloat df64_sqr(doublefloat a) {
doublefloat p = two_sqr(a.hi);
p.lo += 2*a.hi*a.lo;
return quick_two_sum(p.hi, p.lo);
}
int float2int(float a) {
int M = 0xc00000; //1100 0000 0000 0000 0000 0000
a += M;
float_cast x;
x.f = a;
return x.i - 0x4b400000;
}
doublefloat add22(doublefloat a, doublefloat b) {
float r = a.hi + b.hi;
float s = fabsf(a.hi) > fabsf(b.hi) ?
(((a.hi - r) + b.hi) + b.lo ) + a.lo :
(((b.hi - r) + a.hi) + a.lo ) + b.lo;
return two_sum(r, s);
//11 add
}
int main(void) {
//print_float((float_cast){1.0f});
//print_float((float_cast){-2.0f});
//print_float((float_cast){0.0f});
//print_float((float_cast){3.14159f});
//print_float((float_cast){1.5f});
//print_float((float_cast){3.0f});
//print_float((float_cast){7.0f});
//print_float((float_cast){15.0f});
//print_float((float_cast){31.0f});
//uint64_t t = 0xffffff;
//print_float((float_cast){1.0f*t});
//printf("%" PRId64 " %" PRIx64 "\n", t*t,t*t);
/*
float_cast t1;
t1.mantisa = 0x7fffff;
t1.exponent = 0xfe;
t1.sign = 0;
print_float(t1);
*/
//doublefloat z = two_prod(1.0f*t, 1.0f*t);
//print_doublefloat(z);
//double z2 = (double)z.hi + (double)z.lo;
//printf("%.16e\n", z2);
doublefloat s = {0};
int64_t si = 0;
for(int i=0; i<100000; i++) {
int ai = rand()%0x800, bi = rand()%0x800000;
float a = ai, b = bi;
doublefloat z = two_prod(a,b);
int64_t zi = (int64_t)ai*bi;
//print_doublefloat(z);
//s = df64_add(s,z);
s = add22(s,z);
si += zi;
print_doublefloat(z);
printf("%d %d ", ai,bi);
int64_t h = z.hi;
int64_t l = z.lo;
int64_t t = h+l;
//if(t != zi) printf("%" PRId64 " %" PRId64 "\n", h, l);
printf("%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "\n", zi, h, l, h+l);
h = s.hi;
l = s.lo;
t = h + l;
//if(si != t) printf("%" PRId64 " %" PRId64 "\n", h, l);
if(si > (1LL<<48)) {
printf("overflow after %d iterations\n", i); break;
}
}
print_doublefloat(s);
printf("%" PRId64 "\n", si);
int64_t x = s.hi;
int64_t y = s.lo;
int64_t z = x+y;
//int hi = float2int(s.hi);
printf("%" PRId64 " %" PRId64 " %" PRId64 "\n", z,x,y);
}