/* ss2.c */ /* ** This file is placed into the public domain by its author, ** Carey Bloodworth (Carey@Bloodworth.org) on July 16, 2001 ** ** This multiplication demo is not designed for high performance. ** It's a tutorial program designed to be used with the information ** on my web site at www.Bloodworth.org */ /* ** This file demonstrates a very basic Schonhage-Strassen style ** multiplication. ** ** The code is written for simplicity and clarity, NOT for ** performance, so it is *EXTREMELY* slow. ** ** I wanted you to be able to relate this style to a regular NTT ** multiplication, so that's how I wrote it. If you want better ** performance, take a look at the GMP program or Bruno Hailbe's ** CLN math package. (And you'll see why I wanted this code ** to be readable!) ** ** To compile this using GCC: ** gcc -Wall main.c ss2.c -o ss2.exe */ /* ** Do we want to do this using scaling, like Knuth describes? #define USE_SCALING 1 */ #include#include #include #include #include "wmodmath.h" typedef short int Short; static int BASE; static int BASE_DIG; ModIntPtr Prime=NULL; int PrimeExp=0; ModIntPtr KRoot=NULL; int KRootExp=0; ModIntPtr MulInv=NULL; int Log2(int Num) {int x=-1; if (Num==0) return 0; while (Num) {x++;Num/=2;} return x; } int Pow2(int L) {int p=1; if (L<=0) return p; while (L--) p*=2; return p; } static void NTTReOrder(ModIntPtr Data, size_t Len,size_t ModWords) {size_t Index,xednI,k; xednI=0; for (Index=0;Index Index) ModSwap(&Data[Index*ModWords],&Data[xednI*ModWords],ModWords); k=Len/2; while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;} xednI+=k; } } static void SS_NTT(ModIntPtr Data, size_t Len,int Dir,size_t ModWords) /* ** A Radix 2, decimation in time, iterative NTT. ** Special Schonhage-Strassen version. */ {size_t j,Step,HalfStep; ModIntPtr T,End; int wExp; NTTReOrder(Data,Len,ModWords); Step=1; T=ModMalloc(1,ModWords); End=&Data[Len*ModWords]; #ifdef USE_SCALING Dir=1; #endif while (Step < Len) { HalfStep=Step; Step*=2; wExp=KRootExp*(Len/Step); for (j=0;j 0) ModMul2Exp(T,R,j*wExp,ModWords); else ModDiv2Exp(T,R,PrimeExp,j*wExp,ModWords); ModSub(R,L,T,ModWords); ModAdd(L,L,T,ModWords); L+=(Step*ModWords);R+=(Step*ModWords); } } } ModFree(T); } #ifdef USE_SCALING void Scale(ModIntPtr Data,ModIntPtr Theta,int NTTLen,int Dir,int ModWords) {int x; int ThetaExp=KRootExp/2; if (Dir < 0) { /* Flip the order. Only index 0 & Len/2 stay where they are. */ for (x=1;x 0) ModMul2Exp(Data,Data,x*ThetaExp,ModWords); else ModDiv2Exp(Data,Data,PrimeExp,x*ThetaExp,ModWords); Data+=ModWords; } } #endif void LoadData(ModIntPtr NTTData,Short *Num, int NTTLen, int BitsPerGroup,int NumLen,int ModWords) /* ** The modmath package doesn't properly handle very very small ** primes so we have to normalize. And this routine itself ** doesn't handle cases where one 'base digit' would be split ** among two NTT elements. */ {int x,Nx,y; int ShortsPerMod=BitsPerGroup/Log2(BASE); ModIntPtr R; if ((BitsPerGroup % Log2(BASE)) != 0) {printf("Not multiple of base\n");exit(0);} R=ModMalloc(1,ModWords); for (x = 0; x < NTTLen; x++) ModClear(&NTTData[x*ModWords],ModWords); Nx=NumLen; for (x = 0; x < NTTLen/2; x++) { Nx-=ShortsPerMod; ModClear(R,ModWords); for (y=0;y BitsPerGroup*2) { printf("Wrong params.\n"); printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n", TotalBits,NTTLen,BitsPerGroup,RootExp,PBits); exit(0); } if (BitsPerGroup < 16) {printf("L too small.\n");exit(0);} /* At least a word size. */ if (NTTLen < 8) {printf("NTTLen (%d) too small.\n",NTTLen);exit(0);} /* Small product. WModMath.h neds at least two. */ SModWords=ModBits2ModWords(NTTLen); if (SModWords < 2) SModWords=2; /* Large product. WModMath.h neds at least two. */ ModWords=ModBits2ModWords(PBits); if (ModWords < 2) ModWords=2; InitModMath(ModWords+16); /* printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n", TotalBits,NTTLen,BitsPerGroup,RootExp,PBits); */ SProd=ModMalloc(NTTLen,SModWords); NTTNum1=ModMalloc(NTTLen, ModWords); NTTNum2=ModMalloc(NTTLen, ModWords); Theta=ModMalloc(1,ModWords); Prime=ModMalloc(1,ModWords); KRoot=ModMalloc(1,ModWords); MulInv=ModMalloc(1,ModWords); if (Num1 == Num2) {printf("Schonhage-Strassen is hardwired for two vars.\n");exit(0);} /* Do small product. */ CB_Set(Prime,1,SModWords);CB_ShiftLx(Prime,Log2(NTTLen),SModWords);PrimeExp=Log2(NTTLen); SS_NTT2(SProd,Num1,Num2,NTTLen,BitsPerGroup,NumLen,SModWords); /* Set up our primes and constants for the main SS NTT */ CB_Clear(Prime,ModWords);CB_SetBit(Prime,PBits,ModWords); CB_SetBit(Prime,0,ModWords); CB_Clear(Theta,ModWords);CB_SetBit(Theta,RootExp,ModWords); CB_Clear(KRoot,ModWords);CB_SetBit(KRoot,2*RootExp,ModWords); KRootExp=2*RootExp; PrimeExp=PBits; #if 0 ModIPow(MulInv, KRoot,NTTLen,ModWords); printf("Check ");CB_Dump(MulInv,ModWords); ModIPow(MulInv, Theta,NTTLen,ModWords); printf("Check ");CB_Dump(MulInv,ModWords); /* Single bit. */ printf("Prime ");CB_Dump(Prime,ModWords); printf("MulInv ");CB_Dump(MulInv,ModWords); printf("Theta ");CB_Dump(Theta,ModWords); printf("PrimvR ");CB_Dump(KRoot,ModWords); #endif LoadData(NTTNum1,Num1,NTTLen,BitsPerGroup,NumLen,ModWords); LoadData(NTTNum2,Num2,NTTLen,BitsPerGroup,NumLen,ModWords); #ifdef USE_SCALING Scale(NTTNum1,Theta,NTTLen,1,ModWords); Scale(NTTNum2,Theta,NTTLen,1,ModWords); #endif SS_NTT(NTTNum1, NTTLen, 1, ModWords); SS_NTT(NTTNum2, NTTLen, 1, ModWords); /* ** ****************NOTICE******************** ** This convolution should actually be done by recursively ** calling Schonhage-Strassen until you get to a size small ** enough to do efficiently. ** ** HOWEVER, my generic math package here can't handle that, ** so I'm having to do this slowly. This kills the performance. */ for (x=0;x< NTTLen;x++) ModMul(&NTTNum1[x*ModWords], &NTTNum1[x*ModWords], &NTTNum2[x*ModWords],ModWords); SS_NTT(NTTNum1, NTTLen, -1, ModWords); #ifdef USE_SCALING Scale(NTTNum1,Theta,NTTLen,-1,ModWords); #endif for (x = 0;x < NTTLen; x++) ModDiv2Exp(&NTTNum1[x*ModWords],&NTTNum1[x*ModWords],PrimeExp,Log2(NTTLen),ModWords); /* Do CRT & release the carries */ {int PWords=ModWords+6; /* Pyramid words. */ int y,PX; int ShortsPerMod=BitsPerGroup/Log2(BASE); UINT32 W; ModIntPtr RP,PyramidMod; ModIntPtr Pyramid,Sum,Temp; ModIntPtr PPrime; Pyramid=ModMalloc(1,PWords);CB_Clear(Pyramid,PWords); Sum=ModMalloc(1,PWords); CB_Clear(Sum,PWords); Temp=ModMalloc(1,PWords); CB_Clear(Temp,PWords); PPrime=ModMalloc(1,PWords); CB_Copy(PPrime,PWords,Prime,ModWords); RP=ModMalloc(1,PWords); /* RP=r*(2^2L) */ PyramidMod=ModMalloc(1,PWords); /* PyramidMod=2^k * (2^(2L)+1) */ CB_Copy(PyramidMod,PWords,PPrime,PWords);CB_ShiftLx(PyramidMod,Log2(K),PWords); /* PyramidMod=Prime*NTTLen */ if (Log2(K) > 30) {printf("K (%d) too big for int.\n",K);exit(0);} PX=NumLen*2; for (x = 0; x < NTTLen; x++) { /* Get the small digit */ W=CB_Get32(&SProd[x*SModWords],SModWords); /* The modmath stuff isn't good at handling small primes, so make sure */ /* W=W % NTTLen; */ if (W >= K) {printf("W too big. %u %d\n",(unsigned int)W,K);exit(0);} /* Do the CRT */ W=(W-CB_Get32(&NTTNum1[x*ModWords],ModWords)) % K; CB_Copy(Temp,PWords,PPrime,PWords); CB_MulUInt(Temp,W,PWords); CB_Copy(Sum,PWords,&NTTNum1[x*ModWords],ModWords); CB_Add(Sum,Sum,Temp,PWords); /* Normalize it. It might be 'negative'. RP could be accumulated. */ CB_Set(RP,1,PWords);CB_ShiftLx(RP,PBits,PWords);CB_MulUInt(RP,x+1,PWords); if (CB_Cmp(Sum,RP,PWords) >= 0) CB_Sub(Sum,Sum,PyramidMod,PWords); CB_Add(Pyramid,Pyramid,Sum,PWords); /* Release our carries */ for (y=0;y