/* Copyright (C) 2011  Beat Röthlisberger
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "mex.h"
#include <math.h>
#include <stdio.h>
#include <string.h>

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]);


/* Forward declaration of auxiliary functions */
int getIndex(const int *inds, const int n, const int *dims);
int *getIndices(const int index, const int n, const int *dims);


/* C implementation of pTrace_script.m */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    bool isCmplx;
    
    int curr_s, i, ind, indx, j, k, l, m, mc, n, nr, s;
    int *dims, *dimred, *inds1, *inds1c, *inds2, *inds2c;
    
    double tmp_r, tmp_i;
    double *dimsTmp, *rhoR, *rhoI, *retR, *retI;
    
    mxArray *recSys, *ret;
    const mxArray *rho, *systems, *dimensions;

    mxArray **recArray;
    
    
    if (nrhs != 3) 
    { 
        mexErrMsgTxt("Three input arguments required"); 
    } 
    else if (nlhs > 1) 
    {
        mexErrMsgTxt("Too many output arguments"); 
    }     

    /* Get the input arguments */
    rho = prhs[0];
    rhoR = mxGetPr(rho);
    rhoI = mxGetPi(rho);
    
    /* These variables must be row vectors! */
    systems = prhs[1];
    dimensions = prhs[2];
    
    if (mxGetDimensions(systems)[1] == 1)
    {
        /* M: n = length(dimensions); */
        n = (int) mxGetDimensions(dimensions)[1]; 
        nr = n - 1;

        dims = mxCalloc(n, sizeof(int));
        for (i = 0; i < n; i = i + 1)
            dims[i] = (int) mxGetPr(dimensions)[i];
        
        s = (int) mxGetPr(systems)[0] - 1;
        
        if (s + 1 > n || s < 0)
            mxErrMsgTxt("Illegal system specified");

        m =  1;
        for (i = 0; i < n; i = i + 1)
            m = m*dims[i];
        
        mc = m;

        if (mc != mxGetDimensions(rho)[0] || mc != mxGetDimensions(rho)[1])
            mxErrMsgTxt("Density matrix and specified dimensions are incompatible");
        
        m = m/dims[s];

        dimred = mxCalloc(nr, sizeof(int));

        memcpy(dimred, dims, s*sizeof(int));
        memcpy(dimred + s, dims + s + 1, (n - s - 1)*sizeof(int));              
        
        ret = mxCreateDoubleMatrix(m, m, mxCOMPLEX);
        retR = mxGetPr(ret);
        retI = mxGetPi(ret);
        /* Assign return value */
        plhs[0] = ret;

        inds1  = mxCalloc(nr, sizeof(int));
        inds2  = mxCalloc(nr, sizeof(int));
        inds1c  = mxCalloc(n, sizeof(int));
        inds2c  = mxCalloc(n, sizeof(int));
        
        isCmplx = mxIsComplex(rho);
        
        for (i = 0; i < m; i = i + 1)
        {
            inds1 = getIndices(i + 1, nr, dimred);

            memcpy(inds1c, inds1, s*sizeof(int));
            memcpy(inds1c + s + 1, inds1 + s, (n - s - 1)*sizeof(int));              

            for (k = 0; k < m; k = k + 1)
            {
                ind  = m*k + i;
                
                inds2 = getIndices(k + 1, nr, dimred);
                memcpy(inds2c, inds2, s*sizeof(int));
                memcpy(inds2c + s + 1, inds2 + s, (n - s - 1)*sizeof(int));              
                
                tmp_r = 0;
                tmp_i = 0;
            
                for (l = 1; l <= dims[s]; l = l + 1)
                {
                    inds1c[s] = l;
                    inds2c[s] = l;
                    
                    indx = mc*(getIndex(inds2c, n, dims) - 1) + getIndex(inds1c, n, dims) - 1;
                    
                    retR[ind] = retR[ind] + rhoR[indx];
                    
                    if (isCmplx)
                        retI[ind] = retI[ind] + rhoI[indx];
                }
            }
        }
    }
    else
    {
        recArray = mxCalloc(3, sizeof(mxArray *));
        recSys = mxCreateDoubleMatrix(1,1, mxREAL);

        recArray[0] = prhs[0];
        recArray[1] = recSys;
        recArray[2] = dimensions;
        
        n = (int) mxGetDimensions(dimensions)[1];         

        for (i = mxGetDimensions(systems)[1] - 1; i >= 0; i = i - 1)
        {
            curr_s = mxGetPr(systems)[i];
            mxGetPr(recSys)[0] = curr_s;

            mexFunction(nlhs, plhs, 3, recArray);
            
            recArray[0] = plhs[0];
            
            dimsTmp = mxCalloc(n, sizeof(double));
            memcpy(dimsTmp, mxGetPr(recArray[2]), n*sizeof(double));

            n = n - 1;

            recArray[2] = mxCreateDoubleMatrix(1, n, mxREAL);
            memcpy(mxGetPr(recArray[2]), dimsTmp, (curr_s - 1)*sizeof(double));
            memcpy(mxGetPr(recArray[2]) + curr_s - 1, dimsTmp + curr_s, (n - curr_s + 1)*sizeof(double));
        }
    }
}



int getIndex(const int *inds, const int n, const int *dims)
{
    /* n: number of elements in dims */
    
    int i, ind, pr;

    ind = 0;
    pr = 1;
    
    for (i = n-1; i >= 0; i = i - 1)
    {
        ind = ind + (inds[i] - 1)*pr;
        pr = pr*dims[i];
    }
    
    ind = ind + 1;
    
    return ind;
}



int *getIndices(const int index, const int n, const int *dims)
{
    /* n: number of elements in dims */
    int i, ix, prd, tmp;
    int *inds;
    
    inds = mxCalloc(n, sizeof(int));
    
     ix = index;
 
     prd = 1;
     
     for (i = 0; i < n; i = i + 1)
         prd = prd*dims[i];
 
     for (i = 1; i < n; i = i + 1)
     {
        prd = prd/dims[i - 1];
        tmp = (ix - 1)/prd;
         
        inds[i - 1] = tmp + 1;
        ix = ix - tmp*prd;
     }
 
    inds[n - 1] = ix;
    
    return inds;
}

