/* Copyright (C) 2011  Beat Röthlisberger
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <octave/oct.h>
#include <math.h>

using namespace std;

/* C++ implementation of buildUnitary_script.m  (OCTAVE port) */
DEFUN_DLD(buildUnitary, rhs, nlhs, "*****\nIn Octave, please type 'help buildUnitary.m'\n*****\n\n")
{
	int nrhs = rhs.length();
	
	if (nrhs != 3) 
    { 
        error("Three input arguments required."); 
        return octave_value_list();
    } 
    else if (nlhs > 1) 
    {
        error("Too many output arguments.");
		return octave_value_list();
    }     

    Matrix X = rhs(0).matrix_value();    
    int m = rhs(1).int_value();
    int n = rhs(2).int_value();

    if (n <=0 || m < n)
    {
		error("Illegal input. Need n > 0, m >= n");
		return octave_value_list();
    }
    
    dim_vector X_dims = X.dims();
    
    if (X_dims(0) != 1 && X_dims(1) != 1)
    {
        error("Angles must be in vector form.");
		return octave_value_list();
    }       
    
    if (max(X_dims(0), X_dims(1)) != n*(2*m - n))
    {
        error("Wrong number of angles. Need n(2m - n).");
		return octave_value_list();
    }       

    int pt = 0;
    int pp = m*n - (n*(n + 1))/2;
    int pc = 2*pp;

    ComplexMatrix U(m, n);
    
    // Fill diagonal with phases
    for (int i = 0; i < n; i++)
    	U(i, i) = exp(Complex(0, X(pc + i)));
    
	int cnt = 0;
	
    for (int i = n - 1; i >= 0; i--)
    {
        for (int k = i; k < m - 1; k++)
        {
            double tXc = cos(X(pt + cnt));
            double tXs = sin(X(pt + cnt));

            Complex pX = exp(Complex(0, X(pp + cnt)));
            Complex pXconj = conj(pX);
            
            for (int l = i; l < n; l++)
            {
                Complex tmp = U(k, l);
                
				U(k,   l) = tXc*pXconj*tmp - tXs*pXconj*U(k+1, l);
				U(k+1, l) = tXs*pX*tmp + tXc*pX*U(k+1, l);
            }

            cnt++;
        }
    }  

	octave_value_list ret;
	ret.append(U);

	return ret;
}
