/* Copyright (C) 2011  Beat Röthlisberger
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "mex.h"
#include <math.h>


/* C implementation of grad_buildUnitary_script.m */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    int cnt, i, ind, k, l, length_x, pu, pl;
    int m, n, pt, pp, pc;
    double tvr, tvi, tXc, tXs, pXr, pXi, tUrl;
    double *Ur, *Ui, *Utmp_r, *Utmp_i, *X;
    
    const double PI = 3.141592653589793;
    
    mxArray *U, *Utmp_;

    if (nrhs != 4) 
    { 
        mexErrMsgTxt("Four input arguments required."); 
    } 
    else if (nlhs > 1) 
    {
        mexErrMsgTxt("Too many output arguments."); 
    }     

    X = mxGetPr(prhs[0]);    
    m = (int) mxGetPr(prhs[1])[0];
    n = (int) mxGetPr(prhs[2])[0];
    length_x = mxGetM(prhs[0]) > mxGetN(prhs[0]) ? mxGetM(prhs[0]) : mxGetN(prhs[0]);    
    
    /* The index of the angle with respect to which the derivative is taken */
    ind = (int) mxGetPr(prhs[3])[0] - 1;

    if (n <=0 || m < n)
        mxErrMsgTxt("Illegal input. Need n > 0, m >= n");
    
    if (mxGetM(prhs[0]) != 1 && mxGetN(prhs[0]) != 1)
        mexErrMsgTxt("Angles must be in vector form.");

    if (length_x != n*(2*m - n))
        mexErrMsgTxt("Wrong number of angles. Need n(2m - n).");
    
    if (ind < 0 || ind >= length_x)
        mexErrMsgTxt("Illegal angle index.");
    
    /* Trick: Just modify the respective angle in this way. No changes to functions required. */
    X[ind] = X[ind] + PI/2;
    
    pt = 0;
    pp = m*n - (n*(n + 1))/2;
    pc = 2*pp;

    U = mxCreateDoubleMatrix(m, n, mxCOMPLEX);
    Ur = mxGetPr(U);
    Ui = mxGetPi(U);

    Utmp_ = mxCreateDoubleMatrix(m, n, mxCOMPLEX);
    Utmp_r = mxGetPr(Utmp_);
    Utmp_i = mxGetPi(Utmp_);
    
    /* Set up phases properly (depending on whether the derivative with respect to a chi is demanded or not). */
    if (ind >= pc)
    {
          Utmp_r[(ind - pc)*(m + 1)] = cos(X[ind]);
          Utmp_i[(ind - pc)*(m + 1)] = sin(X[ind]);

          Ur = Utmp_r;
          Ui = Utmp_i;
          U = Utmp_;
    }
        
    else
    {
        for (i = 0; i < n; i++)
        {    
            Ur[i*(m + 1)] = cos(X[pc + i]);
            Ui[i*(m + 1)] = sin(X[pc + i]);
        }
    }
    
    cnt = 0;
    
    for (i = n - 1; i >= 0; i--)
    {
        for (k = i; k < m - 1; k++)
        {
            tXc = cos(X[pt + cnt]);
            tXs = sin(X[pt + cnt]);

            pXr = cos(X[pp + cnt]);
            pXi = sin(X[pp + cnt]);
            
            for (l = 0; l < n - i; l++)
            {
                pu  = (i + l)*m + k;
                pl  = (i + l)*m + k + 1;
                
                tvr = Ur[pu];
                tvi = Ui[pu];
                
                Ur[pu] = tXc*(pXr*tvr + pXi*tvi) - tXs*(pXr*Ur[pl] + pXi*Ui[pl]);
                Ui[pu] = tXc*(pXr*tvi - pXi*tvr) - tXs*(pXr*Ui[pl] - pXi*Ur[pl]);
                
                tUrl   = Ur[pl];
                
                Ur[pl] = tXs*(pXr*tvr - pXi*tvi) + tXc*(pXr*Ur[pl] - pXi*Ui[pl]);
                Ui[pl] = tXs*(pXr*tvi + pXi*tvr) + tXc*(pXr*Ui[pl] + pXi*tUrl);
            }

            if (pt + cnt == ind || pp + cnt == ind)
            {
                for (l = 0; l < n - i; l++)
                {
                    pu  = (i + l)*m + k;
                    pl  = (i + l)*m + k + 1;
        
                    Utmp_r[pu] = Ur[pu];
                    Utmp_i[pu] = Ui[pu];
                    Utmp_r[pl] = Ur[pl];
                    Utmp_i[pl] = Ui[pl];                    
                }
            
                Ur = Utmp_r;
                Ui = Utmp_i;
                U = Utmp_;
            }

            cnt++;
        }
    }  

    X[ind] = X[ind] - PI/2;    
    plhs[0] = U;
}
