/*
     MPNTT31.C
*/
/*
** This file is placed into the public domain by its author,
** Carey Bloodworth (Carey@Bloodworth.org) on July 16, 2001
**
** This multiplication demo is not designed for high performance.
** It's a tutorial program designed to be used with the information
** on my web site at www.Bloodworth.org
*/
/*
** This file demonstrates a very basic multi-prime NTT multiply.
** It uses up to four 31 bit primes.  It puts 16 decimals per
** NTT element and can theoretically multiply numbers up to
** 16 million decimals.  (Although the program certainly isn't
** optimized for that large of tests.)
**
** The example program (main.c) might have 1, 2, or 4 decimals
** per Short, so I have to be flexible.  We always put 16 decimals
** into each multi-prime NTT element, though.
**
** The program uses very generic code, so it should work with
** every compiler.
**
** To compile this using GCC:
** gcc main.c mpntt31.c -o mpntt31.exe
*/
#include 
#include 
#include 
#include 
#include 


#define KNUTH_CRT 1
/* Knuth style CRT requires primes to be in ascending order */
/* Non-Knuth can do primes in any order */
/* Comment out to use other style. */

typedef unsigned short     int UINT16;
typedef unsigned long      int UINT32;
typedef   signed short     int  INT16;
typedef   signed long      int  INT32;
typedef UINT32 ModInt;
typedef short int Short;

static int BASE;
static int BASE_DIG;

/* These are set *by design* of the multi-prime NTT */
#define DigitsPerMod  16
#define NPrimes        4
#define DigitsPerPrime (DigitsPerMod/NPrimes)

#define ShortsPerMod   (DigitsPerMod/BASE_DIG)
/*
** Must be an even multiple.  Can't have 3 digits per short, for example.
** Not a good idea to have a div in there, but in a real program, this
** would be hard wired.  This example program might have 1, 2, or 4
** decimals per short and I have to allow for that.
*/

#include "crt.h"

/*#define CalcNTTLen(_Len) ((_Len*2*BASE_DIG)/DigitsPerMod)*/
#define CalcNTTLen(_Len) ((_Len*2)/ShortsPerMod)


typedef struct {ModInt Prime,PrimvRoot,MulInv;double RPrime;} ConstList;
ConstList Consts[NPrimes];

ModInt Prime,PrimvRoot,MulInv;
double RPrime;

static ModInt *NTTNum1=NULL, *NTTNum2=NULL;

#ifdef KNUTH_CRT
ModInt Inverses[NPrimes][NPrimes];
#else
ModInt Inverses[NPrimes];
#endif
CRTNum PProds[NPrimes+1][CRT_LEN];

ModInt PrimeList[NPrimes]=
  {
   1811939329, /* Bits= 30.73 */
   2013265921, /* Bits= 30.89 */
   2113929217, /* Bits= 30.96 */
   2130706433  /* Bits= 30.97 */
  };     /* Total Bits=123.55 */

/* The primative roots of unity for the primes. */
ModInt PrimvRootList[NPrimes]={13,31,5,3};

#if 0
/* For 31 & 32 bit primes, you have a practical choice of:  */
   1811939329,   /* Bits= 30.73 pr=13*/
   2013265921,   /* Bits= 30.89 pr=31*/
   2113929217,   /* Bits= 30.96 pr=5 */
   2130706433,   /* Bits= 30.97 pr=3 */

/* Warning: 32 bit primes.  ModMul limited to 31 */
   2885681153UL, /* Bits= 31.42 pr=3 */
   3942645761UL, /* Bits= 31.87 pr=3 */
   4076863489UL, /* Bits= 31.92 pr=7 */
   4106223617UL, /* Bits= 31.93 pr=3 */
   4194304001UL, /* Bits= 31.96 pr=3 */
   4253024257UL, /* Bits= 31.98 pr=5 */
#endif


ModInt
ModMul(ModInt a, ModInt b)
/* Limited to 31 bits. */
{INT32 rem;
rem = a * b;
rem = rem - Prime * ((ModInt) floor(0.5+RPrime * ((double) a) * ((double) b)));
if (rem < 0) rem +=Prime;
return rem;
}

ModInt ModAdd(ModInt a, ModInt b)
{double x;
x=a;
x+=b;
if (x >= Prime) x-=Prime;
return (ModInt)x;
}

ModInt ModSub(ModInt a, ModInt b)
{double r;
r=a;
r-=b;
if (r < 0) r+=Prime;
return (ModInt)r;
}


ModInt
ModPow(ModInt Base,UINT32 Expon)
{ModInt prod,b;

if (Expon<=0) return 1;

b=Base;
while (!(Expon&1)) {b=ModMul(b,b);Expon>>=1;}
prod=b;

while (Expon>>=1)
  {
   b=ModMul(b,b);
   if (Expon&1) prod=ModMul(prod,b);
  }
return prod;
}

ModInt
FindInverse(ModInt Num)
{
return ModPow(Num,Prime-2);
}


static void
NTTReorder(ModInt *Data, int Len)
/*
** Standard FFT/NTT data scrambling.
*/
{int Index,xednI,k;

xednI = 0;
for (Index = 0;Index < Len;Index++)
  {
   if (xednI > Index)
     {ModInt Temp;
      Temp=Data[xednI];
      Data[xednI]=Data[Index];
      Data[Index]=Temp;
     }
/* Bit reversal */
   k=Len/2;
   while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;}
   xednI+=k;
  }
}

void NTT(ModInt *Data, int Len, int Dir)
/* A simple minded, generic transform */
{int j,step,halfstep;
 int index,index2;
 ModInt u,w,temp;

NTTReorder(Data,Len);

step=1;
while (step < Len)
  {
   halfstep=step;
   step*=2;

   u=1;
   if (Dir > 0) w=ModPow(PrimvRoot,Prime-1-((Prime-1)/step));
   else         w=ModPow(PrimvRoot,(Prime-1)/step);

   for (j=0;j 1) || (BASE > 10))
  {
   printf("Error:  The ntt is hardwired for just 1 digit in the base.\n");
   printf("In 'main.c' please set BASE to 10 and BASE_DIG to 1\n");
   exit(0);
  }
#endif

NTTNum1=(ModInt*)malloc(Bytes);
NTTNum2=(ModInt*)malloc(Bytes);

if ((NTTNum1==NULL) || (NTTNum2==NULL))
   {
    printf("Unable to allocate memory for NTTNum.\n");
    printf("Len=%d Bytes=%d\n",(int)Len,(int)Bytes);
    exit(0);
   }

/* Setup our primes */
  for (x=0;x