/*
ss2.c
*/
/*
** This file is placed into the public domain by its author,
** Carey Bloodworth (Carey@Bloodworth.org) on July 16, 2001
**
** This multiplication demo is not designed for high performance.
** It's a tutorial program designed to be used with the information
** on my web site at www.Bloodworth.org
*/
/*
** This file demonstrates a very basic Schonhage-Strassen style
** multiplication.
**
** The code is written for simplicity and clarity, NOT for
** performance, so it is *EXTREMELY* slow.
**
** I wanted you to be able to relate this style to a regular NTT
** multiplication, so that's how I wrote it. If you want better
** performance, take a look at the GMP program or Bruno Hailbe's
** CLN math package. (And you'll see why I wanted this code
** to be readable!)
**
** To compile this using GCC:
** gcc -Wall main.c ss2.c -o ss2.exe
*/
/*
** Do we want to do this using scaling, like Knuth describes?
#define USE_SCALING 1
*/
#include
#include
#include
#include
#include "wmodmath.h"
typedef short int Short;
static int BASE;
static int BASE_DIG;
ModIntPtr Prime=NULL; int PrimeExp=0;
ModIntPtr KRoot=NULL; int KRootExp=0;
ModIntPtr MulInv=NULL;
int
Log2(int Num)
{int x=-1;
if (Num==0) return 0;
while (Num) {x++;Num/=2;}
return x;
}
int
Pow2(int L)
{int p=1;
if (L<=0) return p;
while (L--) p*=2;
return p;
}
static void
NTTReOrder(ModIntPtr Data, size_t Len,size_t ModWords)
{size_t Index,xednI,k;
xednI=0;
for (Index=0;Index Index)
ModSwap(&Data[Index*ModWords],&Data[xednI*ModWords],ModWords);
k=Len/2;
while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;}
xednI+=k;
}
}
static void
SS_NTT(ModIntPtr Data, size_t Len,int Dir,size_t ModWords)
/*
** A Radix 2, decimation in time, iterative NTT.
** Special Schonhage-Strassen version.
*/
{size_t j,Step,HalfStep;
ModIntPtr T,End;
int wExp;
NTTReOrder(Data,Len,ModWords);
Step=1;
T=ModMalloc(1,ModWords);
End=&Data[Len*ModWords];
#ifdef USE_SCALING
Dir=1;
#endif
while (Step < Len)
{
HalfStep=Step;
Step*=2;
wExp=KRootExp*(Len/Step);
for (j=0;j 0)
ModMul2Exp(T,R,j*wExp,ModWords);
else
ModDiv2Exp(T,R,PrimeExp,j*wExp,ModWords);
ModSub(R,L,T,ModWords);
ModAdd(L,L,T,ModWords);
L+=(Step*ModWords);R+=(Step*ModWords);
}
}
}
ModFree(T);
}
#ifdef USE_SCALING
void
Scale(ModIntPtr Data,ModIntPtr Theta,int NTTLen,int Dir,int ModWords)
{int x;
int ThetaExp=KRootExp/2;
if (Dir < 0)
{
/* Flip the order. Only index 0 & Len/2 stay where they are. */
for (x=1;x 0)
ModMul2Exp(Data,Data,x*ThetaExp,ModWords);
else
ModDiv2Exp(Data,Data,PrimeExp,x*ThetaExp,ModWords);
Data+=ModWords;
}
}
#endif
void
LoadData(ModIntPtr NTTData,Short *Num, int NTTLen,
int BitsPerGroup,int NumLen,int ModWords)
/*
** The modmath package doesn't properly handle very very small
** primes so we have to normalize. And this routine itself
** doesn't handle cases where one 'base digit' would be split
** among two NTT elements.
*/
{int x,Nx,y;
int ShortsPerMod=BitsPerGroup/Log2(BASE);
ModIntPtr R;
if ((BitsPerGroup % Log2(BASE)) != 0)
{printf("Not multiple of base\n");exit(0);}
R=ModMalloc(1,ModWords);
for (x = 0; x < NTTLen; x++) ModClear(&NTTData[x*ModWords],ModWords);
Nx=NumLen;
for (x = 0; x < NTTLen/2; x++)
{
Nx-=ShortsPerMod;
ModClear(R,ModWords);
for (y=0;y BitsPerGroup*2)
{
printf("Wrong params.\n");
printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
exit(0);
}
if (BitsPerGroup < 16) {printf("L too small.\n");exit(0);} /* At least a word size. */
if (NTTLen < 8) {printf("NTTLen (%d) too small.\n",NTTLen);exit(0);}
/* Small product. WModMath.h neds at least two. */
SModWords=ModBits2ModWords(NTTLen);
if (SModWords < 2) SModWords=2;
/* Large product. WModMath.h neds at least two. */
ModWords=ModBits2ModWords(PBits);
if (ModWords < 2) ModWords=2;
InitModMath(ModWords+16);
/*
printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
*/
SProd=ModMalloc(NTTLen,SModWords);
NTTNum1=ModMalloc(NTTLen, ModWords);
NTTNum2=ModMalloc(NTTLen, ModWords);
Theta=ModMalloc(1,ModWords);
Prime=ModMalloc(1,ModWords);
KRoot=ModMalloc(1,ModWords);
MulInv=ModMalloc(1,ModWords);
if (Num1 == Num2) {printf("Schonhage-Strassen is hardwired for two vars.\n");exit(0);}
/* Do small product. */
CB_Set(Prime,1,SModWords);CB_ShiftLx(Prime,Log2(NTTLen),SModWords);PrimeExp=Log2(NTTLen);
SS_NTT2(SProd,Num1,Num2,NTTLen,BitsPerGroup,NumLen,SModWords);
/* Set up our primes and constants for the main SS NTT */
CB_Clear(Prime,ModWords);CB_SetBit(Prime,PBits,ModWords);
CB_SetBit(Prime,0,ModWords);
CB_Clear(Theta,ModWords);CB_SetBit(Theta,RootExp,ModWords);
CB_Clear(KRoot,ModWords);CB_SetBit(KRoot,2*RootExp,ModWords);
KRootExp=2*RootExp;
PrimeExp=PBits;
#if 0
ModIPow(MulInv, KRoot,NTTLen,ModWords);
printf("Check ");CB_Dump(MulInv,ModWords);
ModIPow(MulInv, Theta,NTTLen,ModWords);
printf("Check ");CB_Dump(MulInv,ModWords); /* Single bit. */
printf("Prime ");CB_Dump(Prime,ModWords);
printf("MulInv ");CB_Dump(MulInv,ModWords);
printf("Theta ");CB_Dump(Theta,ModWords);
printf("PrimvR ");CB_Dump(KRoot,ModWords);
#endif
LoadData(NTTNum1,Num1,NTTLen,BitsPerGroup,NumLen,ModWords);
LoadData(NTTNum2,Num2,NTTLen,BitsPerGroup,NumLen,ModWords);
#ifdef USE_SCALING
Scale(NTTNum1,Theta,NTTLen,1,ModWords);
Scale(NTTNum2,Theta,NTTLen,1,ModWords);
#endif
SS_NTT(NTTNum1, NTTLen, 1, ModWords);
SS_NTT(NTTNum2, NTTLen, 1, ModWords);
/*
** ****************NOTICE********************
** This convolution should actually be done by recursively
** calling Schonhage-Strassen until you get to a size small
** enough to do efficiently.
**
** HOWEVER, my generic math package here can't handle that,
** so I'm having to do this slowly. This kills the performance.
*/
for (x=0;x< NTTLen;x++)
ModMul(&NTTNum1[x*ModWords],
&NTTNum1[x*ModWords],
&NTTNum2[x*ModWords],ModWords);
SS_NTT(NTTNum1, NTTLen, -1, ModWords);
#ifdef USE_SCALING
Scale(NTTNum1,Theta,NTTLen,-1,ModWords);
#endif
for (x = 0;x < NTTLen; x++)
ModDiv2Exp(&NTTNum1[x*ModWords],&NTTNum1[x*ModWords],PrimeExp,Log2(NTTLen),ModWords);
/* Do CRT & release the carries */
{int PWords=ModWords+6; /* Pyramid words. */
int y,PX;
int ShortsPerMod=BitsPerGroup/Log2(BASE);
UINT32 W;
ModIntPtr RP,PyramidMod;
ModIntPtr Pyramid,Sum,Temp;
ModIntPtr PPrime;
Pyramid=ModMalloc(1,PWords);CB_Clear(Pyramid,PWords);
Sum=ModMalloc(1,PWords); CB_Clear(Sum,PWords);
Temp=ModMalloc(1,PWords); CB_Clear(Temp,PWords);
PPrime=ModMalloc(1,PWords); CB_Copy(PPrime,PWords,Prime,ModWords);
RP=ModMalloc(1,PWords); /* RP=r*(2^2L) */
PyramidMod=ModMalloc(1,PWords); /* PyramidMod=2^k * (2^(2L)+1) */
CB_Copy(PyramidMod,PWords,PPrime,PWords);CB_ShiftLx(PyramidMod,Log2(K),PWords);
/* PyramidMod=Prime*NTTLen */
if (Log2(K) > 30) {printf("K (%d) too big for int.\n",K);exit(0);}
PX=NumLen*2;
for (x = 0; x < NTTLen; x++)
{
/* Get the small digit */
W=CB_Get32(&SProd[x*SModWords],SModWords);
/* The modmath stuff isn't good at handling small primes, so make sure */
/* W=W % NTTLen; */
if (W >= K) {printf("W too big. %u %d\n",(unsigned int)W,K);exit(0);}
/* Do the CRT */
W=(W-CB_Get32(&NTTNum1[x*ModWords],ModWords)) % K;
CB_Copy(Temp,PWords,PPrime,PWords);
CB_MulUInt(Temp,W,PWords);
CB_Copy(Sum,PWords,&NTTNum1[x*ModWords],ModWords);
CB_Add(Sum,Sum,Temp,PWords);
/* Normalize it. It might be 'negative'. RP could be accumulated. */
CB_Set(RP,1,PWords);CB_ShiftLx(RP,PBits,PWords);CB_MulUInt(RP,x+1,PWords);
if (CB_Cmp(Sum,RP,PWords) >= 0) CB_Sub(Sum,Sum,PyramidMod,PWords);
CB_Add(Pyramid,Pyramid,Sum,PWords);
/* Release our carries */
for (y=0;y