/*
ss.c
*/
/*
** This file is placed into the public domain by its author,
** Carey Bloodworth (Carey@Bloodworth.org) on May 20, 2001
**
** This multiplication demo is not designed for high performance.
** It's a tutorial program designed to be used with the information
** on my web site at www.Bloodworth.org
*/
/*
** This file demonstrates a very basic Schonhage-Strassen style
** multiplication.
**
** The code is written for simplicity and clarity, NOT for
** performance, so it is *EXTREMELY* slow.
**
** I wanted you to be able to relate this style to a regular NTT
** multiplication, so that's how I wrote it. If you want better
** performance, take a look at the GMP program or Bruno Hailbe's
** CLN math package. (And you'll see why I wanted this code
** to be readable!)
**
** To compile this using GCC:
** gcc -Wall main.c ss.c -o ss.exe
*/
/*
** Do we want to do this using scaling, like Knuth describes?
#define USE_SCALING 1
*/
#include
#include
#include
#include
#include "wmodmath.h"
typedef short int Short;
static int BASE;
static int BASE_DIG;
ModIntPtr Prime=NULL;
ModIntPtr KRoot=NULL;
ModIntPtr MulInv=NULL;
int
Log2(int Num)
{int x=-1;
if (Num==0) return 0;
while (Num) {x++;Num/=2;}
return x;
}
int
Pow2(int L)
{int p=1;
if (L<=0) return p;
while (L--) p*=2;
return p;
}
static void
NTTReOrder(ModIntPtr Data, size_t Len,size_t ModWords)
{size_t Index,xednI,k;
xednI=0;
for (Index=0;Index Index)
ModSwap(&Data[Index*ModWords],&Data[xednI*ModWords],ModWords);
k=Len/2;
while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;}
xednI+=k;
}
}
static void
NTT(ModIntPtr Data, size_t Len,int Dir,size_t ModWords)
/*
** A Radix 2, decimation in time, iterative NTT.
*/
{size_t j,Step,HalfStep;
ModIntPtr u,w,T,End;
NTTReOrder(Data,Len,ModWords);
Step=1;
u=ModMalloc(1,ModWords);
w=ModMalloc(1,ModWords);
T=ModMalloc(1,ModWords);
End=&Data[Len*ModWords];
while (Step < Len)
{
HalfStep=Step;
Step*=2;
ModSet1(u,ModWords);
#ifdef USE_SCALING
ModIPow(w,KRoot,Len/Step,ModWords); // Treat Dir as always positive.
#else
if (Dir > 0) ModIPow(w,KRoot,Len/Step,ModWords);
else ModIPow(w,KRoot,Len-Len/Step,ModWords);
#endif
for (j=0;j= q) ModAdd(&Prod[p*ModWords],&Prod[p*ModWords],Temp,ModWords);
else ModSub(&Prod[p*ModWords],&Prod[p*ModWords],Temp,ModWords);
}
}
ModFree(N1);
ModFree(N2);
ModFree(Temp);
}
void
FFTMul(Short * Prod, Short *Num1, Short *Num2, int NumLen)
{int x;
int ModWords,SModWords;
ModIntPtr SProd,NTTNum1,NTTNum2;
ModIntPtr Theta;
int RootExp; // Root exponent.
int TotalBits,N,n; // Total bits in product.
int NTTLen,K,k; // number of groups the number is broken up into (NTT Len)
int BitsPerGroup,L,l; // number of bits in each group
// K <= L*2
// NTTLen <= Bits_Per_Group*2
int PBits; // How many bits in the pyramid.
// Zero pad for the double length product.
TotalBits = N =NumLen*2*Log2(BASE); n=Log2(N);
BitsPerGroup= L =Pow2(Log2(TotalBits) / 2); l=Log2(L);
NTTLen = K =TotalBits / L; k=Log2(K);
#if 0
// Example of fixed NTT length.
NTTLen = K =16; k=Log2(K);
BitsPerGroup= L =TotalBits/NTTLen; l=Log2(L);
#endif
// N, n K, k, L, l are the Knuth vars. Lower case is Log2(var)
PBits=BitsPerGroup*2;
if (PBits % NTTLen) {printf("\nPBits=%d NTTLen=%d\n",PBits,NTTLen);exit(0);}
RootExp=PBits/NTTLen; /* 2L+1-K */
#if 0
printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
#endif
// Sanity check...
if (NTTLen > BitsPerGroup*2)
{
printf("Wrong params.\n");
printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
exit(0);
}
if (BitsPerGroup < 16) {printf("L too small.\n");exit(0);} // At least a word size.
if (NTTLen < 8) {printf("NTTLen (%d) too small.\n",NTTLen);exit(0);}
SModWords=ModBits2ModWords(NTTLen); // Small product.
if (SModWords < 2) SModWords=2; // ModMath.h needs at least two.
ModWords=ModBits2ModWords(PBits);
if (ModWords < 2) ModWords=2; // ModMath.h needs at least two.
InitModMath(ModWords+16);
/*
printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
*/
SProd=ModMalloc(NTTLen,SModWords);
NTTNum1=ModMalloc(NTTLen, ModWords);
NTTNum2=ModMalloc(NTTLen, ModWords+4); // extra just for testing / verification
Theta=ModMalloc(1,ModWords);
Prime=ModMalloc(1,ModWords);
KRoot=ModMalloc(1,ModWords);
MulInv=ModMalloc(1,ModWords);
#if 1
if (Num1 == Num2) {printf("Schonhage-Strassen is hardwired for two vars.\n");exit(0);}
// Do small product.
// Actually, you should use a small NTT to do this.
// I'm doing a simple schoolboy because it's easier to code and it
// shows there is nothing special. It is just a standard negacyclic
// multiply modulo NTTLen
CB_Set(Prime,1,SModWords);CB_ShiftLx(Prime,Log2(NTTLen),SModWords);
NegaSchoolMul(SProd,Num1,Num2,NTTLen,BitsPerGroup,NumLen,SModWords);
// Set the prime. Can't use Mod...() cause it needs Prime & Prime would reduce.
CB_Set(Prime,1,ModWords);CB_ShiftLx(Prime,PBits,ModWords);CB_AddInt(Prime,1,ModWords);
ModSet(Theta,1,ModWords);CB_ShiftLx(Theta,RootExp,ModWords);
ModSet(KRoot,1,ModWords);CB_ShiftLx(KRoot,2*RootExp,ModWords);
//ModIPow(MulInv, KRoot,NTTLen,ModWords);
//printf("Check ");CB_Dump(MulInv,ModWords); // 1
//ModIPow(MulInv, Theta,NTTLen,ModWords);
//printf("Check ");CB_Dump(MulInv,ModWords); // Single bit.
FindInverse(MulInv,NTTLen,ModWords);
//printf("Prime ");CB_Dump(Prime,ModWords);
//printf("MulInv ");CB_Dump(MulInv,ModWords);
//printf("Theta ");CB_Dump(Theta,ModWords);
//printf("PrimvR ");CB_Dump(KRoot,ModWords);
LoadData(NTTNum1,Num1,NTTLen,BitsPerGroup,NumLen,ModWords);
LoadData(NTTNum2,Num2,NTTLen,BitsPerGroup,NumLen,ModWords);
#ifdef USE_SCALING
Scale(NTTNum1,Theta,NTTLen,1,ModWords);
Scale(NTTNum2,Theta,NTTLen,1,ModWords);
#endif
NTT(NTTNum1, NTTLen, 1, ModWords);
NTT(NTTNum2, NTTLen, 1, ModWords);
for (x=0;x< NTTLen;x++)
ModMul(&NTTNum1[x*ModWords],
&NTTNum1[x*ModWords],
&NTTNum2[x*ModWords],ModWords);
NTT(NTTNum1, NTTLen, -1, ModWords);
#ifdef USE_SCALING
Scale(NTTNum1,Theta,NTTLen,-1,ModWords);
#endif
for (x = 0;x < NTTLen; x++)
ModMul(&NTTNum1[x*ModWords],&NTTNum1[x*ModWords],MulInv,ModWords);
#endif
#if 0
// Schoolboy
// This lets me test the basic ModMath, LoadData, NegaSchoolMul & CRT
// Everything except the NTT & root stuff.
CB_Set(Prime,1,SModWords);CB_ShiftLx(Prime,Log2(NTTLen),SModWords);
NegaSchoolMul(SProd,Num1,Num2,NTTLen,BitsPerGroup,NumLen,SModWords);
CB_Set(Prime,1,ModWords);CB_ShiftLx(Prime,PBits,ModWords);CB_AddInt(Prime,1,ModWords);
NegaSchoolMul(NTTNum1,Num1,Num2,NTTLen,BitsPerGroup,NumLen,ModWords);
#endif
#if 0
for (x=0;x < NTTLen; x++)
{printf("x=%2d ",x);CB_Dump(&NTTNum1[x*ModWords],ModWords);}
exit(0);
#endif
#if 0
// Simple debugging verifcation of the two modulo'ed numbers.
{unsigned long long P1,P2;
int BModWords=ModWords+4;
CB_Set(Prime,1,BModWords);CB_ShiftLx(Prime,PBits+Log2(K)*3,BModWords);
NegaSchoolMul(NTTNum2,Num1,Num2,NTTLen,BitsPerGroup,NumLen,BModWords);
P1=1;P1=P1< 30) {printf("K (%d) too big for int.\n",K);exit(0);}
PX=NumLen*2;
for (x = 0; x < NTTLen; x++)
{
/* Get the small digit */
W=CB_Get32(&SProd[x*SModWords],SModWords);
W=W % NTTLen;// force it.
// The modmath stuff isn't good at handling small primes, so force it.
if (W >= K) {printf("W too big. %u %d\n",(unsigned int)W,K);exit(0);}
/* Do the CRT */
W=(W-CB_Get32(&NTTNum1[x*ModWords],ModWords)) % K;
CB_Copy(Temp,PWords,PPrime,PWords);
CB_MulUInt(Temp,W,PWords);
CB_Copy(Sum,PWords,&NTTNum1[x*ModWords],ModWords);
CB_Add(Sum,Sum,Temp,PWords);
// Normalize it. It might be 'negative'. RP could be accumulated.
// CB_Set(RP,1,PWords);CB_ShiftLx(RP,2*L,PWords);CB_MulUInt(RP,x+1,PWords);
CB_Set(RP,1,PWords);CB_ShiftLx(RP,PBits,PWords);CB_MulUInt(RP,x+1,PWords);
if (CB_Cmp(Sum,RP,PWords) >= 0) CB_Sub(Sum,Sum,PyramidMod,PWords);
CB_Add(Pyramid,Pyramid,Sum,PWords);
/* Release our carries */
for (y=0;y