/* mont62.c */ /* ** This file is placed into the public domain by its author, ** Carey Bloodworth (Carey@Bloodworth.org) on July 16, 2001 ** ** This multiplication demo is not designed for high performance. ** It's a tutorial program designed to be used with the information ** on my web site at www.Bloodworth.org */ /* ** This file demonstrates a very basic NTT multiply. It uses a single ** 62 bit prime. ** ** Multiplications are done using Montgomery modular multiplication. ** ** To compile this using GCC: ** gcc main.c mont62.c -o mont62.exe */ #include#include #include #include #include #define CalcNTTLen(_NumLen) ((((_NumLen)*BASE_DIG)*2)/BASE_DIG) /* NumLen*BaseDig*ZeroPadding/Dig_Per_FFT */ #define USE_MONTGOMERY 1 typedef short int Short; typedef signed long INT32; /* 32/31 bit signed int */ typedef unsigned long UINT32; /* 32 bit unsigned int */ typedef unsigned long long UINT64; typedef UINT64 ModInt; static ModInt *NTTNum1=NULL, *NTTNum2=NULL; static int BASE; static int BASE_DIG; ModInt Prime,PrimvRoot,MulInv; ModInt FromMontC,ToMontC; UINT32 MontMulC; /* NOTE: This is our word size, not a ModInt */ /*ModInt MontToMontC,MontFromMontC;*/ #define MONTLOW(z) ((UINT32)(z)) #define MONTHIGH(z) (((UINT64)(z))>>32) #define MUL64(z,q) (((UINT64)z)*((UINT64)q)) ModInt ModMul(ModInt a, ModInt b) /* This can be *extremely* simple because it'll be rarely used */ /* Just a plain 'bit by bit' multiply. */ { int i = 64; ModInt Prod; Prod=0; while (i--) { if (Prod & 0x8000000000000000) { Prod<<=1; Prod-=Prime; } else { Prod<<=1; } if (Prod >= Prime) Prod-=Prime; if (a & 0x8000000000000000) {ModInt T; T=Prod+b; if (T < Prod) T-=Prime; Prod=T; } a<<=1; if (Prod >= Prime) Prod-=Prime; } return Prod; } ModInt ModAdd(ModInt a,ModInt b) {ModInt Sum; /* unsigned data type */ Sum=a+b; if (Sum < a) Sum-=Prime; if (Sum >= Prime) Sum-=Prime; return Sum; } ModInt ModSub(ModInt a,ModInt b) {ModInt Dif; /* unsigned data type */ Dif=a-b; if (Dif > a) Dif+=Prime; return Dif; } ModInt ModPow(ModInt base,ModInt expon) {ModInt prod,b; if (expon==0) return 1; b=base; while (!(expon&1)) { b=ModMul(b,b); expon>>=1; } prod=b; while (expon>>=1) { b=ModMul(b,b); if (expon&1) prod=ModMul(prod,b); } return prod; } ModInt FindInverse(ModInt Num) /* Or we could use an extended gcd function */ { return ModPow(Num,Prime-2); } /* ** Montgomery modular multiplication stuff. */ #ifdef USE_MONTGOMERY #define ToMont(_x) ModMul(ToMontC,_x) #define FromMont(_x) ModMul(FromMontC,_x) /* ** ToMont and FromMont are a little slow. After all, they ** are using our dead dog slow ModMul routine. If we wanted ** to do a little preperation, we can speed them up. ** We can convert those constants into Montgomery space ** and that way we can use our fast MontMul with them. ModInt MontToMontC,MontFromMontC; MontToMontC=ModMul(ToMontC,ToMontC); MontFromMontC=ModMul(ToMontC,FromMontC); ** And then we can do: #define ToMont(_x) MontMul(MontToMontC,_x) #define FromMont(_x) MontMul(MontFromMontC,_x) ** I'm not actually doing it this way because you need to see ** the regular way. */ #if 0 ModInt MontMul(ModInt Num1,ModInt Num2) /* Pretty standard Montgomery Modular Multiplication */ {UINT64 x; UINT32 A0,A1,B0,B1,P0,P1,D0,D1; A0=MONTLOW(Num1); A1=MONTHIGH(Num1); B0=MONTLOW(Num2); B1=MONTHIGH(Num2); D0=MONTLOW(Prime);D1=MONTHIGH(Prime); x=MUL64(A0,B0); P0=MONTLOW((INT32)x) * MontMulC;P0=MONTLOW(P0); x+=MUL64(P0,D0); if (MONTLOW(x)) {printf("x1 is not zero.\n");exit(0);} x=MONTHIGH(x); x=x+MUL64(P0,D1)+MUL64(A0,B1)+MUL64(A1,B0); P1=MONTLOW((INT32)x) * MontMulC;P1=MONTLOW(P1); x+=MUL64(P1,D0); if (MONTLOW(x)) {printf("x2 is not zero.\n");exit(0);} x=MONTHIGH(x); x+=MUL64(P1,D1)+MUL64(A1,B1); if (x >= ((UINT64)Prime)) x-=Prime; return x; } #endif #if 1 ModInt MontMul(ModInt Num1,ModInt Num2) /* More advanced MontMul requiring only 6 muls. */ {UINT64 Sum,P00,P01,P10,P11,T64; UINT32 A0,A1,B0,B1,Prime0,Prime1;UINT32 T; A0=MONTLOW(Num1); A1=MONTHIGH(Num1); B0=MONTLOW(Num2); B1=MONTHIGH(Num2); Prime0=MONTLOW(Prime);Prime1=MONTHIGH(Prime); P00=MUL64(A0,B0); P10=MUL64(A1,B0); P01=MUL64(A0,B1); P11=MUL64(A1,B1); T= MONTLOW(-((UINT32)MONTLOW(P00))); T64=MUL64(T,Prime1); Sum=P00+T;Sum=MONTHIGH(Sum); Sum=Sum+T64+P01+P10; T= MONTLOW(-((UINT32)MONTLOW(Sum))); Sum=Sum+T;Sum=MONTHIGH(Sum); Sum=Sum+MUL64(T,Prime1); Sum=Sum+P11; if (Sum >= ((UINT64)Prime)) Sum-=Prime; return Sum; } #endif ModInt MontPow(ModInt base,ModInt expon) /* ** Our Montgomery radix is a full 32 bits, so we can just take ** the low word of these multiplications. */ {ModInt prod,b; b=base; while (!(expon&1)) { b=MontMul(b,b); expon>>=1; } prod=b; while (expon>>=1) { b=MontMul(b,b); if (expon&1) prod=MontMul(prod,b); } return prod; } #else #define ToMont(_x) (_x) #define FromMont(_x) (_x) #define MontMul(q1,q2) ModMul((q1),(q2)) #define MontPow(_a,_b) ModPow(_a,_b) #endif static void NTTReorder(ModInt *Data, int Len) {int Index,xednI,k; xednI = 0; for (Index = 0;Index < Len;Index++) { if (xednI > Index) {ModInt Temp; Temp=Data[xednI]; Data[xednI]=Data[Index]; Data[Index]=Temp; } k=Len/2; while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;} xednI+=k; } } void NTT(ModInt *Data, int Len, int Dir) /* A simple minded, generic transform */ {int j,step,halfstep; int index,index2; ModInt u,w,temp; NTTReorder(Data,Len); step=1; while (step < Len) { halfstep=step; step*=2; u=1; u=ToMont(u); /* ends up being just MontMulC */ /* ** We could do this... if (Dir > 0) w=ModPow(PrimvRoot,Prime-1-((Prime-1)/step)); else w=ModPow(PrimvRoot,(Prime-1)/step); w=ToMont(w); ** BUT plain ModMul is *VERY* slow. That's why we are doing ** Montgomery multiplication! So we have got to avoid it it. */ if (Dir > 0) w=MontPow(ToMont(PrimvRoot),Prime-1-((Prime-1)/step)); else w=MontPow(ToMont(PrimvRoot),(Prime-1)/step); for (j=0;j 0; x--) { Pyramid = NTTNum1[Len2 - x] + Carry; Carry = Pyramid / BASE; Prod[x - 1] = Pyramid % BASE; } } void InitFFT(unsigned long int Len,int Base,int BaseDig) {int Bytes; BASE=Base; BASE_DIG=BaseDig; Bytes=sizeof(ModInt)*CalcNTTLen(Len); NTTNum1=(ModInt*)malloc(Bytes); NTTNum2=(ModInt*)malloc(Bytes); if ((NTTNum1==NULL) || (NTTNum2==NULL)) { printf("Unable to allocate memory for NTTNum.\n"); printf("Len=%d Bytes=%d\n",(int)Len,(int)Bytes); exit(0); } Prime=0x3fdc000000000001ULL; /* 4087*2^50+1 Bits=61.99682653*/ PrimvRoot=3; ToMontC=(0-Prime) % Prime; FromMontC=FindInverse(ToMontC) % Prime; MontMulC=MONTLOW(Prime-2); /* MontToMontC=ModMul(ToMontC,ToMontC); MontFromMontC=ModMul(ToMontC,FromMontC); */ } void DeInitFFT(unsigned long int Len) { free(NTTNum1);free(NTTNum2); }