/*
mont31.c
*/
/*
** This file is placed into the public domain by its author,
** Carey Bloodworth (Carey@Bloodworth.org) on July 16, 2001
**
** This multiplication demo is not designed for high performance.
** It's a tutorial program designed to be used with the information
** on my web site at www.Bloodworth.org
*/
/*
** This file demonstrates a very basic NTT multiply. It uses a single
** 31 bit prime. That means you can only put a single decimal in there
** without things overflowing quickly. It does, however, mean that
** everybody can experiment with this.
**
** Multiplications are done using Montgomery modular multiplication.
**
** To compile this using GCC:
** gcc main.c mont31.c -o mont31.exe
*/
#include
#include
#include
#include
#include
#define CalcNTTLen(_NumLen) ((((_NumLen)*BASE_DIG)*2)/BASE_DIG)
/* NumLen*BaseDig*ZeroPadding/Dig_Per_FFT */
#define USE_MONTGOMERY 1
typedef short int Short;
typedef signed long INT32; /* 32/31 bit signed int */
typedef unsigned long UINT32; /* 32 bit unsigned int */
typedef UINT32 ModInt;
#ifdef USE_MONTGOMERY
typedef unsigned long long UINT64; /* for MontMul */
#endif
static ModInt *NTTNum1=NULL, *NTTNum2=NULL;
static int BASE;
static int BASE_DIG;
ModInt Prime,PrimvRoot,MulInv;
ModInt FromMontC,ToMontC;
UINT32 MontMulC; /* NOTE: This is our word size, not a ModInt */
/*ModInt MontToMontC,MontFromMontC;*/
ModInt
ModMul(ModInt a, ModInt b)
/* This can be *extremely* simple because it'll be rarely used */
/* Just a plain 'bit by bit' multiply. */
{
int i = 32;
ModInt Prod;
Prod=0;
while (i--)
{
if (Prod & 0x80000000)
{
Prod<<=1;
Prod-=Prime;
}
else
{
Prod<<=1;
}
if (Prod >= Prime) Prod-=Prime;
if (a & 0x80000000)
{ModInt T;
T=Prod+b;
if (T < Prod) T-=Prime;
Prod=T;
}
a<<=1;
if (Prod >= Prime) Prod-=Prime;
}
return Prod;
}
ModInt ModAdd(ModInt a,ModInt b)
{ModInt Sum; /* unsigned data type */
Sum=a+b;
if (Sum < a) Sum-=Prime;
if (Sum >= Prime) Sum-=Prime;
return Sum;
}
ModInt ModSub(ModInt a,ModInt b)
{ModInt Dif; /* unsigned data type */
Dif=a-b;
if (Dif > a) Dif+=Prime;
return Dif;
}
ModInt ModPow(ModInt base,ModInt expon)
{ModInt prod,b;
if (expon==0) return 1;
b=base;
while (!(expon&1))
{
b=ModMul(b,b);
expon>>=1;
}
prod=b;
while (expon>>=1)
{
b=ModMul(b,b);
if (expon&1) prod=ModMul(prod,b);
}
return prod;
}
ModInt
FindInverse(ModInt Num)
/* Or we could use an extended gcd function */
{
return ModPow(Num,Prime-2);
}
/*
** Montgomery modular multiplication stuff.
*/
#ifdef USE_MONTGOMERY
#define ToMont(_x) ModMul(ToMontC,_x)
#define FromMont(_x) ModMul(FromMontC,_x)
/*
** ToMont and FromMont are a little slow. After all, they
** are using our dead dog slow ModMul routine. If we wanted
** to do a little preperation, we can speed them up.
** We can convert those constants into Montgomery space
** and that way we can use our fast MontMul with them.
ModInt MontToMontC,MontFromMontC;
MontToMontC=ModMul(ToMontC,ToMontC);
MontFromMontC=ModMul(ToMontC,FromMontC);
** And then we can do:
#define ToMont(_x) MontMul(MontToMontC,_x)
#define FromMont(_x) MontMul(MontFromMontC,_x)
** I'm not actually doing it this way because you need to see
** the regular way.
*/
ModInt MontMul(UINT32 Num1,UINT32 Num2)
/* Pretty standard Montgomery Modular Multiplication */
{UINT32 Lo1,Hi1,Lo2,Hi2;
UINT64 LP1,LP2,T;
UINT32 m;
LP1=((UINT64)Num1)*((UINT64)Num2);Lo1=LP1;Hi1=LP1>>32;
m=(Lo1 * MontMulC);
LP2=((UINT64)m)*((UINT64)Prime);Lo2=LP2;Hi2=LP2>>32;
T=((UINT64)Hi1)+((UINT64)Hi2);
if ((Lo1+Lo2) < Lo1) T++;
if (T >= Prime) T-=Prime;
m=T;
return m;
}
ModInt MontPow(ModInt base,UINT32 expon)
/*
** Our Montgomery radix is a full 32 bits, so we can just take
** the low word of these multiplications.
*/
{ModInt prod,b;
b=base;
while (!(expon&1))
{
b=MontMul(b,b);
expon>>=1;
}
prod=b;
while (expon>>=1)
{
b=MontMul(b,b);
if (expon&1) prod=MontMul(prod,b);
}
return prod;
}
#else
#define ToMont(_x) (_x)
#define FromMont(_x) (_x)
#define MontMul(q1,q2) ModMul((q1),(q2))
#define MontPow(_a,_b) ModPow(_a,_b)
#endif
static void
NTTReorder(ModInt *Data, int Len)
{int Index,xednI,k;
xednI = 0;
for (Index = 0;Index < Len;Index++)
{
if (xednI > Index)
{ModInt Temp;
Temp=Data[xednI];
Data[xednI]=Data[Index];
Data[Index]=Temp;
}
k=Len/2;
while ((k <= xednI) && (k >=1)) {xednI-=k;k/=2;}
xednI+=k;
}
}
void NTT(ModInt *Data, int Len, int Dir)
/* A simple minded, generic transform */
{int j,step,halfstep;
int index,index2;
ModInt u,w,temp;
NTTReorder(Data,Len);
step=1;
while (step < Len)
{
halfstep=step;
step*=2;
u=1;
u=ToMont(u); /* ends up being just MontMulC */
/*
** We could do this...
if (Dir > 0) w=ModPow(PrimvRoot,Prime-1-((Prime-1)/step));
else w=ModPow(PrimvRoot,(Prime-1)/step);
w=ToMont(w);
** BUT plain ModMul is *VERY* slow. That's why we are doing
** Montgomery multiplication! So we have got to avoid it it.
*/
if (Dir > 0) w=MontPow(ToMont(PrimvRoot),Prime-1-((Prime-1)/step));
else w=MontPow(ToMont(PrimvRoot),(Prime-1)/step);
for (j=0;j 0; x--)
{
Pyramid = NTTNum1[Len2 - x] + Carry;
Carry = Pyramid / BASE;
Prod[x - 1] = Pyramid % BASE;
}
}
void
InitFFT(unsigned long int Len,int Base,int BaseDig)
{int Bytes;
BASE=Base;
BASE_DIG=BaseDig;
Bytes=sizeof(ModInt)*CalcNTTLen(Len);
if ( (BaseDig > 1) || (BASE > 10))
{
printf("Error: The ntt is hardwired for just 1 digit in the base.\n");
printf("In 'main.c' please set BASE to 10 and BASE_DIG to 1\n");
exit(0);
}
NTTNum1=(ModInt*)malloc(Bytes);
NTTNum2=(ModInt*)malloc(Bytes);
if ((NTTNum1==NULL) || (NTTNum2==NULL))
{
printf("Unable to allocate memory for NTTNum.\n");
printf("Len=%d Bytes=%d\n",(int)Len,(int)Bytes);
exit(0);
}
Prime = 2130706433;PrimvRoot=3;
ToMontC=((UINT32)(0-Prime)) % Prime;
FromMontC =FindInverse(ToMontC);
MontMulC=Prime-2;
/*
MontToMontC=ModMul(ToMontC,ToMontC);
MontFromMontC=ModMul(ToMontC,FromMontC);
*/
}
void DeInitFFT(unsigned long int Len)
{
free(NTTNum1);free(NTTNum2);
}