/*
     ss2.c
*/
/*
** This file is placed into the public domain by its author,
** Carey Bloodworth (Carey@Bloodworth.org) on July 16, 2001
**
** This multiplication demo is not designed for high performance.
** It's a tutorial program designed to be used with the information
** on my web site at www.Bloodworth.org
*/
/*
** This file demonstrates a very basic Schonhage-Strassen style
** multiplication.
**
** The code is written for simplicity and clarity, NOT for
** performance, so it is *EXTREMELY* slow.
**
** I wanted you to be able to relate this style to a regular NTT
** multiplication, so that's how I wrote it.  If you want better
** performance, take a look at the GMP program or Bruno Hailbe's
** CLN math package.  (And you'll see why I wanted this code
** to be readable!)
**
** To compile this using GCC:
** gcc -Wall main.c ss2.c -o ss2.exe
*/

/*
** Do we want to do this using scaling, like Knuth describes?
#define USE_SCALING 1
*/

#include 
#include 
#include 
#include 

#include "wmodmath.h"

typedef short int Short;

static int BASE;
static int BASE_DIG;

ModIntPtr Prime=NULL;    int PrimeExp=0;
ModIntPtr KRoot=NULL;    int KRootExp=0;
ModIntPtr MulInv=NULL;


int
Log2(int Num)
{int x=-1;
  if (Num==0) return 0;
  while (Num) {x++;Num/=2;}
  return x;
}

int
Pow2(int L)
{int p=1;
if (L<=0) return p;
while (L--) p*=2;
return p;
}

static void
NTTReOrder(ModIntPtr Data, size_t Len,size_t ModWords)
{size_t Index,xednI,k;

xednI=0;
for (Index=0;Index Index)
     ModSwap(&Data[Index*ModWords],&Data[xednI*ModWords],ModWords);

   k=Len/2;
   while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;}
   xednI+=k;
  }
}

static void
SS_NTT(ModIntPtr Data, size_t Len,int Dir,size_t ModWords)
/*
** A Radix 2, decimation in time, iterative NTT.
** Special Schonhage-Strassen version.
*/
{size_t j,Step,HalfStep;
 ModIntPtr T,End;
 int wExp;

NTTReOrder(Data,Len,ModWords);

Step=1;
T=ModMalloc(1,ModWords);
End=&Data[Len*ModWords];

#ifdef USE_SCALING
Dir=1;
#endif

while (Step < Len)
  {
   HalfStep=Step;
   Step*=2;

   wExp=KRootExp*(Len/Step);

   for (j=0;j 0)
           ModMul2Exp(T,R,j*wExp,ModWords);
         else
           ModDiv2Exp(T,R,PrimeExp,j*wExp,ModWords);
         ModSub(R,L,T,ModWords);
         ModAdd(L,L,T,ModWords);
         L+=(Step*ModWords);R+=(Step*ModWords);
        }
     }
  }
ModFree(T);
}

#ifdef USE_SCALING
void
Scale(ModIntPtr Data,ModIntPtr Theta,int NTTLen,int Dir,int ModWords)
{int x;
 int ThetaExp=KRootExp/2;

if (Dir < 0)
  {
/* Flip the order.  Only index 0 & Len/2 stay where they are. */
   for (x=1;x 0)
      ModMul2Exp(Data,Data,x*ThetaExp,ModWords);
   else
      ModDiv2Exp(Data,Data,PrimeExp,x*ThetaExp,ModWords);
   Data+=ModWords;
  }
}
#endif

void
LoadData(ModIntPtr NTTData,Short *Num, int NTTLen,
         int BitsPerGroup,int NumLen,int ModWords)
/*
** The modmath package doesn't properly handle very very small
** primes so we have to normalize.  And this routine itself
** doesn't handle cases where one 'base digit' would be split
** among two NTT elements.
*/
{int x,Nx,y;
 int ShortsPerMod=BitsPerGroup/Log2(BASE);
 ModIntPtr R;

 if ((BitsPerGroup % Log2(BASE)) != 0)
   {printf("Not multiple of base\n");exit(0);}

 R=ModMalloc(1,ModWords);
 for (x = 0; x < NTTLen; x++) ModClear(&NTTData[x*ModWords],ModWords);

 Nx=NumLen;
 for (x = 0; x < NTTLen/2; x++)
   {
    Nx-=ShortsPerMod;
    ModClear(R,ModWords);
    for (y=0;y BitsPerGroup*2)
  {
   printf("Wrong params.\n");
   printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
          TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
   exit(0);
  }

if (BitsPerGroup < 16) {printf("L too small.\n");exit(0);} /* At least a word size. */
if (NTTLen < 8) {printf("NTTLen (%d) too small.\n",NTTLen);exit(0);}

/* Small product.  WModMath.h neds at least two. */
SModWords=ModBits2ModWords(NTTLen);
if (SModWords < 2) SModWords=2;

/* Large product.  WModMath.h neds at least two. */
ModWords=ModBits2ModWords(PBits);
if (ModWords < 2) ModWords=2;
InitModMath(ModWords+16);

/*
printf("\nTotalBits=%d NTTLen=%d BitsPerGroup=%d RootExp=%d PBits=%d\n",
       TotalBits,NTTLen,BitsPerGroup,RootExp,PBits);
*/

SProd=ModMalloc(NTTLen,SModWords);
NTTNum1=ModMalloc(NTTLen, ModWords);
NTTNum2=ModMalloc(NTTLen, ModWords);
Theta=ModMalloc(1,ModWords);

Prime=ModMalloc(1,ModWords);
KRoot=ModMalloc(1,ModWords);
MulInv=ModMalloc(1,ModWords);

if (Num1 == Num2) {printf("Schonhage-Strassen is hardwired for two vars.\n");exit(0);}

/* Do small product. */
CB_Set(Prime,1,SModWords);CB_ShiftLx(Prime,Log2(NTTLen),SModWords);PrimeExp=Log2(NTTLen);
SS_NTT2(SProd,Num1,Num2,NTTLen,BitsPerGroup,NumLen,SModWords);

/* Set up our primes and constants for the main SS NTT */
CB_Clear(Prime,ModWords);CB_SetBit(Prime,PBits,ModWords);
                         CB_SetBit(Prime,0,ModWords);
CB_Clear(Theta,ModWords);CB_SetBit(Theta,RootExp,ModWords);
CB_Clear(KRoot,ModWords);CB_SetBit(KRoot,2*RootExp,ModWords);

KRootExp=2*RootExp;
PrimeExp=PBits;

#if 0
ModIPow(MulInv, KRoot,NTTLen,ModWords);
printf("Check  ");CB_Dump(MulInv,ModWords);
ModIPow(MulInv, Theta,NTTLen,ModWords);
printf("Check  ");CB_Dump(MulInv,ModWords); /* Single bit. */

printf("Prime  ");CB_Dump(Prime,ModWords);
printf("MulInv ");CB_Dump(MulInv,ModWords);
printf("Theta  ");CB_Dump(Theta,ModWords);
printf("PrimvR ");CB_Dump(KRoot,ModWords);
#endif

LoadData(NTTNum1,Num1,NTTLen,BitsPerGroup,NumLen,ModWords);
LoadData(NTTNum2,Num2,NTTLen,BitsPerGroup,NumLen,ModWords);
#ifdef USE_SCALING
Scale(NTTNum1,Theta,NTTLen,1,ModWords);
Scale(NTTNum2,Theta,NTTLen,1,ModWords);
#endif
SS_NTT(NTTNum1, NTTLen, 1, ModWords);
SS_NTT(NTTNum2, NTTLen, 1, ModWords);

/*
** ****************NOTICE********************
** This convolution should actually be done by recursively
** calling Schonhage-Strassen until you get to a size small
** enough to do efficiently.
**
** HOWEVER, my generic math package here can't handle that,
** so I'm having to do this slowly.  This kills the performance.
*/
for (x=0;x< NTTLen;x++)
  ModMul(&NTTNum1[x*ModWords],
         &NTTNum1[x*ModWords],
         &NTTNum2[x*ModWords],ModWords);

SS_NTT(NTTNum1, NTTLen, -1, ModWords);

#ifdef USE_SCALING
Scale(NTTNum1,Theta,NTTLen,-1,ModWords);
#endif

for (x = 0;x < NTTLen; x++)
   ModDiv2Exp(&NTTNum1[x*ModWords],&NTTNum1[x*ModWords],PrimeExp,Log2(NTTLen),ModWords);


/* Do CRT & release the carries */
{int PWords=ModWords+6; /* Pyramid words. */
 int y,PX;
 int ShortsPerMod=BitsPerGroup/Log2(BASE);
 UINT32 W;
 ModIntPtr RP,PyramidMod;
 ModIntPtr Pyramid,Sum,Temp;
 ModIntPtr PPrime;

Pyramid=ModMalloc(1,PWords);CB_Clear(Pyramid,PWords);
Sum=ModMalloc(1,PWords);    CB_Clear(Sum,PWords);
Temp=ModMalloc(1,PWords);   CB_Clear(Temp,PWords);
PPrime=ModMalloc(1,PWords); CB_Copy(PPrime,PWords,Prime,ModWords);

RP=ModMalloc(1,PWords);         /* RP=r*(2^2L) */
PyramidMod=ModMalloc(1,PWords); /* PyramidMod=2^k * (2^(2L)+1) */
CB_Copy(PyramidMod,PWords,PPrime,PWords);CB_ShiftLx(PyramidMod,Log2(K),PWords);
/* PyramidMod=Prime*NTTLen */

if (Log2(K) > 30) {printf("K (%d) too big for int.\n",K);exit(0);}

PX=NumLen*2;
for (x = 0; x < NTTLen; x++)
  {
/* Get the small digit */
   W=CB_Get32(&SProd[x*SModWords],SModWords);

/* The modmath stuff isn't good at handling small primes, so make sure */
/* W=W % NTTLen; */
if (W >= K) {printf("W too big. %u %d\n",(unsigned int)W,K);exit(0);}

/* Do the CRT */
   W=(W-CB_Get32(&NTTNum1[x*ModWords],ModWords)) % K;
   CB_Copy(Temp,PWords,PPrime,PWords);
   CB_MulUInt(Temp,W,PWords);

   CB_Copy(Sum,PWords,&NTTNum1[x*ModWords],ModWords);
   CB_Add(Sum,Sum,Temp,PWords);

/* Normalize it. It might be 'negative'.  RP could be accumulated. */
   CB_Set(RP,1,PWords);CB_ShiftLx(RP,PBits,PWords);CB_MulUInt(RP,x+1,PWords);
   if (CB_Cmp(Sum,RP,PWords) >= 0) CB_Sub(Sum,Sum,PyramidMod,PWords);

   CB_Add(Pyramid,Pyramid,Sum,PWords);

/* Release our carries */
   for (y=0;y