/*
     NTT64.C
*/
/*
** This file is placed into the public domain by its author,
** Carey Bloodworth (Carey@Bloodworth.org) on July 16, 2001
**
** This multiplication demo is not designed for high performance.
** It's a tutorial program designed to be used with the information
** on my web site at www.Bloodworth.org
*/
/*
** This file demonstrates a very basic NTT based multiply.
**
** To compile this using GCC:
** gcc main.c ntt64.c -o ntt64.exe
*/
#include 
#include 
#include 
#include 
#include 

#define CalcNTTLen(_NumLen) ((((_NumLen)*BASE_DIG)*2)/BASE_DIG)
/* NumLen*BaseDig*ZeroPadding/Dig_Per_FFT */

typedef short int Short;
typedef unsigned long long UINT64; /* 64 bit unsigned int */
typedef UINT64 ModInt;
/*
** Not all compilers have a 64 bit integer data type.
** Most of them do, though.  Under GNU C, it's called "long long"
** Others may have it named somewhat differently.
**
** The C99 c standard does require all compilers to have a 64
** bit data type.  Older compilers might not.  Such as Borland C 5.5
*/

static ModInt *NTTNum1=NULL, *NTTNum2=NULL;
static int BASE;
static int BASE_DIG;

ModInt Prime,PrimvRoot,MulInv;

ModInt ModMulP(ModInt a,ModInt b,ModInt Prime)
{
 UINT64 al, ah, bl, bh, rl, rh, c;

    /* double-width multiplication */

    al = a & 0xFFFFFFFF;
    ah = a >> 32;

    bl = b & 0xFFFFFFFF;
    bh = b >> 32;

    rl = al * bl;

    c = al * bh;
    rh = c >> 32;
    c <<= 32;
    rl += c;
    if (rl < c) rh++;

    c = ah * bl;
    rh += c >> 32;
    c <<= 32;
    rl += c;
    if (rl < c) rh++;

    rh += ah * bh;

    /* modulo reduction */

/*  if (Prime == 18446744069414584321ULL) */
  if (Prime == 0xffffffff00000001ULL)
    {
        /* modulus == 2^64-2^32+1 */
        UINT64 t;

        /* 1st shift */
        t = rh;
        c = rh << 32;
        rh >>= 32;
        t = rl - t;
        if (t > rl) rh--;

        rl = t + c;
        if (rl < t) rh++;

        /* 2nd shift */
        t = rh;
        c = rh << 32;
        rh >>= 32;
        t = rl - t;
        if (t > rl) rh--;

        rl = t + c;
        if (rl < t) rh++;

        /* Final check */
        return (rh || rl >= Prime ? rl - Prime : rl);
    }
   else
    {
     printf("Unknown modulus in ModMulP.\n");
     exit(1);
    }
}

ModInt ModMul(ModInt a,ModInt b)
{
return ModMulP(a,b,Prime);
}

ModInt ModAdd(ModInt a,ModInt b)
{ModInt Sum;
Sum=a+b;
if (Sum < a) Sum-=Prime;
return Sum;
}

ModInt ModSub(ModInt a,ModInt b)
{ModInt Dif;
Dif=a-b;
if (Dif > a) Dif+=Prime;
return Dif;
}


ModInt
ModPowP(ModInt Base,ModInt Expon,ModInt Prime)
{ModInt prod,b;

if (Expon<=0) return 1;

b=Base;
while (!(Expon&1)) {b=ModMulP(b,b,Prime);Expon>>=1;}
prod=b;

while (Expon>>=1)
  {
   b=ModMulP(b,b,Prime);
   if (Expon&1) prod=ModMulP(prod,b,Prime);
  }
return prod;
}

ModInt
ModPow(ModInt Base,ModInt Expon)
{
return ModPowP(Base,Expon,Prime);
}

ModInt
FindInverse(ModInt Num, ModInt Modulus)
{ModInt i;
i=ModPowP(Num,Modulus-2,Modulus);
/*
** Num*3 can overflow causing the check to fail.
if (ModMul(Num*3,i) != 3)
  FatalError("Unable to find Mul inverse for %u mod %u\n",Num,Modulus);
*/
return i;
}

static void
NTTReorder(ModInt *Data, int Len)
{int Index,xednI,k;

xednI = 0;
for (Index = 0;Index < Len;Index++)
  {
   if (xednI > Index)
     {ModInt Temp;
      Temp=Data[xednI];
      Data[xednI]=Data[Index];
      Data[Index]=Temp;
     }
   k=Len/2;
   while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;}
   xednI+=k;
  }
}

void NTT(ModInt *Data, int Len, int Dir)
/* A simple minded, generic transform */
{int j,step,halfstep;
 int index,index2;
 ModInt u,w,temp;

NTTReorder(Data,Len);

step=1;
while (step < Len)
  {
   halfstep=step;
   step*=2;

   u=1;
   if (Dir > 0) w=ModPow(PrimvRoot,Prime-1-((Prime-1)/step));
   else         w=ModPow(PrimvRoot,(Prime-1)/step);

   for (j=0;j 0; x--)
    {
      Pyramid = ModMul(NTTNum1[Len2 - x],MulInv) + Carry;
      Carry       = Pyramid / BASE;
      Prod[x - 1] = Pyramid % BASE;
    }
}

void
InitFFT(unsigned long int Len,int Base,int BaseDig)
{int Bytes;

BASE=Base;
BASE_DIG=BaseDig;

Bytes=sizeof(ModInt)*CalcNTTLen(Len);

if (BaseDig > 4)
  {
   printf("Error:  The fft is slightly hardwired for <= 4 digits in the base.\n");
   exit(0);
  }

NTTNum1=(ModInt*)malloc(Bytes);
NTTNum2=(ModInt*)malloc(Bytes);

if ((NTTNum1==NULL) || (NTTNum2==NULL))
   {
    printf("Unable to allocate memory for NTTNum.\n");
    printf("Len=%d Bytes=%d\n",(int)Len,(int)Bytes);
    exit(0);
   }
  Prime   =0xffffffff00000001ULL;PrimvRoot=7;
}

void DeInitFFT(unsigned long int Len)
{
free(NTTNum1);free(NTTNum2);
}