//
//  quantum register wavefunction implementatiom 
//
//  
//
//
//  last modified 19/07/2004

#ifndef _QUBITS_H_
#define _QUBITS_H_


#include <stdlib.h>
#include <stdio.h>
#include <ctype.h>
#include <math.h>
#include <complex>
#include <string.h>
#include "qInlines.h"

using std::complex;

//Les qubits sont de la forme |x_(n-1) x_(n-2)... x_1 x_0>

class QBitsWaveFunction
{ 
  public:
  int n;               // number of qubits
  long N;
  complex<double> *wf;  // pointer to Wave Function buffer
  char WFname[16];     // name of Wave Function (optional)
  
//-------- Constructor of wavefunction object, nq - number of qubits  
  
    QBitsWaveFunction(int nq,char* name=0) // 
	{ n=nq;
	  N=1<<n;
	  if(name) { strncpy(WFname,name,15); WFname[15]=0;}
	  else strcpy(WFname,"unnamed");
	  wf=(complex<double>*)calloc(N,sizeof(complex<double>));
	  if(wf==0) 
	    { printf("Fail to init wave function %s\n",WFname);
	      abort();
	    };	
	};
	
    QBitsWaveFunction(const QBitsWaveFunction& QBwf, char* name="unnamed") // cloning of wave function
	{ n=QBwf.n; N=QBwf.N;
	  strncpy(WFname,name,15); WFname[15]=0; //truncate name by 15 chars
	  wf=(complex<double>*)calloc(N,sizeof(complex<double>));
	  if(wf==0) 
	    { printf("Fail to init wave function %s\n",WFname);
	      abort();
	    };	
	  memcpy(wf,QBwf.wf,N*sizeof(complex<double>));      
	};
	
    ~QBitsWaveFunction() { free(wf);};

// equating two wave functions
QBitsWaveFunction& operator = (const QBitsWaveFunction& QBwf) 
	{ if(N!=QBwf.N)
          { n=QBwf.n;
	    N=QBwf.N;
	    if(wf) free(wf);
	    wf=(complex<double>*)calloc(N,sizeof(complex<double>));
	    if(wf==0) 
	      { printf("Fail to init wave function %s\n",WFname);
	        abort();
	      };
	  };	
	  strcpy(WFname,"equal_tmp");
	  memcpy(wf,QBwf.wf,N*sizeof(complex<double>));
	  return *this;      
	};


// rescaling of wave function by a scalar factor
QBitsWaveFunction& RescaleBy(complex<double> a) 
	{ if(wf==0) 
	    { printf("error: wave function %s is not inited\n",WFname);
	      abort();
	    };	
          for(long i=0; i<N; i++) wf[i]*=a;
	  return *this;    
	};

// these functions are for efficiency reason only, in order to avoid excessive temporary objects

QBitsWaveFunction& Sum(QBitsWaveFunction& QBwf1,QBitsWaveFunction& QBwf2) 
	{ 
	  if(QBwf1.N!=QBwf2.N) 
	  { printf("error: sum of vectors %s,%s of different lengths\n",QBwf1.WFname,QBwf2.WFname);
	    abort();
	  };
          for(long i=0; i<QBwf1.N; i++) wf[i]=QBwf1.wf[i]+QBwf2.wf[i];
	  return *this;    
	};

QBitsWaveFunction& Product(complex<double> a, QBitsWaveFunction& QBwf) 
	{ *this=QBwf;
          RescaleBy(a);
	  return *this;    
	};
	
QBitsWaveFunction& Allign() 
	{ 
          for(long i=0; i<N; i++) wf[i]=0;
	  wf[N-1]=1;
	  return *this;    
	};
	
QBitsWaveFunction& SetZero() 
	{ 
          for(long i=0; i<N; i++) wf[i]=0;
	  wf[0]=1;
	  return *this;    
	};

QBitsWaveFunction& Allone() 
	{ 
          for(long i=0; i<N; i++) wf[i]=1;
	  return *this;    
	};

QBitsWaveFunction& Alli() 
	{ 
          for(long i=0; i<N; i++) wf[i]=i;
	  return *this;    
	};



//------------------------------ operations on j-th qubit
//   	
QBitsWaveFunction& SigmaX(int j) 
	{ if(j<0 || j>=n) { printf("X: j=%d out of range\n",j); abort();}; 
	  int i1=1<<j;
	  complex<double> tmp;
          for(int i=0; i<N; i++) 
	    if(i1&i) { tmp=wf[i]; wf[i]=wf[i-i1]; wf[i-i1]=tmp;}; //flip the state
	  return *this;    
	};

QBitsWaveFunction& SigmaY(int j) 
	{ if(j<0 || j>=n) { printf("Y: j=%d out of range\n",j); abort();}; 
	  int i1=1<<j;
	  complex<double> tmp;
          for(int i=0; i<N; i++) 
	    if(i1&i) { tmp=I*wf[i]; wf[i]=I*wf[i-i1]; wf[i-i1]=-tmp;}; //flip the state
	  return *this;    
	};


QBitsWaveFunction& SigmaZ(int j) 
	{ if(j<0 || j>=n) { printf("Z: j=%d out of range\n",j); abort();}; 
	  int i1=1<<j;
          for(int i=0; i<N; i++) 
	    if(i1&i) wf[i]=-wf[i];
	  return *this;    
	};

QBitsWaveFunction& WH_tr(int j)   //Walsh-Hadamard transformation
	{ if(j<0 || j>=n) { printf("WH: j=%d out of range\n",j); abort();}; 
	  int i1=1<<j;
	  complex<double> tmp0,tmp1;
	  const double sqrt12=1./sqrt(2);
          for(int i=0; i<N; i++) 
	    if(i1&i) 
	    { tmp0=wf[i-i1];  tmp1=wf[i];
	      wf[i-i1]=(tmp0+tmp1)*sqrt12;   
	      wf[i]=(tmp0-tmp1)*sqrt12;
	    }; 
	  return *this;    
	};

QBitsWaveFunction& RotateQBit(int j, double Angle) 
	{ if(j<0 || j>=n) { printf("rotQBit: j=%d out of range\n",j); abort();}; 
	  int i1=1<<j;
	  complex<double> U0=exp(I*(Angle/2.)), U1=conj(U0);
          for(int i=0; i<N; i++) 
	    if(i1&i) 
	      { wf[i-i1]*=U0;
	        wf[i]*=U1;
	      }; 
	  return *this;    
	};

QBitsWaveFunction& InteractQBits(int j1, int j2, double g) 
	{ if(j1<0 || j1>=n) { printf("InteractBits: j1=%d out of range\n",j1); abort();}; 
	  if(j2<0 || j2>=n) { printf("InteractBits: j2=%d out of range\n",j2); abort();}; 
	  int i1=1<<j1,i2=1<<j2;
	  complex<double> U0=exp(I*g), U1=conj(U0);
          for(int i=0; i<N; i++) 
	    if(i1&i && i2&i) 
	      { wf[i-i1]*=U0;
	        wf[i]*=U1;
		wf[i-i2]*=U0;
		wf[i-i1-i2]*=U1;
	      }; 
	  return *this;    
	};

QBitsWaveFunction& Cnot_tr(int j1,int j2) 
	{ if(j1<0 || j1>=n) { printf("Cnot: j1=%d out of range\n",j1); abort();}; 
	  if(j2<0 || j2>=n) { printf("Cnot: j2=%d out of range\n",j2); abort();}; 
	  int i1=1<<j1,i2=1<<j2;
	  complex<double> tmp;
          for(int i=0; i<N; i++) 
	    if(i1&i) 
	      if(i2&i){ tmp=wf[i]; wf[i]=wf[i-i2]; wf[i-i2]=tmp;}; //flip the state
	  return *this;    
	};

 QBitsWaveFunction& CCnot_tr(int j1,int j2, int j3) 
 { if(j1<0 || j1>=n) { printf("CCnot: j1=%d out of range\n",j1); abort();}; 
   if(j2<0 || j2>=n) { printf("CCnot: j2=%d out of range\n",j2); abort();}; 
   if(j3<0 || j3>=n) { printf("CCnot: j3=%d out of range\n",j3); abort();}; 
   int i1=1<<j1,i2=1<<j2,i3=1<<j3;
   complex<double> tmp;
   for(int i=0; i<N; i++) 
     if(i1&i) 
       if(i2&i)
	 if(i3&i){ tmp=wf[i]; wf[i]=wf[i-i3]; wf[i-i3]=tmp;}; //flip the state
	  return *this;    
 };


QBitsWaveFunction& CRotQBit(int i, int j, double Angle) 
	{ if(j<0 || j>=n) { printf("rotQBit: j=%d out of range\n",j); abort();}; 
	  int i1=1<<i, i2=1<<j;
	  complex<double> U0=exp(I*(Angle));
          for(int k=0; k<N; k++) 
	    if(i1&k)        //rotation on j controlled by the qubit i
	      if(i2&k)
		wf[k]*=U0;
	  return *this;    
	};

 
//=========Quantum Fourier Transform with gates H and rotations==============
QBitsWaveFunction& Qft_tr() 
                {
		 for(int i=n-1; i>-1; i--)
		   { 
		     WH_tr(i);
		     for(int j=2; j<i+2; j++)
		       {
			 int ik=1<<j;
			 CRotQBit(i-j+1,i,6.283185307/ik);
		       } 
		   }
		 for(int i=0;i<n*.5;i++)
		   { Cnot_tr(i, n-1-i);Cnot_tr(n-1-i, i); Cnot_tr(i, n-1-i);}
		 return *this; 
		};


//=========Quantum Fourier Transform with FFT ==============
QBitsWaveFunction& Fft_tr(int sgn) 
                 {
		   if(sgn*sgn!=1)
		     {printf("Qft: argument must be 1 or -1"); abort();};
		   complex<double> sgnipi=3.1415926535897932384626*sgn*I;
		   int j=0, m=0;
		   complex<double> att=0;
		   for(int i=0; i<N; i++)
		     {
		       if(j>i)
			 {att=wf[i];wf[i]=wf[j]; wf[j]=att;}
		       m=int(N*0.5);
		       while((m>1)&&(j>=m))
			 {
			   j-=m;
			   m=int(m*0.5);
			 }
		       j+=m;  
		     }
		   int mmax=1;
		   double immax=1;
		   int pas=0;
		   complex<double> w=1,wwp=1;
		   while(mmax<N)
		     {
		       pas=2*mmax;
		       immax=1.0/mmax;
		       wwp=exp(sgnipi*immax);
		       w=1.;
		       for(m=0;m<mmax;m++)
			 {
			   for(int i=m;i<N;i+=pas)
			     {
			       att=w*wf[i+mmax];
			       wf[i+mmax]=wf[i]-att;
			       wf[i]+=att;
			     }
			   w*=wwp;
			 }
		       mmax=pas;
		     }
		   double norme=1./sqrt(N);
		   for(int i=0;i<N;i++)
		     wf[i]*=norme;
		   return *this; 
		 };
 
//--------------------------

	
//Participation ratio function
double participation()
        {
	  double num=0, den=0,aux=0;
	  for(int i=0; i<N;i++)
	    {aux=abs(wf[i]);
	      aux*=aux;
	      num+=aux;
	      den+=aux*aux;
	    }
	  if(den==0)
	    {printf("Qubit is zero\n");return 0.0;}
	  else
	    return num/den;
	}
	
//Norme
double norm()
        {
	  double norme=0;
	  for(int uu=0;uu<N;uu++)
	    norme+=abs(wf[uu])*abs(wf[uu]);
	  return norme;
	}
 
void print()
        { 
	  printf("%s =\n",WFname);
	  for(int i=0; i<N;i++)
	    printf("|%d>\t  %f+i*%f\n",i,real(wf[i]),imag(wf[i]));
	  };  
 
}; //=============== end of class "QBitsWaveFunction" 
 

// right-side multiplication by scalar 
QBitsWaveFunction operator * (const QBitsWaveFunction& QBwf, complex<double> a)
	{ QBitsWaveFunction QBwf1(QBwf,"mult_tmp");
          QBwf1.RescaleBy(a);
	  return QBwf1;    
	};

// left-side multiplication by scalar 
QBitsWaveFunction operator * (complex<double> a, QBitsWaveFunction& QBwf) 
	{ QBitsWaveFunction QBwf1(QBwf,"mult_tmp");
          QBwf1.RescaleBy(a);
	  return QBwf1;    
	};

// scalar product of two wave functions
complex<double> operator * (QBitsWaveFunction& QBwf1,QBitsWaveFunction& QBwf2) 
	{ 
	  if(QBwf1.N!=QBwf2.N) 
	  { printf("error: scalar product of vectors %s,%s of different lengths\n",QBwf1.WFname,QBwf2.WFname);
	    abort();
	  };
          complex<double> s=0;
          for(long i=0; i<QBwf1.N; i++) s+=conj(QBwf1.wf[i])*QBwf2.wf[i];
	  return s;    
	};

// sum of two wave functions
QBitsWaveFunction operator + (QBitsWaveFunction& QBwf1,QBitsWaveFunction& QBwf2) 
	{ 
	  if(QBwf1.N!=QBwf2.N) 
	  { printf("error: sum of vectors %s,%s of different lengths\n",QBwf1.WFname,QBwf2.WFname);
	    abort();
	  };
          QBitsWaveFunction QBwf(QBwf1.n,"sum_tmp");
          for(long i=0; i<QBwf1.N; i++) QBwf.wf[i]=QBwf1.wf[i]+QBwf2.wf[i];
	  return QBwf;    
	};

#endif
