%grad_convexSum    Gradient of a convex sum over a pure state entanglement measure.
%   G = grad_convexSum(U, PURE, GRAD_PURE, CHI, LAMBDA) calculates the gradient
%   of a convex sum h(U) at a specific pure-state decomposition parameterized by the
%   Stiefel matrix U. U must be of size k x r, where r is the length of LAMBDA, 
%   and k >= r. G has the same dimensions as U.
%
%   PURE is the pure-state measure, and GRAD_PURE its gradient. Both functions must 
%   take a state vector as their only argument. [CHI LAMBDA] is the eigendecomposition 
%   of a density matrix RHO, where LAMBDA is a column vector holding the positive eigenvalues 
%   of RHO and CHI the corresponding column matrix of eigenvectors.
%
%   See also: convexSum, densityEig

% Copyright (C) 2011 Beat Röthlisberger
%
% This program is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
% 
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
% 
% You should have received a copy of the GNU General Public License
% along with this program.  If not, see <http://www.gnu.org/licenses/

function res = grad_convexSum(U, PURE, GRAD_PURE, V, D)

% V == CHI
% D == LAMBDA

r = length(D);

[psi, p] = psDecomposition(U(:, 1:r), V, D);

res = zeros(size(U));

for k = 1:size(U, 1)
    
    val_p  = PURE(psi(:, k));
    grad_p = GRAD_PURE(psi(:, k));
    
    for l = 1:r

        phi_r = sqrt(D(l)*p(k))*V(:, l) - D(l)*real(U(k, l))*psi(:, k);
        phi_i = sqrt(-1)*sqrt(D(l)*p(k))*V(:, l) - D(l)*imag(U(k, l))*psi(:, k);

        res(k, l) = 2*D(l)*real(U(k, l))*val_p + (real(grad_p)'*real(phi_r) + imag(grad_p)'*imag(phi_r));
        res(k, l) = res(k, l) + sqrt(-1)*(2*D(l)*imag(U(k, l))*val_p + (real(grad_p)'*real(phi_i) + imag(grad_p)'*imag(phi_i)));
    end
end
