/*
     ss.c
*/
/*
** This file is placed into the public domain by its author,
** Carey Bloodworth (Carey@Bloodworth.org) on May 20, 2001
**
** This multiplication demo is not designed for high performance.
** It's a tutorial program designed to be used with the information
** on my web site at www.Bloodworth.org
*/
/*
** This file demonstrates a very basic Schonhage-Strassen style
** multiplication.
**
** The code is written for simplicity and clarity, NOT for
** performance, so it is *EXTREMELY* slow.
**
** I wanted you to be able to relate this style to a regular NTT
** multiplication, so that's how I wrote it.  If you want better
** performance, take a look at the GMP program or Bruno Hailbe's
** CLN math package.  (And you'll see why I wanted this code
** to be readable!)
**
** To compile this using GCC:
** gcc -Wall main.c ss.c -o ss.exe
*/

/*
** Do we want to do this using scaling, like Knuth describes?
#define USE_SCALING 1
*/

#include 
#include 
#include 
#include 

#include "wmodmath.h"

typedef short int Short;

static int BASE;
static int BASE_DIG;

ModIntPtr Prime=NULL;
ModIntPtr KRoot=NULL;
ModIntPtr MulInv=NULL;


int
Log2(int Num)
{int x=-1;
  if (Num==0) return 0;
  while (Num) {x++;Num/=2;}
  return x;
}

int
Pow2(int L)
{int p=1;
if (L<=0) return p;
while (L--) p*=2;
return p;
}

static void
NTTReOrder(ModIntPtr Data, size_t Len,size_t ModWords)
{size_t Index,xednI,k;

xednI=0;
for (Index=0;Index Index)
     ModSwap(&Data[Index*ModWords],&Data[xednI*ModWords],ModWords);

   k=Len/2;
   while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;}
   xednI+=k;
  }
}

static void
NTT(ModIntPtr Data, size_t Len,int Dir,size_t ModWords)
/*
** A Radix 2, decimation in time, iterative NTT.
*/
{size_t j,Step,HalfStep;
 ModIntPtr u,w,T,End;

NTTReOrder(Data,Len,ModWords);

Step=1;
u=ModMalloc(1,ModWords);
w=ModMalloc(1,ModWords);
T=ModMalloc(1,ModWords);
End=&Data[Len*ModWords];

while (Step < Len)
  {
   HalfStep=Step;
   Step*=2;

   ModSet1(u,ModWords);
#ifdef USE_SCALING
   ModIPow(w,KRoot,Len/Step,ModWords); // Treat Dir as always positive.
#else
   if (Dir > 0) ModIPow(w,KRoot,Len/Step,ModWords);
   else         ModIPow(w,KRoot,Len-Len/Step,ModWords);
#endif

   for (j=0;j= q) ModAdd(&Prod[p*ModWords],&Prod[p*ModWords],Temp,ModWords);
      else        ModSub(&Prod[p*ModWords],&Prod[p*ModWords],Temp,ModWords);
     }
  }

ModFree(N1);
ModFree(N2);
ModFree(Temp);
}

void
FFTMul(Short * Prod, Short *Num1, Short *Num2, int NumLen)
{int x;
int ModWords,SModWords;
ModIntPtr SProd,NTTNum1,NTTNum2;
ModIntPtr Theta;


int RootExp; // Root exponent.
int TotalBits,N,n; // Total bits in product.
int NTTLen,K,k; // number of groups the number is broken up into (NTT Len)
int BitsPerGroup,L,l; // number of bits in each group
// K <= L*2
// NTTLen <= Bits_Per_Group*2

int PBits; // How many bits in the pyramid.

// Zero pad for the double length product.
TotalBits   = N =NumLen*2*Log2(BASE);         n=Log2(N);
BitsPerGroup= L =Pow2(Log2(TotalBits) / 2);   l=Log2(L);
NTTLen      = K =TotalBits / L;               k=Log2(K);

#if 0
// Example of fixed NTT length.
NTTLen      = K =16;                          k=Log2(K);
BitsPerGroup= L =TotalBits/NTTLen;            l=Log2(L);
#endif

// N, n K, k, L, l are the Knuth vars.  Lower case is Log2(var)

PBits=BitsPerGroup*2;
if (PBits % NTTLen) {printf("\nPBits=%d NTTLen=%d\n",PBits,NTTLen);exit(0);}
RootExp=PBits/NTTLen; /* 2L+1-K */

#if 0
printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
       TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
#endif

// Sanity check...
if (NTTLen > BitsPerGroup*2)
  {
   printf("Wrong params.\n");
   printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
          TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
   exit(0);
  }

if (BitsPerGroup < 16) {printf("L too small.\n");exit(0);} // At least a word size.
if (NTTLen < 8) {printf("NTTLen (%d) too small.\n",NTTLen);exit(0);}

SModWords=ModBits2ModWords(NTTLen); // Small product.
if (SModWords < 2) SModWords=2; // ModMath.h needs at least two.

ModWords=ModBits2ModWords(PBits);
if (ModWords < 2) ModWords=2; // ModMath.h needs at least two.
InitModMath(ModWords+16);

/*
printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
       TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
*/

SProd=ModMalloc(NTTLen,SModWords);
NTTNum1=ModMalloc(NTTLen, ModWords);
NTTNum2=ModMalloc(NTTLen, ModWords+4); // extra just for testing / verification
Theta=ModMalloc(1,ModWords);

Prime=ModMalloc(1,ModWords);
KRoot=ModMalloc(1,ModWords);
MulInv=ModMalloc(1,ModWords);

#if 1
if (Num1 == Num2) {printf("Schonhage-Strassen is hardwired for two vars.\n");exit(0);}

// Do small product.
// Actually, you should use a small NTT to do this.
// I'm doing a simple schoolboy because it's easier to code and it
// shows there is nothing special.  It is just a standard negacyclic
// multiply modulo NTTLen
CB_Set(Prime,1,SModWords);CB_ShiftLx(Prime,Log2(NTTLen),SModWords);
NegaSchoolMul(SProd,Num1,Num2,NTTLen,BitsPerGroup,NumLen,SModWords);

// Set the prime.  Can't use Mod...() cause it needs Prime & Prime would reduce.
CB_Set(Prime,1,ModWords);CB_ShiftLx(Prime,PBits,ModWords);CB_AddInt(Prime,1,ModWords);
ModSet(Theta,1,ModWords);CB_ShiftLx(Theta,RootExp,ModWords);
ModSet(KRoot,1,ModWords);CB_ShiftLx(KRoot,2*RootExp,ModWords);

//ModIPow(MulInv, KRoot,NTTLen,ModWords);
//printf("Check  ");CB_Dump(MulInv,ModWords); // 1
//ModIPow(MulInv, Theta,NTTLen,ModWords);
//printf("Check  ");CB_Dump(MulInv,ModWords); // Single bit.

FindInverse(MulInv,NTTLen,ModWords);

//printf("Prime  ");CB_Dump(Prime,ModWords);
//printf("MulInv ");CB_Dump(MulInv,ModWords);
//printf("Theta  ");CB_Dump(Theta,ModWords);
//printf("PrimvR ");CB_Dump(KRoot,ModWords);

LoadData(NTTNum1,Num1,NTTLen,BitsPerGroup,NumLen,ModWords);
LoadData(NTTNum2,Num2,NTTLen,BitsPerGroup,NumLen,ModWords);
#ifdef USE_SCALING
Scale(NTTNum1,Theta,NTTLen,1,ModWords);
Scale(NTTNum2,Theta,NTTLen,1,ModWords);
#endif
NTT(NTTNum1, NTTLen, 1, ModWords);
NTT(NTTNum2, NTTLen, 1, ModWords);

for (x=0;x< NTTLen;x++)
  ModMul(&NTTNum1[x*ModWords],
         &NTTNum1[x*ModWords],
         &NTTNum2[x*ModWords],ModWords);

NTT(NTTNum1, NTTLen, -1, ModWords);

#ifdef USE_SCALING
Scale(NTTNum1,Theta,NTTLen,-1,ModWords);
#endif

for (x = 0;x < NTTLen; x++)
   ModMul(&NTTNum1[x*ModWords],&NTTNum1[x*ModWords],MulInv,ModWords);
#endif

#if 0
// Schoolboy
// This lets me test the basic ModMath, LoadData, NegaSchoolMul & CRT
// Everything except the NTT & root stuff.

CB_Set(Prime,1,SModWords);CB_ShiftLx(Prime,Log2(NTTLen),SModWords);
NegaSchoolMul(SProd,Num1,Num2,NTTLen,BitsPerGroup,NumLen,SModWords);

CB_Set(Prime,1,ModWords);CB_ShiftLx(Prime,PBits,ModWords);CB_AddInt(Prime,1,ModWords);
NegaSchoolMul(NTTNum1,Num1,Num2,NTTLen,BitsPerGroup,NumLen,ModWords);
#endif


#if 0
for (x=0;x < NTTLen; x++)
  {printf("x=%2d ",x);CB_Dump(&NTTNum1[x*ModWords],ModWords);}
exit(0);
#endif

#if 0
// Simple debugging verifcation of the two modulo'ed numbers.
{unsigned long long P1,P2;
 int BModWords=ModWords+4;

CB_Set(Prime,1,BModWords);CB_ShiftLx(Prime,PBits+Log2(K)*3,BModWords);
NegaSchoolMul(NTTNum2,Num1,Num2,NTTLen,BitsPerGroup,NumLen,BModWords);

P1=1;P1=P1< 30) {printf("K (%d) too big for int.\n",K);exit(0);}

PX=NumLen*2;
for (x = 0; x < NTTLen; x++)
  {
/* Get the small digit */
   W=CB_Get32(&SProd[x*SModWords],SModWords);

W=W % NTTLen;// force it.
// The modmath stuff isn't good at handling small primes, so force it.
if (W >= K) {printf("W too big. %u %d\n",(unsigned int)W,K);exit(0);}

/* Do the CRT */
   W=(W-CB_Get32(&NTTNum1[x*ModWords],ModWords)) % K;
   CB_Copy(Temp,PWords,PPrime,PWords);
   CB_MulUInt(Temp,W,PWords);

   CB_Copy(Sum,PWords,&NTTNum1[x*ModWords],ModWords);
   CB_Add(Sum,Sum,Temp,PWords);

// Normalize it. It might be 'negative'.  RP could be accumulated.
//   CB_Set(RP,1,PWords);CB_ShiftLx(RP,2*L,PWords);CB_MulUInt(RP,x+1,PWords);
   CB_Set(RP,1,PWords);CB_ShiftLx(RP,PBits,PWords);CB_MulUInt(RP,x+1,PWords);
   if (CB_Cmp(Sum,RP,PWords) >= 0) CB_Sub(Sum,Sum,PyramidMod,PWords);

   CB_Add(Pyramid,Pyramid,Sum,PWords);

/* Release our carries */
   for (y=0;y