/* Copyright (C) 2011  Beat Röthlisberger
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <octave/oct.h>
#include <cmath>
#include <iostream>

using namespace std;


/* Forward declaration of auxiliary functions */
int getIndex(const vector<int> &inds, const vector<int> &dims);
vector<int> getIndices(const int index, const vector<int> &dims);


/* C++ implementation of pTrace_script.m */
octave_value_list fun_rec(const octave_value_list &rhs)
{
	int nrhs = rhs.length();

    if (nrhs != 3) 
    { 
        error("Three input arguments required"); 
        return octave_value_list();
    } 

    // Get the input arguments
    ComplexMatrix rho = rhs(0).complex_matrix_value();
    
    // These variables must be row vectors!
	Matrix systems = rhs(1).matrix_value();
    Matrix dims_tmp = rhs(2).matrix_value();

    // TODO: ALSO INSERT THIS IN THE C (MATLAB) FILE!
    if (systems.dims()(0) != 1 || dims_tmp.dims()(0) != 1)
    {
        error("Lists of systems to trace out and system dimensions must both be row vectors."); 
        return octave_value_list();
    }     

	vector<int> dimensions; 
	
	for (int i = 0; i < dims_tmp.dims()(1); ++i)
		dimensions.push_back(dims_tmp(0, i));
    
    int nrSystems = systems.dims()(1);

	ComplexMatrix ret;
    
    // If we only have to trace out one subsystem...
    if (nrSystems == 1)
    {
        // M: n = length(dimensions);
        int n = dimensions.size(); 
        int nr = n - 1;

       	int s = systems(0, 0) - 1;
        
        if (s >= n || s < 0)
	    {
	   	    error("Illegal system specified"); 
	        return octave_value_list();
	   	}     

        int m =  1;
        for (int i = 0; i < n; ++i)
            m = m*dimensions[i];
        
        int mc = m;

        if (mc != rho.dims()(0) || mc != rho.dims()(1))
	    {
	        error("Density matrix and specified dimensions are incompatible"); 
	        return octave_value_list();
	    }     

        m = m/dimensions[s];

        vector<int> dimred;
        
        for (int i = 0; i < n; ++i)
        {
        	if (i != s)
        		dimred.push_back(dimensions[i]);
        }
        
        ret = ComplexMatrix(m, m);

		vector<int> inds1, inds2, inds1c, inds2c;

        for (int i = 0; i < m; ++i)
        {
            inds1 = getIndices(i + 1, dimred);
            inds1c = inds1;
            
            vector<int>::iterator itr1 = inds1c.begin();
            for (int r = 0; r < s; ++r) ++itr1;
            inds1c.insert(itr1, 0);


            for (int k = 0; k < m; ++k)
            {
                inds2 = getIndices(k + 1, dimred);
                inds2c = inds2;
	
	            vector<int>::iterator itr2 = inds2c.begin();
	            for (int r = 0; r < s; ++r) ++itr2;
	            inds2c.insert(itr2, 0);

                for (int l = 1; l <= dimensions[s]; ++l)
                {
                    inds1c[s] = l;
                    inds2c[s] = l;
                    
                    ret(i, k) += rho(getIndex(inds1c, dimensions) - 1, getIndex(inds2c, dimensions) - 1);
                }
            }
        }
        
    }

    else
    {
    	ret = rho;
    	
    	for (int i = nrSystems; i > 0; --i)
    	{
	    	octave_value_list tmpList;
	    	tmpList.append(ret);
	    	tmpList.append(octave_value(systems(0, i-1)));
	    	
	    	Matrix dims_oct(1, dimensions.size());
	    	for (int k = 0; k < dimensions.size(); ++k) dims_oct(0, k) = dimensions[k];
	    	tmpList.append(dims_oct);


			tmpList = fun_rec(tmpList);
			
			ret = tmpList(0).complex_matrix_value();
			dimensions.pop_back();			    	
    	}
	}    	
    
    octave_value_list ret_list;
    ret_list.append(ret);
    return ret_list;
}    	




int getIndex(const vector<int> &inds, const vector<int> &dims)
{
    int ind = 0;
	int pr = 1;
    
    for (int i = dims.size() - 1; i >= 0; --i)
    {
        ind = ind + (inds[i] - 1)*pr;
        pr = pr*dims[i];
    }
    
    ind = ind + 1;
    
    return ind;
}




vector<int> getIndices(const int index, const vector<int> &dims)
{
	int n = dims.size();
	
	vector<int> inds;
    int prd = 1;
    
    int idx = index; 
     
     
     for (int i = 0; i < n; ++i)
         prd = prd*dims[i];
 
     for (int i = 1; i < n; ++i)
     {
        prd = prd/dims[i - 1];
        int tmp = (idx - 1)/prd;
         
        inds.push_back(tmp + 1);
        idx = idx - tmp*prd;
     }
 
    inds.push_back(idx);

    return inds;
}


DEFUN_DLD(pTrace, rhs, nlhs, "*****\nIn Octave, please type 'help pTrace.m'\n*****\n\n")
{
    if (nlhs > 1) 
    {
        error("Too many output arguments"); 
        return octave_value_list();
    }     


	return fun_rec(rhs);
}



