模块化算法和 NTT(有限域 DFT)优化

我想使用 NTT 进行快速平方(请参阅快速 bignum 平方计算),但即使对于非常大的数字……超过 12000 位.

I wanted to use NTT for fast squaring (see Fast bignum square computation), but the result is slow even for really big numbers .. more than 12000 bits.

所以我的问题是:

  1. 有没有办法优化我的 NTT 转换?我并不是要通过并行(线程)来加速它;这只是低级层.
  2. 有没有办法加快我的模块化算术的速度?

这是我在 C++ 中为 NTT 编写的(已经优化的)源代码(它是完整的并且 100% 在 C++ 中工作,不需要第三方库,并且还应该是线程安全的.注意源数组被用作临时!!!, 也不能将数组转化为自身).

This is my (already optimized) source code in C++ for NTT (it's complete and 100% working in C++ whitout any need for third-party libs and should also be thread-safe. Beware the source array is used as a temporary!!!, Also it cannot transform the array to itself).

//---------------------------------------------------------------------------
class fourier_NTT                                    // Number theoretic transform
    {

public:
    DWORD r,L,p,N;
    DWORD W,iW,rN;
    fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; }

    // main interface
    void  NTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast  NTT(DWORD src[n])
    void INTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast INTT(DWORD src[n])

    // Helper functions
    bool init(DWORD n);                                       // init r,L,p,W,iW,rN
    void  NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = fast  NTT(DWORD src[n])

    // Only for testing
    void  NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow  NTT(DWORD src[n])
    void INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow INTT(DWORD src[n])

    // DWORD arithmetics
    DWORD shl(DWORD a);
    DWORD shr(DWORD a);

    // Modular arithmetics
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    DWORD modsub(DWORD a,DWORD b);
    DWORD modmul(DWORD a,DWORD b);
    DWORD modpow(DWORD a,DWORD b);
    };

//---------------------------------------------------------------------------
void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,W);
//    NTT_slow(dst,src,N,W);
    }

//---------------------------------------------------------------------------
void fourier_NTT::INTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,iW);
    for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
       //    INTT_slow(dst,src,N,W);
    }

//---------------------------------------------------------------------------
bool fourier_NTT::init(DWORD n)
    {
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
//    r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
//    r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
//    r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
     N=n;                // size of vectors [DWORDs]
     W=modpow(r,    L);    // Wn for NTT
    iW=modpow(r,p-1-L);    // Wn for INTT
    rN=modpow(n,p-2  );    // scale for INTT
    return true;
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);
    // reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];
    // recursion
    NTT_fast(src   ,dst   ,n2,w2);    // even
    NTT_fast(src+n2,dst+n2,n2,w2);    // odd
    // restore results
    for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
        {
        a0=src[i];
        a1=modmul(src[j],w2);
        dst[i]=modadd(a0,a1);
        dst[j]=modsub(a0,a1);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wj,wi,a,n2=n>>1;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i]));
            wi=modmul(wi,wj);
            }
        dst[j]=a;
        wj=modmul(wj,w);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT::INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wi=1,wj=1,a,n2=n>>1;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i]));
            wi=modmul(wi,wj);
            }
        dst[j]=modmul(a,rN);
        wj=modmul(wj,iW);
        }
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::shl(DWORD a) { return (a<<1)&0xFFFFFFFE; }
DWORD fourier_NTT::shr(DWORD a) { return (a>>1)&0x7FFFFFFF; }

//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
    {
    DWORD bb;
    for (bb=p;(DWORD(a)>DWORD(bb))&&(!DWORD(bb&0x80000000));bb=shl(bb));
    for (;;)
        {
        if (DWORD(a)>=DWORD(bb)) a-=bb;
        if (bb==p) break;
        bb =shr(bb);
        }
    return a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    a=mod(a);
    b=mod(b);
    d=a+b;
    cy=(shr(a)+shr(b)+shr((a&1)+(b&1)))&0x80000000;
    if (cy) d-=p;
    if (DWORD(d)>=DWORD(p)) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    a=mod(a);
    b=mod(b);
    d=a-b; if (DWORD(a)<DWORD(b)) d+=p;
    if (DWORD(d)>=DWORD(p)) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {    // b bez orezania !
    int i;
    DWORD d;
    a=mod(a);
    for (d=0,i=0;i<32;i++)
        {
        if (DWORD(a&1))    d=modadd(d,b);
        a=shr(a);
        b=modadd(b,b);
        }
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    {    // a,b bez orezania !
    int i;
    DWORD d=1;
    for (i=0;i<32;i++)
        {
        d=modmul(d,d);
        if (DWORD(b&0x80000000)) d=modmul(d,a);
        b=shl(b);
        }
    return d;
    }
//---------------------------------------------------------------------------

我的 NTT 类的使用示例:

Example of usage of my NTT class:

fourier_NTT ntt;
const DWORD n=32
DWORD x[N]={0,1,2,3,....31},y[N]={32,33,34,35,...63},z[N];

ntt.NTT(z,x,N);    // z[N]=NTT(x[N]), also init constants for N
ntt.NTT(x,y);    // x[N]=NTT(y[N]), no recompute of constants, use last N
// modular convolution y[]=z[].x[]
for (i=0;i<n;i++) y[i]=ntt.modmul(z[i],x[i]);
ntt.INTT(x,y);    // x[N]=INTT(y[N]), no recompute of constants, use last N
// x[]=convolution of original x[].y[]

优化前的一些测量(非 NTT 类):

Some measurements before optimizations (non Class NTT):

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul

我优化后的一些测量(当前代码、较低的递归参数大小/计数和更好的模块化算法):

Some measurements after my optimizations (current code, lower recursion parameter size/count, and better modular arithmetics):

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.214 ms ] fast sqr
sqr2[ 208.298 ms ] NTT sqr
mul1[ 5.564 ms ] simpe mul
mul2[ 3.113 ms ] karatsuba mul
mul3[ 302.740 ms ] NTT mul

检查 NTT mul 和 NTT sqr 时间(我的优化加快了 3 倍多一点).它只有 1 倍循环,所以它不是很精确(误差 ~ 10%),但即使现在加速也很明显(通常我循环它 1000 倍甚至更多,但我的 NTT 太慢了).

Check the NTT mul and NTT sqr times (my optimizations speed it up little over 3x times). It's only 1x times loop so it's not very precise (error ~ 10%), but the speedup is noticeable even now (normally I loop it 1000x and more, but my NTT is too slow for that).

您可以自由使用我的代码...只需将我的昵称和/或指向此页面的链接保留在某处(rem in code、readme.txt、about 或其他内容).我希望它有帮助......(我没有在任何地方看到快速 NTT 的 C++ 源代码,所以我不得不自己编写).对所有接受的 N 都测试了统一根,请参阅 fourier_NTT::init(DWORD n) 函数.

You can use my code freely... Just keep my nick and/or link to this page somewhere (rem in code, readme.txt, about or whatever). I hope it helps... (I did not see C++ source for fast NTTs anywhere so I had to write it by myself). Roots of unity were tested for all accepted N, see the fourier_NTT::init(DWORD n) function.

P.S.:有关 NTT 的更多信息,请参阅从复杂 FFT 到有限域 FFT 的转换.此代码基于我在该链接中的帖子.

P.S.: For more information about NTT, see Translation from Complex-FFT to Finite-Field-FFT. This code is based on my posts inside that link.

[edit1:] 代码的进一步变化

通过利用模素数始终为 0xC0000001 并消除不必要的调用,我设法进一步优化了我的模算术.由此产生的加速现在令人惊叹(超过 40 倍),并且在大约 1500 * 32 位阈值之后,NTT 乘法比 karatsuba 更快.顺便说一句,我的 NTT 的速度现在与我在 64 位双打上优化的 DFFT 的速度相同.

I managed to further optimize my modular arithmetics, by exploiting that modulo prime is allways 0xC0000001 and eliminating unnecessary calls. The resulting speedup is stunning (more than 40x times) now and NTT multiplication is faster than karatsuba after about the 1500 * 32 bits threshold. BTW, the speed of my NTT is now the same as my optimized DFFT on 64-bit doubles.

一些测量:

a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] karatsuba mul
mul3[ 26.311 ms ] NTT mul

模块化算术的新源代码:

New source code for modular arithmetics:

//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
    {
    if (a>p) a-=p;
    return a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    if (a>p) a-=p;
    if (b>p) b-=p;
    d=a+b;
    cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
    if (cy ) d-=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    if (a>p) a-=p;
    if (b>p) b-=p;
    d=a-b;
    if (a<b) d+=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {
    DWORD _a,_b,_p;
    _a=a;
    _b=b;
    _p=p;
    asm    {
        mov    eax,_a
        mov    ebx,_b
        mul    ebx        // H(edx),L(eax) = eax * ebx
        mov    ebx,_p
        div    ebx        // eax = H(edx),L(eax) / ebx
        mov    _a,edx    // edx = H(edx),L(eax) % ebx
        }
    return _a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    {    // b bez orezania!
    int i;
    DWORD d=1;
    if (a>p) a-=p;
    for (i=0;i<32;i++)
        {
        d=modmul(d,d);
        if (DWORD(b&0x80000000)) d=modmul(d,a);
        b<<=1;
        }
    return d;
    }

//---------------------------------------------------------------------------

如您所见,函数 shlshr?? 不再使用.我认为 modpow 可以进一步优化,但它不是一个关键函数,因为它只被调用了很少的次数.最关键的函数是 modmul,它似乎处于最佳状态.

As you can see, functions shl and shr are no more used. I think that modpow can be further optimized, but it's not a critical function because it is called only very few times. The most critical function is modmul, and that seems to be in the best shape possible.

其他问题:

  • 还有其他方法可以加速 NTT 吗?
  • 我对模块化算法的优化安全吗?(结果似乎是一样的,但我可能会遗漏一些东西.)

[edit2] 新的优化

a = 0.99991970486 | 2000*32 bits
looped 10x
sqr1[  13.908 ms ] fast sqr
sqr2[  13.649 ms ] NTT sqr
mul1[  19.726 ms ] simpe mul
mul2[  31.808 ms ] karatsuba mul
mul3[  19.373 ms ] NTT mul

我从你的所有评论中实现了所有可用的东西(感谢你的洞察力).

I implemented all the usable stuff from all of your comments (thanks for the insight).

加速:

  • 通过移除不必要的安全模组(Mandalf The Beige)+2.5%
  • +34.9% 通过使用预先计算的 W,iW 功率(神秘)
  • +35% 总计

实际完整源代码:

//---------------------------------------------------------------------------
//--- Number theoretic transforms: 2.03 -------------------------------------
//---------------------------------------------------------------------------
#ifndef _fourier_NTT_h
#define _fourier_NTT_h
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
class fourier_NTT        // Number theoretic transform
    {
public:
    DWORD r,L,p,N;
    DWORD W,iW,rN;        // W=(r^L) mod p, iW=inverse W, rN = inverse N
    DWORD *WW,*iWW,NN;    // Precomputed (W,iW)^(0,..,NN-1) powers

    // Internals
    fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; WW=NULL; iWW=NULL; NN=0; }
    ~fourier_NTT(){ _free(); }
    void _free();                                            // Free precomputed W,iW powers tables
    void _alloc(DWORD n);                                    // Allocate and precompute W,iW powers tables

    // Main interface
    void  NTT(DWORD *dst,DWORD *src,DWORD n=0);                // DWORD dst[n] = fast  NTT(DWORD src[n])
    void iNTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast INTT(DWORD src[n])

    // Helper functions
    bool init(DWORD n);                                          // init r,L,p,W,iW,rN
    void  NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = fast  NTT(DWORD src[n])
    void  NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2);

    // Only for testing
    void  NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow  NTT(DWORD src[n])
    void iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow INTT(DWORD src[n])

    // Modular arithmetics (optimized, but it works only for p >= 0x80000000!!!)
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    DWORD modsub(DWORD a,DWORD b);
    DWORD modmul(DWORD a,DWORD b);
    DWORD modpow(DWORD a,DWORD b);
    };
//---------------------------------------------------------------------------

//---------------------------------------------------------------------------
void fourier_NTT::_free()
    {
    NN=0;
    if ( WW) delete[]  WW;  WW=NULL;
    if (iWW) delete[] iWW; iWW=NULL;
    }

//---------------------------------------------------------------------------
void fourier_NTT::_alloc(DWORD n)
    {
    if (n<=NN) return;
    DWORD *tmp,i,w;
    tmp=new DWORD[n]; if ((NN)&&( WW)) for (i=0;i<NN;i++) tmp[i]= WW[i]; if ( WW) delete[]  WW;  WW=tmp;  WW[0]=1; for (i=NN?NN:1,w= WW[i-1];i<n;i++){ w=modmul(w, W);  WW[i]=w; }
    tmp=new DWORD[n]; if ((NN)&&(iWW)) for (i=0;i<NN;i++) tmp[i]=iWW[i]; if (iWW) delete[] iWW; iWW=tmp; iWW[0]=1; for (i=NN?NN:1,w=iWW[i-1];i<n;i++){ w=modmul(w,iW); iWW[i]=w; }
    NN=n;
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,WW,1);
//    NTT_fast(dst,src,N,W);
//    NTT_slow(dst,src,N,W);
    }

//---------------------------------------------------------------------------
void fourier_NTT::iNTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,iWW,1);
//    NTT_fast(dst,src,N,iW);
    for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
//    iNTT_slow(dst,src,N,W);
    }

//---------------------------------------------------------------------------
bool fourier_NTT::init(DWORD n)
    {
    // (max(src[])^2)*n < p else NTT overflow can ocur!!!
    r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
//    r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
//    r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
//    r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
     N=n;                // Size of vectors [DWORDs]
     W=modpow(r,    L);  // Wn for NTT
    iW=modpow(r,p-1-L);  // Wn for INTT
    rN=modpow(n,p-2  );  // Scale for INTT
    _alloc(n>>1);        // Precompute W,iW powers
    return true;
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);

    // Reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];

    // Recursion
    NTT_fast(src   ,dst   ,n2,w2);    // Even
    NTT_fast(src+n2,dst+n2,n2,w2);    // Odd

    // Restore results
    for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
        {
        a0=src[i];
        a1=modmul(src[j],w2);
        dst[i]=modadd(a0,a1);
        dst[j]=modsub(a0,a1);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1;

    // Reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];

    // Recursion
    i=i2<<1;
    NTT_fast(src   ,dst   ,n2,w2,i);    // Even
    NTT_fast(src+n2,dst+n2,n2,w2,i);    // Odd

    // Restore results
    for (i=0,j=n2;i<n2;i++,j++,w2+=i2)
        {
        a0=src[i];
        a1=modmul(src[j],*w2);
        dst[i]=modadd(a0,a1);
        dst[j]=modsub(a0,a1);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wj,wi,a;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i]));
            wi=modmul(wi,wj);
            }
        dst[j]=a;
        wj=modmul(wj,w);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT::iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wi=1,wj=1,a;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i]));
            wi=modmul(wi,wj);
            }
        dst[j]=modmul(a,rN);
        wj=modmul(wj,iW);
        }
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
    {
    if (a>p) a-=p;
    return a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    //if (a>p) a-=p;
    //if (b>p) b-=p;
    d=a+b;
    cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
    if (cy ) d-=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    //if (a>p) a-=p;
    //if (b>p) b-=p;
    d=a-b;
    if (a<b) d+=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {
    DWORD _a,_b,_p;
    _a=a;
    _b=b;
    _p=p;
    asm    {
        mov    eax,_a
        mov    ebx,_b
        mul    ebx        // H(edx),L(eax) = eax * ebx
        mov    ebx,_p
        div    ebx        // eax = H(edx),L(eax) / ebx
        mov    _a,edx    // edx = H(edx),L(eax) % ebx
        }
    return _a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    {    // b is not mod(p)!
    int i;
    DWORD d=1;
    //if (a>p) a-=p;
    for (i=0;i<32;i++)
        {
        d=modmul(d,d);
        if (DWORD(b&0x80000000)) d=modmul(d,a);
        b<<=1;
        }
    return d;
    }
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
#endif
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------

通过将 NTT_fast 分成两个函数,仍然有可能使用更少的堆垃圾.一个带有 WW[],另一个带有 iWW[],这导致递归调用中的一个参数减少.但我对它的期望并不高(仅限 32 位指针),而是有一个功能可以在未来更好地管理代码.许多函数现在处于休眠状态(用于测试),例如慢变体、mod 和较旧的快速函数(使用 w 参数代替 *w2,i2).

There is still the possibility to use less heap trashing by separating NTT_fast to two functions. One with WW[] and the other with iWW[] which leads to one parameter less in recursion calls. But I do not expect much from it (32-bit pointer only) and rather have one function for better code management in the future. Many functions are dormant now (for testing) Like slow variants, mod and the older fast function (with w parameter instead of *w2,i2).

为避免大数据集溢出,将输入数字限制为 p/4 位. 其中 p 是每个 的位数NTT 元素,因此对于此 32 位版本,请使用最大 (32 位/4 -> 8 位) 输入值.

To avoid overflows for big datasets, limit input numbers to p/4 bits. Where p is number of bits per NTT element so for this 32 bit version use max (32 bit/4 -> 8 bit) input values.

[edit3] 用于测试的简单字符串 bigint 乘法

[edit3] Simple string bigint multiplication for testing

//---------------------------------------------------------------------------
char* mul_NTT(const char *sx,const char *sy)
    {
    char *s;
    int i,j,k,n;
    // n = min power of 2 <= 2 max length(x,y)
    for (i=0;sx[i];i++); for (n=1;n<i;n<<=1);        i--;
    for (j=0;sx[j];j++); for (n=1;n<j;n<<=1); n<<=1; j--;
    DWORD *x,*y,*xx,*yy,a;
    x=new DWORD[n]; xx=new DWORD[n];
    y=new DWORD[n]; yy=new DWORD[n];

    // Zero padding
    for (k=0;i>=0;i--,k++) x[k]=sx[i]-'0'; for (;k<n;k++) x[k]=0;
    for (k=0;j>=0;j--,k++) y[k]=sy[j]-'0'; for (;k<n;k++) y[k]=0;

    //NTT
    fourier_NTT ntt;
    ntt.NTT(xx,x,n);
    ntt.NTT(yy,y);

    // Convolution
    for (i=0;i<n;i++) xx[i]=ntt.modmul(xx[i],yy[i]);

    //INTT
    ntt.iNTT(yy,xx);

    //suma
    a=0; s=new char[n+1]; for (i=0;i<n;i++) { a+=yy[i]; s[n-i-1]=(a%10)+'0'; a/=10; } s[n]=0;
    delete[] x; delete[] xx;
    delete[] y; delete[] yy;

    return s;
    }
//---------------------------------------------------------------------------

我使用AnsiString,所以我希望将它移植到char*,我没有做错.看起来它工作正常(与 AnsiString 版本相比).

I use AnsiString's, so I port it to char* hopefully, I did not do some mistake. It looks like it works properly (in comparison to the AnsiString version).

  • sx,sy 是十进制整数
  • 返回分配的字符串(char*)=sx*sy
  • sx,sy are decadic integer numbers
  • Returns allocated string (char*)=sx*sy

每 32 位数据字只有 ~4 位,因此没有溢出的风险,但当然速度较慢.在我的 bignum 库中,我使用二进制表示,并为 NTT 每 32 位 WORD 使用 8 位 块.如果 N 很大,那么风险更大......

This is only ~4 bit per 32 bit data word so there is no risk of overflow, but it is slower of course. In my bignum lib I use a binary representation and use 8 bit chunks per 32-bit WORD for NTT. More than that is risky if N is big ...

玩得开心

推荐答案

首先,非常感谢您发布并免费使用它.我真的很感激.

First off, thank you very much for posting and making it free to use. I really appreciate that.

我能够使用一些小技巧来消除一些分支,重新排列主循环并修改程序集,并且能够获得 1.35 倍的加速.

I was able to use some bit tricks to eliminate some branching, rearranged the main loop, and modified the assembly, and was able to get a 1.35x speedup.

另外,我为 64 位添加了一个预处理器条件,因为 Visual Studio 不允许在 64 位模式下进行内联汇编(谢谢微软;你可以自己搞砸了).

Also, I added a preprocessor condition for 64 bit, seeing as Visual Studio doesn't allow inline assembly in 64 bit mode (thank you Microsoft; feel free to go screw yourself).

在优化 modsub() 函数时发生了一些奇怪的事情.我像 modadd 一样使用 bit hacks 重写了它(速度更快).但出于某种原因,modsub 的位明智版本较慢.不知道为什么.可能只是我的电脑.

Something strange happened when I was optimizing the modsub() function. I rewrote it using bit hacks like I did modadd (which was faster). But for some reason, the bit wise version of modsub was slower. Not sure why. Might just be my computer.

//
// Mandalf The Beige
// Based on:
// Spektre
// http://stackoverflow.com/questions/18577076/modular-arithmetics-and-ntt-finite-field-dft-optimizations
//
// This code may be freely used however you choose, so long as it is accompanied by this notice.
//




#ifndef H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR
#define H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR

#include <string.h>

#ifndef uint32
#define uint32 unsigned long int
#endif

#ifndef uint64
#define uint64 unsigned long long int
#endif


class fast_ntt                                   // number theoretic transform
{
    public:
    fast_ntt()
    {
        r = 0; L = 0;
        W = 0; iW = 0; rN = 0;
    }
    // main interface
    void  NTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast  NTT(uint32 src[n])
    void INTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast INTT(uint32 src[n])
    // helper functions

    private:
    bool init(uint32 n);                                     // init r,L,p,W,iW,rN
    void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])

    void  NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])
    void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);
    // only for testing
    void  NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = slow  NTT(uint32 src[n])
    void INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = slow INTT(uint32 src[n])
    // uint32 arithmetics


    // modular arithmetics
    inline uint32 modadd(uint32 a, uint32 b);
    inline uint32 modsub(uint32 a, uint32 b);
    inline uint32 modmul(uint32 a, uint32 b);
    inline uint32 modpow(uint32 a, uint32 b);

    uint32 r, L, N;//, p;
    uint32 W, iW, rN;

    const uint32 p = 0xC0000001;
};

//---------------------------------------------------------------------------
void fast_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
    if (n > 0)
    {
        init(n);
    }
    NTT_fast(dst, src, N, W);
    //  NTT_slow(dst,src,N,W);
}

//---------------------------------------------------------------------------
void fast_ntt::INTT(uint32 *dst, uint32 *src, uint32 n)
{
    if (n > 0)
    {
        init(n);
    }
    NTT_fast(dst, src, N, iW);
    for (uint32 i = 0; i<N; i++)
    {
        dst[i] = modmul(dst[i], rN);
    }
    //  INTT_slow(dst,src,N,W);
}

//---------------------------------------------------------------------------
bool fast_ntt::init(uint32 n)
{
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r = 2;
    //p = 0xC0000001;
    if ((n < 2) || (n > 0x10000000))
    {
        r = 0; L = 0; W = 0; // p = 0;
        iW = 0; rN = 0; N = 0;
        return false;
    }
    L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit
    //  r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
    //  r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
    //  r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
    N = n;               // size of vectors [uint32s]
    W = modpow(r, L); // Wn for NTT
    iW = modpow(r, p - 1 - L); // Wn for INTT
    rN = modpow(n, p - 2); // scale for INTT
    return true;
}

//---------------------------------------------------------------------------

void fast_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if(n > 1)
    {
        if(dst != src)
        {
            NTT_calc(dst, src, n, w);
        }
        else
        {
            uint32* temp = new uint32[n];
            NTT_calc(temp, src, n, w);
            memcpy(dst, temp, n * sizeof(uint32));
            delete [] temp;
        }
    }
    else if(n == 1)
    {
        dst[0] = src[0];
    }
}

void fast_ntt::NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w)
{
    if (n > 1)
    {
        uint32* temp = new uint32[n];
        memcpy(temp, src, n * sizeof(uint32));
        NTT_calc(dst, temp, n, w);
        delete[] temp;
    }
    else if (n == 1)
    {
        dst[0] = src[0];
    }
}



void fast_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if(n > 1)
    {
        uint32 i, j, a0, a1,
        n2 = n >> 1,
        w2 = modmul(w, w);

        // reorder even,odd
        for (i = 0, j = 0; i < n2; i++, j += 2)
        {
            dst[i] = src[j];
        }
        for (j = 1; i < n; i++, j += 2)
        {
            dst[i] = src[j];
        }
        // recursion
        if(n2 > 1)
        {
            NTT_calc(src, dst, n2, w2);  // even
            NTT_calc(src + n2, dst + n2, n2, w2);  // odd
        }
        else if(n2 == 1)
        {
            src[0] = dst[0];
            src[1] = dst[1];
        }

        // restore results

        w2 = 1, i = 0, j = n2;
        a0 = src[i];
        a1 = src[j];
        dst[i] = modadd(a0, a1);
        dst[j] = modsub(a0, a1);
        while (++i < n2)
        {
            w2 = modmul(w2, w);
            j++;
            a0 = src[i];
            a1 = modmul(src[j], w2);
            dst[i] = modadd(a0, a1);
            dst[j] = modsub(a0, a1);
        }
    }
}

//---------------------------------------------------------------------------
void fast_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    uint32 i, j, wj, wi, a,
        n2 = n >> 1;
    for (wj = 1, j = 0; j < n; j++)
    {
        a = 0;
        for (wi = 1, i = 0; i < n; i++)
        {
            a = modadd(a, modmul(wi, src[i]));
            wi = modmul(wi, wj);
        }
        dst[j] = a;
        wj = modmul(wj, w);
    }
}

//---------------------------------------------------------------------------
void fast_ntt::INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    uint32 i, j, wi = 1, wj = 1, a, n2 = n >> 1;

    for (wj = 1, j = 0; j < n; j++)
    {
        a = 0;
        for (wi = 1, i = 0; i < n; i++)
        {
            a = modadd(a, modmul(wi, src[i]));
            wi = modmul(wi, wj);
        }
        dst[j] = modmul(a, rN);
        wj = modmul(wj, iW);
    }
}    


//---------------------------------------------------------------------------
uint32 fast_ntt::modadd(uint32 a, uint32 b)
{
    uint32 d;
    d = a + b;

    if(d < a)
    {
        d -= p;
    }
    if (d >= p)
    {
        d -= p;
    }
    return d;
}

//---------------------------------------------------------------------------
uint32 fast_ntt::modsub(uint32 a, uint32 b)
{
    uint32 d;
    d = a - b;
    if (d > a)
    {
        d += p;
    }
    return d;
}

//---------------------------------------------------------------------------
uint32 fast_ntt::modmul(uint32 a, uint32 b)
{
    uint32 _a = a;
    uint32 _b = b;

    // Original
    uint32 _p = p;
    __asm
    {
        mov eax, _a;
        mul _b;
        div _p;
        mov eax, edx;
    };
}


uint32 fast_ntt::modpow(uint32 a, uint32 b)
{
    //*
    uint64 D, M, A, P;

    P = p; A = a;
    M = 0llu - (b & 1);
    D = (M & A) | ((~M) & 1);

    while ((b >>= 1) != 0)
    {
        A = modmul(A, A);
        //A = (A * A) % P;

        if ((b & 1) == 1)
        {
            //D = (D * A) % P;
            D = modmul(D, A);
        }
    }
    return (uint32)D;
}

新模块

uint32 fast_ntt::modmul(uint32 a, uint32 b)
{
    uint32 _a = a;
    uint32 _b = b;   

    __asm
    {
    mov eax, a;
    mul b;
    mov ebx, eax;
    mov eax, 2863311530;
    mov ecx, edx;
    mul edx;
    shld edx, eax, 1;
    mov eax, 3221225473;

    mul edx;
    sub ebx, eax;
    mov eax, 3221225473;
    sbb ecx, edx;
    jc addback;

            neg ecx;
            and ecx, eax;
            sub ebx, ecx;

    sub ebx, eax;
    sbb edx, edx;
    and eax, edx;
            addback:
    add eax, ebx;          
    };  
}

Spektre,根据您的反馈,我更改了 modadd &modsub 回到原来的状态.我也意识到我对递归 NTT 函数做了一些我不应该做的更改.

Spektre, based on your feedback I changed the modadd & modsub back to their original. I also realized I made some changes to the recursive NTT function I shouldn't have.

删除了不需要的 if 语句和按位函数.

Removed unneeded if statements and bitwise functions.

添加了新的 modmul 内联程序集.

Added new modmul inline assembly.

相关文章