#include <stdio.h>
#include <stdlib.h>
#include <vector>
#include <assert.h>
#include <time.h>

#define LIST std::vector<unsigned>

// Internal counter for step counting
int count=0;

// The array of subkeys KEY[i][j]: subkey j of round i
unsigned int KEY[6][9];

/* Performes multiplication of to 8-bit integers modulu 2^8+1*/
unsigned int mul(unsigned int a, unsigned int b)
{
	int ret;
	unsigned int r_a=a&0xff;
	unsigned int r_b=b&0xff;
	if(!r_a)
	{
		ret=257-r_b;
	}
	else if(!r_b)
	{
		ret=257-r_a;
	}
	else
	{
		unsigned int t_ret=r_a*r_b;
		ret=(t_ret&0xff)-((t_ret>>8)&0xff);
		if(ret<0)
			ret+=257;
	}
	return (unsigned int)(ret&0xff);
}

/* This function prints the key schedule on the stdout. It prints it in a way that is convinient */
/* for seeing where the 3.5 round attack would be faster (that is where the overlap is bigger :-)). */
void key_sched_print()
{
	int k=1;
	int cur=1;
	for(int i=0;i<8;i++)
	{
		for(int j=0;j<8;j++)
		{
			printf("[%d..",cur);
			cur=((cur+7>64)?(cur-57):(cur+7));
			printf("%d]",cur);
			cur++;
			if(cur==65)
				cur=1;
			k++;
			if((k%6)==1)
				printf("\n");
		}
		cur=(25*(i+1))%64+1;
	}
	printf("\n");
}

/* The encipherer (taken from the paper describing IDEA) for rounds 1 - 3.5. It gets the input 4 bytes: the lower */
/* bytes of XX[4] and it produces the output 4 bytes: the lower bytes of YY[4]. (REMARK: by lower) */
/* bytes I mean LOGICAL lower bytes, i.e X&0xff) */
void cip(unsigned int XX[4], unsigned int YY[4], unsigned int Z[6][9])
{
	unsigned int x1,x2,x3,x4,kk,t1,t2,a;
	x1=XX[0]&0xff;
	x2=XX[1]&0xff;
	x3=XX[2]&0xff;
	x4=XX[3]&0xff;
	for(int r=0;r<8;r++)
	{
		// T
		x1=mul(x1,Z[0][r]);
		x2=(x2+Z[1][r])&0xff;
		x3=(x3+Z[2][r])&0xff;
		x4=mul(x4,Z[3][r]);
		// Break after 3.5 rounds
		if(r==3)
			break;
		// MA
		kk=mul(Z[4][r],(x1^x3));
		t1=mul(Z[5][r],(kk+(x2^x4))&0xff);
		t2=(kk+t1)&0xff;
		// Involution-Permutation
		x1=x1^t1;
		x4=x4^t2;
		a=x2^t2;
		x2=x3^t1;
		x3=a;
	}
	YY[0]=x1;
	YY[1]=x2;
	YY[2]=x3;
	YY[3]=x4;
}

/* This function is the IDEA subkey generator. IDEA key schedule is to partition the 64 bit key */
/* into 8 subsiquent subkeys and then shift the key 25 places to the left, then get the next subsequent subkeys */
/* the same way */
void key(unsigned int uskey[2])
{
	unsigned char *fh=(unsigned char *)&uskey[0];
	unsigned char *lh=(unsigned char *)&uskey[1];
	int cr=0;
	int cc=0;
	for(int i=0;i<8;i++)
	{
		for(int j=0;j<8;j++)
		{
			KEY[cr][cc]=((j<4)?(fh[3-j]):(lh[7-j]));
			cr++;
			cc=((cr>5)?cc+1:cc);
			cr%=6;
			if(cc>8)
				return;
		}
		unsigned int t1=uskey[0]&0xffffff80;
		t1=t1>>7;
		unsigned int t2=uskey[1]&0xffffff80;
		t2=t2>>7;
		uskey[0]=((uskey[0]&0x7f)<<25)|t2;
		uskey[1]=((uskey[1]&0x7f)<<25)|t1;
	}
}

/* Calculates an inverse of 8-bit integer x modulu 2^8+1 using the GCD algorithm */
unsigned int inv(unsigned int x)
{
	int n1,n2,q,r,b1,b2,t;
	if(x==0)
		b2=0;
	else
	{
		n1=257;
		n2=x;
		b2=1;
		b1=0;
		do
		{
			r=n1%n2;
			q=(n1-r)/n2;
			if(r==0)
			{
				if(b2<0)
					b2=257+b2;
			}
			else
			{
				n1=n2;
				n2=r;
				t=b2;
				b2=b1-q*b2;
				b1=t;
			}
		}
		while(r);
	}
	return (unsigned int)(b2&0xff);
}

// Intended to hold the inp./outp. values
struct Pair
{
	unsigned int inp;
	unsigned int out;
};

namespace PartA
{
	// Will contain the number of structures used during the elimination process
	int NumStructUsed=0;
	// Will be a block of 2^24 bits each bit signifing one possibility for the first 24 bits
	// of the real key.
	int *Variants=NULL;
	// The name speakes for itself :-).
	unsigned int NumKeysLeft=0x1000000;
	// Array of inverses: (inv_arr[i]*i) mod (2^8+1) = 1.
	unsigned int inv_arr[256];

	// Sample Data Base, 2^16 entries
	Pair SDB[0x10000];

	/* Initialization: initialize the NumKeysLeft to 2^24, inv_arr to the intended values, NumStructUsed to 0 */
	/* and Variants array to a block of 2^24 zero bits. */
	void Init()
	{
		Variants=(int *)calloc(1<<19,4);
		NumKeysLeft=0x1000000;
		for(unsigned int i=0;i<256;i++)
			inv_arr[i]=inv(i);
		count=0;
		NumStructUsed=0;
	}

	/* Exludes the given improper key from the Variants array by setting it's bit to 1 */
	int Exclude(unsigned int ip_key)
	{
		int ret;
		if(!Variants)
			return -1;
		unsigned int rip_key=ip_key&0xffffff;
		int ind=(rip_key-(rip_key%32))/32;
		int mask=1<<(31-(rip_key%32));
		if(Variants[ind]&mask)
		{
			ret=0;
		}
		else
		{
			ret=1;
		}
		Variants[ind]|=mask;
		return ret;
	}

	/* Checks if the given key was not previosly excluded (intended for BUG detection) */
	int QueryKey(unsigned int key)
	{
		int ret;
		if(!Variants)
			return -1;
		unsigned int rkey=key&0xffffff;
		int ind=(rkey-(rkey%32))/32;
		int mask=1<<(31-(rkey%32));
		if(Variants[ind]&mask)
		{
			ret=0;
		}
		else
		{
			ret=1;
		}
		return ret;
	}

	/* Builds an array of all the keys not previosly excluded (the ones whose bit in Variants array is 0) */
	unsigned int *GetAllRemainingKeys()
	{
		unsigned int *ret=new unsigned[NumKeysLeft];
		unsigned int ind=0;
		for(unsigned int i=0;i<(1<<19);i++)
		{
			if(Variants[i]!=0xffffffff)
			{
				for(unsigned int j=0;j<32;j++)
				{
					if((Variants[i]&(1<<(31-j)))==0)
					{
						ret[ind]=i*32+j;
						ind++;
					}
				}
			}
		}

		return ret;
	}

	/* Get all bad First Key Halves (FKH), i.e. all the relevant 1-st round subkey bits which make a (A,0,A,0) input */
	/* difference for a given pair. Complexity: O(2^8). */
	LIST **GetAllBadFKH(unsigned int x1_mul, unsigned int x2_mul, unsigned int x1_add, unsigned int x2_add)
	{
		LIST **res=new LIST*[256];

		for(int k=0;k<256;k++)
			res[k]=new LIST;

		x1_mul&=0xff;
		x2_mul&=0xff;
		x1_add&=0xff;
		x2_add&=0xff;

		LIST lst_arr[256];

		for(unsigned int i=0;i<256;i++)
		{
			lst_arr[mul(x1_mul,i)^mul(x2_mul,i)].push_back(i);
		}

		for(unsigned int j=0;j<256;j++)
		{
			unsigned int tmp=((x1_add+j)&0xff)^((x2_add+j)&0xff);
			if(!(lst_arr[tmp].empty()))
			{
				for(int ind=0;ind<lst_arr[tmp].size();ind++)
				{
					unsigned int aux=lst_arr[tmp][ind];
					res[(((aux&0x3f)<<2)&0xff)|(((j&0xc0)>>6)&0x3)]->push_back(((aux<<8)&0xff00)|j);
				}
			}
		}

		return res;
	}

	/* Get all bad Last Key Halves (LKH), i.e. all the relevant 3.5-th round subkey bits which make a (B,B,0,0) output */
	/* difference for a given pair. Complexity: O(2^8). */
	LIST **GetAllBadLKH(unsigned int x1_mul, unsigned int x2_mul, unsigned int x1_add, unsigned int x2_add)
	{
		LIST **res=new LIST*[256];

		for(int k=0;k<256;k++)
			res[k]=new LIST;

		x1_mul&=0xff;
		x2_mul&=0xff;
		x1_add&=0xff;
		x2_add&=0xff;

		LIST lst_arr[256];

		for(unsigned int i=0;i<256;i++)
		{
			lst_arr[mul(x1_mul,inv_arr[i])^mul(x2_mul,inv_arr[i])].push_back(i);
		}

		for(unsigned int j=0;j<256;j++)
		{
			unsigned int tmp=((x1_add+(256-j))&0xff)^((x2_add+(256-j))&0xff);
			if(!(lst_arr[tmp].empty()))
			{
				for(int ind=0;ind<lst_arr[tmp].size();ind++)
				{
					unsigned int aux=lst_arr[tmp][ind];
					res[(aux&0xfc)|(j&0x3)]->push_back(((aux<<8)&0xff00)|j);
				}
			}
		}

		return res;
	}

	// Exclude all the improper (exhibiting the impossible 2.5 round differential) keys for a given pair, expected O(2^16) complexity
	void ExcludeResBadKeys(unsigned int x11_mul, unsigned int x21_mul, unsigned int x11_add, unsigned int x21_add, unsigned int x12_mul, unsigned int x22_mul, unsigned int x12_add, unsigned int x22_add)
	{
	// Getting the key halves...
		LIST **FKH=GetAllBadFKH(x11_mul,x21_mul,x11_add,x21_add);
		LIST **LKH=GetAllBadLKH(x12_mul,x22_mul,x12_add,x22_add);
	// Excluding all the keys made from one of the first halves and one of the second halves 
	// (of course the two halves must be COMPATABLE, i.e. have the same SHARED bits)
		for(int i=0;i<256;i++)
		{
			// For every 2 compatable halves ...
			if((!(FKH[i]->empty()))&&(!(LKH[i]->empty())))
			{
				for(int ind1=0;ind1<FKH[i]->size();ind1++)
				{
					unsigned int fh=(*(FKH[i]))[ind1];
					for(int ind2=0;ind2<LKH[i]->size();ind2++)
					{
						unsigned int lh=(*(LKH[i]))[ind2];
						// Generate the improper key
						unsigned int bk=(((fh&0xff00)<<8)|((lh&0x03fc)<<6)|(fh&0xff))&0xffffff;
						// Exclude it if not previously excluded
						if(Exclude(bk))
							NumKeysLeft--;
					}			
				}			
			}
		}
	// Clean up...
		for(int l=0;l<256;l++)
		{
			delete FKH[l];
			delete LKH[l];
		}

		delete[] FKH;
		delete[] LKH;
	}

	/* The Attack! It takes the real 24 first bits of the key for debugging */
	/* reference (i.e. to see if we didn't made an error and excluded the real key in some step) */

	void ExludeAllBadKeys(unsigned int real_key)
	{
		std::vector<Pair*> *tlst=new std::vector<Pair*>[0x10000];
		// Iterate on all elements of the structure (the current structure resides in the SDB inp./outp. pairs array)
		for(int i=0;i<0x10000;i++)
		{
			unsigned int ind=(SDB[i].out)&0x0000ffff;
			// if we have already seen elements with the same last 16 bits of the output as the cur. element
			if(!(tlst[ind].empty()))
			{
				// then we iterate on all of them and exclude all the improper keys for the resulting pairs (cur. elem. and the previouse ones)
				for(int r=0;r<tlst[ind].size();r++)
				{
					Pair *tmp=tlst[ind][r];
					// Make the exclusion
					ExcludeResBadKeys((SDB[i].inp&0xff000000)>>24,(tmp->inp&0xff000000)>>24,(SDB[i].inp&0xff00)>>8,(tmp->inp&0xff00)>>8,(SDB[i].out&0xff000000)>>24,(tmp->out&0xff000000)>>24,(SDB[i].out&0xff0000)>>16,(tmp->out&0xff0000)>>16);
					// Check if we didn't kill the real key due to some BUG
					int ret=QueryKey(real_key);
					if(ret!=1)
					{
						// Well we don't really want to be here ;-)
						printf("Ooops, we've just killed the real key! Aborting...\n");
						delete[] tlst;
						return;
					}
					if((count%1000)==0)
					{
						// if we've passed a 1000 iterations for a given structure then
						// output the current status (i.e. how many keys are still remaining)
						printf("Step %d - %d keys left...\n",count/1000,NumKeysLeft);
					}
					// If there no more then 2 possible keys left then return (from experimening
					// I got that we reach the 2 possib. in some reasonable time and then we wait 
					// and wait and ... after 90 structures we still get 2 possibilities, so why bother ;-)
					if(NumKeysLeft<=2)
					{
						printf("We have only two keys left. One of them is the right key. Terminating...\n");
						delete[] tlst;
						return;
					}
					count++;
				}
			}
			// whether it is the first element with the given last 16 bits or not, we add it
			// into the proper list
			tlst[ind].push_back(&SDB[i]);
		}
		delete[] tlst;
	}

	//////////////////////////////////////////////////////////
	/* This function receives x2=X and x4=Y and filles the Pairs array SDB with a structure */
	/* of all the possible inputs of the type *X*Y with their according outputs after 3.5 rounds */
	/* of IDEA */
	void BuildSDB(unsigned int x2,unsigned int x4)
	{
		// Iterate over all input possibilities
		for(int i=0;i<0x10000;i++)
		{
			unsigned int XX[4];
			unsigned int YY[4];
			// Generate the input
			XX[0]=(i&0xff00)>>8;
			XX[1]=x2;
			XX[2]=i&0xff;
			XX[3]=x4;
			// Get the output by enciphring the input using cip
			cip(XX,YY,KEY);
			// Put the result into the SDB
			SDB[i].inp=(XX[0]<<24)|(XX[1]<<16)|(XX[2]<<8)|(XX[3]);
			SDB[i].out=(YY[0]<<24)|(YY[1]<<16)|(YY[2]<<8)|(YY[3]);
		}
	}
};

namespace PartB
{
	// Will contain the number of structures used during the elimination process
	int NumStructUsed=0;
	// Will be a block of 2^10 bits each bit signifing one possibility for the first bits [25..34]
	// of the real key.
	int *Variants=NULL;
	// The name speakes for itself :-).
	unsigned int NumKeysLeft=0x400;
	// Array of inverses: (inv_arr[i]*i) mod (2^8+1) = 1.
	unsigned int inv_arr[256];

	// Sample Data Base, 2^16 entries, will be filled by elements of structures
	Pair SDB[0x10000];

	// Less Then Two Possibilities left (used by the elimination process)
	unsigned int lttpl=0;

	/* Initialization: initialize the NumKeysLeft to 2^10, inv_arr to the intended values */
	/* and Variants array to a block of 2^10 zero bits. */
	void Init()
	{
		lttpl=0;
		Variants=(int *)calloc(1<<5,4);
		NumKeysLeft=0x400;
		for(unsigned int i=0;i<256;i++)
			inv_arr[i]=inv(i);
		count=0;
	}

	/* Exludes the given improper key from the Variants array by setting it's bit to 1 */
	int Exclude(unsigned int ip_key)
	{
		int ret;
		if(!Variants)
			return -1;
		unsigned int rip_key=ip_key&0x3ff;
		int ind=(rip_key-(rip_key%32))/32;
		int mask=1<<(31-(rip_key%32));
		if(Variants[ind]&mask)
		{
			ret=0;
		}
		else
		{
			ret=1;
		}
		Variants[ind]|=mask;
		return ret;
	}

	/* Checks if the given key was not previosly excluded (intended for BUG detection) */
	int QueryKey(unsigned int key)
	{
		int ret;
		if(!Variants)
			return -1;
		unsigned int rkey=key&0x3ff;
		int ind=(rkey-(rkey%32))/32;
		int mask=1<<(31-(rkey%32));
		if(Variants[ind]&mask)
		{
			ret=0;
		}
		else
		{
			ret=1;
		}
		return ret;
	}

	/* Builds an array of all the keys not previosly excluded (the ones whose bit in Variants array is 0) */
	unsigned int *GetAllRemainingKeys()
	{
		unsigned int *ret=new unsigned[NumKeysLeft];
		unsigned int ind=0;
		for(unsigned int i=0;i<(1<<5);i++)
		{
			if(Variants[i]!=0xffffffff)
			{
				for(unsigned int j=0;j<32;j++)
				{
					if((Variants[i]&(1<<(31-j)))==0)
					{
						ret[ind]=i*32+j;
						ind++;
					}
				}
			}
		}

		return ret;
	}

	/* Get all bad First Key Halves (FKH), i.e. all the relevant 1-st round subkey bits which make a (0,A,0,A) input */
	/* difference for a given pair. Some subkey bits are dictated by already known (from Part A) key bits: kb_9_to_16 */
	/* (i.e. already known key bits 9 to 16). Complexity: O(2^8). */
	LIST *GetAllBadFKH(unsigned int x1_add, unsigned int x2_add, unsigned int x1_mul, unsigned int x2_mul, unsigned int kb_9_to_16)
	{
		LIST *res=new LIST;

		x1_mul&=0xff;
		x2_mul&=0xff;
		x1_add&=0xff;
		x2_add&=0xff;

		unsigned int aux=(((x1_add+kb_9_to_16)&0xff)^((x2_add+kb_9_to_16)&0xff));

		for(unsigned int j=0;j<256;j++)
		{
			unsigned int tmp=((mul(x1_mul,j)^mul(x2_mul,j))&0xff);
			if(aux==tmp)
			{
				res->push_back(j);
			}
		}

		return res;
	}

	/* Get all bad Last Key Halves (LKH), i.e. all the relevant 3.5-th round subkey bits which make a (0,0,B,B) output */
	/* difference for a given pair. Some subkey bits are dictated by already known (from Part A) key bits: kb_19_to_24 */
	/* (i.e. already known key bits 19 to 24). Complexity: O(2^8). */
	LIST **GetAllBadLKH(unsigned int x1_add, unsigned int x2_add, unsigned int x1_mul, unsigned int x2_mul, unsigned int kb_19_to_24)
	{
		LIST **res=new LIST*[256];

		for(int k=0;k<256;k++)
			res[k]=new LIST;

		x1_mul&=0xff;
		x2_mul&=0xff;
		x1_add&=0xff;
		x2_add&=0xff;

		LIST lst_arr[256];

		for(unsigned int i=0;i<4;i++)
		{
			unsigned int rk=(((kb_19_to_24&0x3f)<<2)|i);
			lst_arr[((x1_add+(256-rk))&0xff)^((x2_add+(256-rk))&0xff)].push_back(rk);
		}

		for(unsigned int j=0;j<256;j++)
		{
			unsigned int tmp=mul(x1_mul,inv_arr[j])^mul(x2_mul,inv_arr[j]);
			if(!(lst_arr[tmp].empty()))
			{
				for(int ind=0;ind<lst_arr[tmp].size();ind++)
				{
					unsigned int aux=lst_arr[tmp][ind];
					res[((aux&0x3)<<6)|((j&0xfc)>>2)]->push_back(((aux<<8)&0xff00)|j);
				}
			}
		}

		return res;
	}

	// Exclude all the improper (exhibiting the impossible 2.5 round differential 0A0A-x->00BB) keys for a given pair, expected O(2^16) complexity
	void ExcludeResBadKeys(unsigned int x11_add, unsigned int x21_add, unsigned int x11_mul, unsigned int x21_mul, unsigned int x12_add, unsigned int x22_add, unsigned int x12_mul, unsigned int x22_mul, unsigned int kb_9_to_16, unsigned int kb_19_to_24)
	{
	// Getting the key halves...
		LIST *FKH=GetAllBadFKH(x11_add,x21_add,x11_mul,x21_mul,kb_9_to_16);
		LIST **LKH=GetAllBadLKH(x12_add,x22_add,x12_mul,x22_mul,kb_19_to_24);
	// Excluding all the keys made from one of the first halves and one of the second halves 
	// (of course the two halves must be COMPATABLE, i.e. have the same SHARED bits)
		int fkh_sz=FKH->size();
		for(int i=0;i<fkh_sz;i++)
		{
			// For every 2 compatable halves ...
			if(!(LKH[(*FKH)[i]]->empty()))
			{
				unsigned int fh=(*FKH)[i];
				for(int ind2=0;ind2<LKH[fh]->size();ind2++)
				{
					unsigned int lh=(*(LKH[fh]))[ind2];
					// Generate the improper key
					unsigned int bk=((fh&0xff)<<2)|(lh&0x3);
					// Exclude it if not previously excluded
					if(Exclude(bk))
						NumKeysLeft--;
				}			
			}			
		}

	// Clean up...
		for(int l=0;l<256;l++)
		{
			delete LKH[l];
		}

		delete FKH;
		delete[] LKH;
	}

	/* The Attack! It takes the first 24 key bits found in part A (to speed up the search) */

	void ExludeAllBadKeys(unsigned int kb_1_to_24)
	{
		unsigned int kb_9_to_16=(kb_1_to_24&0x00ff00)>>8;
		unsigned int kb_19_to_24=(kb_1_to_24&0x00003f);
		std::vector<Pair*> *tlst=new std::vector<Pair*>[0x10000];
		// Iterate on all elements of the structure (the current structure resides in the SDB inp./outp. pairs array)
		for(int i=0;i<0x10000;i++)
		{
			unsigned int ind=((SDB[i].out)&0xffff0000)>>16;
			// if we have already seen elements with the same first 16 bits of the output as the cur. element
			if(!(tlst[ind].empty()))
			{
				// then we iterate on all of them and exclude all the improper keys for the resulting pairs (cur. elem. and the previouse ones)
				for(int r=0;r<tlst[ind].size();r++)
				{
					Pair *tmp=tlst[ind][r];
					// Make the exclusion
					ExcludeResBadKeys((SDB[i].inp&0xff0000)>>16,(tmp->inp&0xff0000)>>16,(SDB[i].inp&0xff),(tmp->inp&0xff),(SDB[i].out&0xff00)>>8,(tmp->out&0xff00)>>8,(SDB[i].out&0xff),(tmp->out&0xff),kb_9_to_16,kb_19_to_24);
					if((count%1000)==0)
					{
						// if we've passed a 1000 iterations for a given structure then
						// output the current status (i.e. how many keys are still remaining)
						printf("Step %d - %d keys left...\n",count/1000,NumKeysLeft);
					}
					// If there no more then 1 possible keys left then return
					if(NumKeysLeft==0)
					{
						delete[] tlst;
						return;
					}
					if(NumKeysLeft==1)
					{
						lttpl++;
						// add around 1 minute for possible gain of around 33 minutes
						if(lttpl>120000)
						{
							delete[] tlst;
							return;
						}
					}
					count++;
				}
			}
			// whether it is the first element with the given last 16 bits or not, we add it
			// into the proper list
			tlst[ind].push_back(&SDB[i]);
		}
		delete[] tlst;
	}

	//////////////////////////////////////////////////////////
	/* This function receives x1=X and x3=Y and filles the Pairs array SDB with a structure */
	/* of all the possible inputs of the type X*Y* with their according outputs after 3.5 rounds */
	/* of IDEA */
	void BuildSDB(unsigned int x1,unsigned int x3)
	{
		// Iterate over all input possibilities
		for(int i=0;i<0x10000;i++)
		{
			unsigned int XX[4];
			unsigned int YY[4];
			// Generate the input
			XX[1]=(i&0xff00)>>8;
			XX[0]=x1;
			XX[3]=i&0xff;
			XX[2]=x3;
			// Get the output by enciphring the input using cip
			cip(XX,YY,KEY);
			// Put the result into the SDB
			SDB[i].inp=(XX[0]<<24)|(XX[1]<<16)|(XX[2]<<8)|(XX[3]);
			SDB[i].out=(YY[0]<<24)|(YY[1]<<16)|(YY[2]<<8)|(YY[3]);
		}
	}
};
/* The following 2 functions are used for fast subkey generation in the exhaustive search */
// Initialization, takes all the known key bits (1 to 34) and produces all the subkeys
// which depend ONLY on those known bits, it is used only once, before the begining of the
// exhaustive search
void key4elimInit(unsigned int uk[2])
{
	KEY[0][0]=((uk[0]&0xff000000)>>24);
	KEY[1][0]=((uk[0]&0xff0000)>>16);
	KEY[2][0]=((uk[0]&0xff00)>>8);
	KEY[3][0]=(uk[0]&0xff);
	KEY[0][3]=((uk[0]&0x3fc00000)>>22);
	KEY[1][3]=((uk[0]&0x3fc000)>>14);
	KEY[2][3]=((uk[0]&0x3fc0)>>6);
	KEY[3][3]=(((uk[0]&0x3f)<<2)|((uk[1]&0xc0000000)>>30));
	KEY[2][1]=(((uk[0]&0x7f)<<1)|((uk[1]&0x80000000)>>31));
	KEY[1][2]=((uk[0]&0x7f800000)>>23);
	KEY[2][2]=((uk[0]&0x7f8000)>>15);
	KEY[3][2]=((uk[0]&0x7f80)>>7);
}

// This function will be used in the iteration on all the possibilities for key bits 35 to 64
// it gets the full key for the current possibility (i.e. 1 to 34 - known and 35 to 64 - curr. possib.
// it produces all the remaining subkeys (which also depend on key bits 35 to 64)
void key4elim(unsigned int uk[2])
{
	KEY[4][0]=((uk[1]&0xff000000)>>24);
	KEY[5][0]=((uk[1]&0xff0000)>>16);
	KEY[0][1]=((uk[1]&0xff00)>>8);
	KEY[1][1]=(uk[1]&0xff);
	KEY[3][1]=((uk[1]&0x7f800000)>>23);
	KEY[4][1]=((uk[1]&0x7f8000)>>15);
	KEY[5][1]=((uk[1]&0x7f80)>>7);
	KEY[0][2]=(((uk[1]&0x7f)<<1)|((uk[0]&0x80000000)>>31));
	KEY[4][2]=((uk[1]&0x3fc0)>>6);
	KEY[5][2]=(((uk[1]&0x3f)<<2)|((uk[0]&0xc0000000)>>30));
}

// The Main function :-)
void main()
{
	unsigned int i=0;
	srand((unsigned)time(NULL)); // Seed the random number generator
//	key_sched_print(); // enable this to see the distribution of the key bits between the subkeys in the key schedule
	// Generate a random 64-bit key
	unsigned int k1=(double)(((double)rand())/((double)RAND_MAX/0x10000));
	unsigned int k2=(double)(((double)rand())/((double)RAND_MAX/0x10000));
	unsigned int k3=(double)(((double)rand())/((double)RAND_MAX/0x10000));
	unsigned int k4=(double)(((double)rand())/((double)RAND_MAX/0x10000));
	unsigned int uk[2]={(((k1&0xff)<<16)+(k2&0xff)),(((k3&0xff)<<16)+(k4&0xff))};
	
	// Remember parts of the real key for later demonstration	
	unsigned int real_key=(uk[0]&0xffffff00)>>8;
	unsigned int f32kb=uk[0];
	unsigned int l32kb=uk[1];
	unsigned int s2kb=((uk[1]&0xc0000000)>>30);

	// Generate all the subkeys
	key(uk);
	printf("Commencing Part A (A*A*-x->BB** impossible differential):\n");
	printf("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n");
	PartA::Init(); // Initialize 
	// Do the work
	for(i=0;i<90;i++) // Exclude improper key over 90 structures of the type *X*Y for fixed X and Y per structure
	{
		unsigned int tmp=(double)(((double)rand())/((double)RAND_MAX/0x10000)); // build the X and Y for the structure
		PartA::BuildSDB(tmp&0xff,(tmp&0xff00)>>16); // Build the elements of the structure
		PartA::NumStructUsed++; // One more structure used
		PartA::ExludeAllBadKeys(real_key); // Exclude all the improper keys for the structure
		if(PartA::NumKeysLeft<=2) // If no more then 2 keys remain the break (see the remark above)
			break;
	}
	// Report the findings
	unsigned int *rk_arr=PartA::GetAllRemainingKeys(); // Get all the remaining keys into rk_arr (integer array (int = 32 bits))
	if(PartA::Variants)
		delete[] PartA::Variants; // CleanUp
	printf(" Finished the elimination process (Part A).\nPart A RESULTS:\n ~~~~~~~~~~~~~~~\n");
	printf(" The real first 24 bits of the key were: %06X.\n",real_key);
	printf(" # of the remaining possible keys (bits 1 to 24) is %d.\n",PartA::NumKeysLeft);
	printf(" The remaining possibilities are:\n");
	// Print all the remaining keys
	for(i=0;i<PartA::NumKeysLeft;i++)
	{
		printf(" %06X\n",rk_arr[i]);
	}

	printf("Commencing Part B (*A*A-x->**BB impossible differential):\n");
	printf("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n");
	//first 32 bits
	LIST f32b;
	//bits 33,34
	LIST b33and34;

	for(unsigned int k=0;k<PartA::NumKeysLeft;k++)
	{
		printf("Possibility %d - first 24 bits are: %06X\n",k,rk_arr[k]);
		PartB::Init(); // Initialize 
		// Do the work
		for(i=0;i<90;i++) // Exclude improper key over 90 structures of the type *X*Y for fixed X and Y per structure
		{
			unsigned int tmp=(double)(((double)rand())/((double)RAND_MAX/0x10000)); // build the X and Y for the structure
			PartB::BuildSDB(tmp&0xff,(tmp&0xff00)>>16); // Build the elements of the structure
			PartB::NumStructUsed++; // One more structure used
			PartB::ExludeAllBadKeys(rk_arr[k]); // Exclude all the improper keys for the structure
			if((PartB::NumKeysLeft<=1)&&(PartB::lttpl>30000)) // If no more then 1 keys remain the break (see the remark above)
				break;
		}
		// Report the findings
		if(PartB::NumKeysLeft>0)
		{
			unsigned int *arr=PartB::GetAllRemainingKeys(); // Get all the remaining keys into rk_arr (integer array (int = 32 bits))
			for(int j=0;j<PartB::NumKeysLeft;j++)
			{
				unsigned int kb_1_to_32=((rk_arr[k]&0xffffff)<<8)|((arr[j]&0x3fc)>>2);
				unsigned int kb_33_to_34=arr[j]&0x3;
				f32b.push_back(kb_1_to_32);
				b33and34.push_back(kb_33_to_34);
			}
			delete[] arr;
		}
		if(PartB::Variants)
			delete[] PartB::Variants;
	}

	printf(" Finished the elimination process (Part B).\n Part B RESULTS:\n ~~~~~~~~~~~~~~~\n");
	printf(" The real bits 1 to 34 of the key were: %08X%X.\n",f32kb,s2kb);
	printf(" # of the possible key bits 1 to 34 is %d.\n",f32b.size());
	printf(" The remaining possibilities are:\n");
	// Print all the remaining possibilities for key bits 1 to 34
	for(i=0;i<f32b.size();i++)
	{
		printf(" %08X%X\n",f32b[i],b33and34[i]);
	}
	delete[] rk_arr;

	printf("Commencing Part C (exhaustive search on all remaining possible keys O(2^30)):\n");
	printf("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n");

	// Prepare a plain text/cipher text pair (for primary filtering)
	unsigned int XX1[4];	
	unsigned int YY1[4];
	XX1[0]=0x23;
	XX1[1]=0x5f;
	XX1[2]=0x72;
	XX1[3]=0xda;
	cip(XX1,YY1,KEY);
	unsigned int out1=(YY1[0]<<24)|(YY1[1]<<16)|(YY1[2]<<8)|(YY1[3]);

	unsigned int YY_T[4];
	LIST good_key_arr[2];

	// Perform the Exhaustive Search - O(2^30)
	for(i=0;i<f32b.size();i++)
	{
		uk[0]=f32b[i];
		uk[1]=b33and34[i]<<30;
		key4elimInit(uk);
		printf("Possibility # %d, with first 34 bits: %08X%X\n",i,f32b[i],b33and34[i]);
		for(unsigned int k=0;k<(1<<30);k++)
		{
			uk[1]&=0xc0000000;
			uk[1]|=k;
			key4elim(uk);
			unsigned int tout;
			cip(XX1,YY_T,KEY);
			tout=(YY_T[0]<<24)|(YY_T[1]<<16)|(YY_T[2]<<8)|(YY_T[3]);
			if(tout==out1)
			{
				good_key_arr[0].push_back(uk[0]);
				good_key_arr[1].push_back(uk[1]);
			}
			if(!(k%500000))
			{
				printf("Step %d, # of good keys - %d\n",k/500000,good_key_arr[0].size());
			}
		}
	}
	printf("Part C RESULT:\n~~~~~~~~~~~~~~\n");
	printf("The Real Key: %08X%08X\n",f32kb,l32kb);

	// Do the final filtering (with only few possibilities left for the whole 64 bit key

	for(int ind=0;;ind++)
	{
		// Second plain text/cipher text pair (for secondary filtering), we expect that only one will be
		// needed for the secondary filtering (with very high probability)
		uk[0]=f32kb;
		uk[1]=l32kb;
		key4elimInit(uk);
		key4elim(uk);		
		unsigned int XX2[4];
		unsigned int YY2[4];
		XX2[0]=0xe5;
		XX2[1]=(count+0x3f)%256; // just to vary the samples
		XX2[2]=(count+0xb1)%256; // just to vary the samples
		XX2[3]=0xf9;
		cip(XX2,YY2,KEY);
		unsigned int out2=(YY2[0]<<24)|(YY2[1]<<16)|(YY2[2]<<8)|(YY2[3]);
		std::vector<unsigned>::iterator m0, m1;
		// Iterate on all the 64 bit keys (expected to be a few, around 4)
		// that have passed the primary filtering
		for((m0=good_key_arr[0].begin(),m1=good_key_arr[1].begin());m0!=good_key_arr[0].end();(m0++,m1++))
		{
			// Preperare the test case
			uk[0]=*m0;
			uk[1]=*m1;
			key4elimInit(uk);
			key4elim(uk);
			cip(XX2,YY_T,KEY);
			unsigned int tout=(YY_T[0]<<24)|(YY_T[1]<<16)|(YY_T[2]<<8)|(YY_T[3]);
			if(tout!=out2)
			{
				m0=good_key_arr[0].erase(m0);
				m1=good_key_arr[1].erase(m1);
			}
			if(m0==good_key_arr[0].end())
				break;
		}
		if(good_key_arr[0].size()==1)
		{
			printf("The full 64 bit key that was found by the attack: %08X%08X\n",good_key_arr[0][0],good_key_arr[1][0]);
			break;
		}
	}
	printf("Part A used: %d structures (2^16 plaintexts each).\n",PartA::NumStructUsed);
	printf("Part B used: %d structures (2^16 plaintexts each).\n",PartB::NumStructUsed);
}
