/*
     mont62.c
*/
/*
** This file is placed into the public domain by its author,
** Carey Bloodworth (Carey@Bloodworth.org) on July 16, 2001
**
** This multiplication demo is not designed for high performance.
** It's a tutorial program designed to be used with the information
** on my web site at www.Bloodworth.org
*/
/*
** This file demonstrates a very basic NTT multiply.  It uses a single
** 62 bit prime.
**
** Multiplications are done using Montgomery modular multiplication.
**
** To compile this using GCC:
** gcc main.c mont62.c -o mont62.exe
*/
#include 
#include 
#include 
#include 
#include 

#define CalcNTTLen(_NumLen) ((((_NumLen)*BASE_DIG)*2)/BASE_DIG)
/* NumLen*BaseDig*ZeroPadding/Dig_Per_FFT */


#define USE_MONTGOMERY 1

typedef short int Short;
typedef signed long    INT32; /* 32/31 bit signed int */
typedef unsigned long UINT32; /* 32 bit unsigned int */
typedef unsigned long long UINT64;
typedef UINT64 ModInt;

static ModInt *NTTNum1=NULL, *NTTNum2=NULL;
static int BASE;
static int BASE_DIG;

ModInt Prime,PrimvRoot,MulInv;
ModInt FromMontC,ToMontC;
UINT32 MontMulC; /* NOTE: This is our word size, not a ModInt */

/*ModInt MontToMontC,MontFromMontC;*/

#define MONTLOW(z)  ((UINT32)(z))
#define MONTHIGH(z) (((UINT64)(z))>>32)
#define MUL64(z,q) (((UINT64)z)*((UINT64)q))

ModInt
ModMul(ModInt a, ModInt b)
/* This can be *extremely* simple because it'll be rarely used */
/* Just a plain 'bit by bit' multiply. */
{
  int i = 64;
  ModInt Prod;
  Prod=0;
  while (i--)
    {
      if (Prod & 0x8000000000000000)
        {
         Prod<<=1;
         Prod-=Prime;
        }
      else
        {
         Prod<<=1;
        }
      if (Prod >= Prime) Prod-=Prime;
      if (a & 0x8000000000000000)
        {ModInt T;
         T=Prod+b;
         if (T < Prod) T-=Prime;
         Prod=T;
        }
      a<<=1;
      if (Prod >= Prime) Prod-=Prime;
    }
 return Prod;
}

ModInt ModAdd(ModInt a,ModInt b)
{ModInt Sum; /* unsigned data type */
Sum=a+b;
if (Sum < a) Sum-=Prime;
if (Sum >= Prime) Sum-=Prime;
return Sum;
}

ModInt ModSub(ModInt a,ModInt b)
{ModInt Dif; /* unsigned data type */
Dif=a-b;
if (Dif > a) Dif+=Prime;
return Dif;
}

ModInt ModPow(ModInt base,ModInt expon)
{ModInt prod,b;

if (expon==0) return 1;

b=base;
while (!(expon&1))
  {
   b=ModMul(b,b);
   expon>>=1;
  }
prod=b;

while (expon>>=1)
  {
   b=ModMul(b,b);
   if (expon&1) prod=ModMul(prod,b);
  }
return prod;
}

ModInt
FindInverse(ModInt Num)
/* Or we could use an extended gcd function */
{
return ModPow(Num,Prime-2);
}

/*
** Montgomery modular multiplication stuff.
*/

#ifdef USE_MONTGOMERY

#define ToMont(_x)   ModMul(ToMontC,_x)
#define FromMont(_x) ModMul(FromMontC,_x)
/*
** ToMont and FromMont are a little slow.  After all, they
** are using our dead dog slow ModMul routine.  If we wanted
** to do a little preperation, we can speed them up.
** We can convert those constants into Montgomery space
** and that way we can use our fast MontMul with them.

ModInt MontToMontC,MontFromMontC;

MontToMontC=ModMul(ToMontC,ToMontC);
MontFromMontC=ModMul(ToMontC,FromMontC);

** And then we can do:

#define ToMont(_x)   MontMul(MontToMontC,_x)
#define FromMont(_x) MontMul(MontFromMontC,_x)

** I'm not actually doing it this way because you need to see
** the regular way.
*/

#if 0
ModInt MontMul(ModInt Num1,ModInt Num2)
/* Pretty standard Montgomery Modular Multiplication */
{UINT64 x;
UINT32 A0,A1,B0,B1,P0,P1,D0,D1;

A0=MONTLOW(Num1); A1=MONTHIGH(Num1);
B0=MONTLOW(Num2); B1=MONTHIGH(Num2);
D0=MONTLOW(Prime);D1=MONTHIGH(Prime);

x=MUL64(A0,B0);
P0=MONTLOW((INT32)x) * MontMulC;P0=MONTLOW(P0);
x+=MUL64(P0,D0);
if (MONTLOW(x)) {printf("x1 is not zero.\n");exit(0);}
x=MONTHIGH(x);

x=x+MUL64(P0,D1)+MUL64(A0,B1)+MUL64(A1,B0);
P1=MONTLOW((INT32)x) * MontMulC;P1=MONTLOW(P1);
x+=MUL64(P1,D0);
if (MONTLOW(x)) {printf("x2 is not zero.\n");exit(0);}
x=MONTHIGH(x);

x+=MUL64(P1,D1)+MUL64(A1,B1);

if (x >= ((UINT64)Prime)) x-=Prime;
return x;
}
#endif

#if 1
ModInt MontMul(ModInt Num1,ModInt Num2)
/* More advanced MontMul requiring only 6 muls. */
{UINT64 Sum,P00,P01,P10,P11,T64;
 UINT32 A0,A1,B0,B1,Prime0,Prime1;UINT32 T;

A0=MONTLOW(Num1); A1=MONTHIGH(Num1);
B0=MONTLOW(Num2); B1=MONTHIGH(Num2);
Prime0=MONTLOW(Prime);Prime1=MONTHIGH(Prime);

P00=MUL64(A0,B0);
P10=MUL64(A1,B0);
P01=MUL64(A0,B1);
P11=MUL64(A1,B1);

T= MONTLOW(-((UINT32)MONTLOW(P00)));
T64=MUL64(T,Prime1);

Sum=P00+T;Sum=MONTHIGH(Sum);

Sum=Sum+T64+P01+P10;

T= MONTLOW(-((UINT32)MONTLOW(Sum)));
Sum=Sum+T;Sum=MONTHIGH(Sum);

Sum=Sum+MUL64(T,Prime1);
Sum=Sum+P11;

if (Sum >= ((UINT64)Prime)) Sum-=Prime;
return Sum;
}
#endif


ModInt MontPow(ModInt base,ModInt expon)
/*
** Our Montgomery radix is a full 32 bits, so we can just take
** the low word of these multiplications.
*/
{ModInt prod,b;

b=base;
while (!(expon&1))
  {
   b=MontMul(b,b);
   expon>>=1;
  }
prod=b;

while (expon>>=1)
  {
   b=MontMul(b,b);
   if (expon&1) prod=MontMul(prod,b);
  }
return prod;
}

#else

#define ToMont(_x)         (_x)
#define FromMont(_x)       (_x)
#define MontMul(q1,q2)     ModMul((q1),(q2))
#define MontPow(_a,_b)     ModPow(_a,_b)

#endif



static void
NTTReorder(ModInt *Data, int Len)
{int Index,xednI,k;

xednI = 0;
for (Index = 0;Index < Len;Index++)
  {
   if (xednI > Index)
     {ModInt Temp;
      Temp=Data[xednI];
      Data[xednI]=Data[Index];
      Data[Index]=Temp;
     }
   k=Len/2;
   while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;}
   xednI+=k;
  }
}

void NTT(ModInt *Data, int Len, int Dir)
/* A simple minded, generic transform */
{int j,step,halfstep;
 int index,index2;
 ModInt u,w,temp;

NTTReorder(Data,Len);

step=1;
while (step < Len)
  {
   halfstep=step;
   step*=2;

   u=1;
   u=ToMont(u); /* ends up being just MontMulC */
/*
** We could do this...
   if (Dir > 0) w=ModPow(PrimvRoot,Prime-1-((Prime-1)/step));
   else         w=ModPow(PrimvRoot,(Prime-1)/step);
   w=ToMont(w);

** BUT plain ModMul is *VERY* slow.  That's why we are doing
** Montgomery multiplication!  So we have got to avoid it it.
*/
   if (Dir > 0) w=MontPow(ToMont(PrimvRoot),Prime-1-((Prime-1)/step));
   else         w=MontPow(ToMont(PrimvRoot),(Prime-1)/step);

   for (j=0;j 0; x--)
    {
      Pyramid = NTTNum1[Len2 - x] + Carry;
      Carry       = Pyramid / BASE;
      Prod[x - 1] = Pyramid % BASE;
    }
}

void
InitFFT(unsigned long int Len,int Base,int BaseDig)
{int Bytes;

BASE=Base;
BASE_DIG=BaseDig;

Bytes=sizeof(ModInt)*CalcNTTLen(Len);

NTTNum1=(ModInt*)malloc(Bytes);
NTTNum2=(ModInt*)malloc(Bytes);

if ((NTTNum1==NULL) || (NTTNum2==NULL))
   {
    printf("Unable to allocate memory for NTTNum.\n");
    printf("Len=%d Bytes=%d\n",(int)Len,(int)Bytes);
    exit(0);
   }

  Prime=0x3fdc000000000001ULL;  /* 4087*2^50+1 Bits=61.99682653*/
  PrimvRoot=3;

  ToMontC=(0-Prime) % Prime;
  FromMontC=FindInverse(ToMontC) % Prime;
  MontMulC=MONTLOW(Prime-2);

/*
  MontToMontC=ModMul(ToMontC,ToMontC);
  MontFromMontC=ModMul(ToMontC,FromMontC);
*/
}

void DeInitFFT(unsigned long int Len)
{
free(NTTNum1);free(NTTNum2);
}