% l1eqmodcs_large.m
%
% Solve Modified CS for large size data
% This function is written by modifying l1eq_pd.m function from Candes' l1-magic
% package
%
% min_x ||x_{T^c}||_1  s.t.  Ax = b
%
% Recast as linear program
% min_{x,u} sum(u)  s.t.  -u <= x_{T^c} <= u,  Ax=b
% and use primal-dual interior point method
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%% Parameters%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% x0 - Nx1 vector, initial point.
%
% A - a function handle to a function that takes a N vector and returns a K 
%     vector.  The algorithm operates in "largescale" mode, solving the Newton systems via the
%     Conjugate Gradients algorithm.
%
% At - Handle to a function that takes a K vector and returns an N vector.
%      If A is a KxN matrix, At is ignored.
%
% b - Kx1 vector of observations.
%
% P - The same length vector as the signal with elements in T^c equal to 1
% and 0 in T
%
% mask - Sampling mask
%
% pdtol - Tolerance for primal-dual algorithm (algorithm terminates if
%     the duality gap is less than pdtol).  
%     Default = 1e-3.
%
% pdmaxiter - Maximum number of primal-dual iterations.  
%     Default = 20.
%
% cgtol - Tolerance for Conjugate Gradients; ignored if A is a matrix.
%     Default = 1e-9.
%
% cgmaxiter - Maximum number of iterations for Conjugate Gradients; ignored
%     if A is a matrix.
%     Default = 800.
%    



function xp = l1eqmodcs_large(x0, A, At, b, P, mask, pdtol, pdmaxiter, cgtol, cgmaxiter)

largescale = isa(A,'function_handle');

if ~largescale
    
    disp('A must be a function handle to implement large scale data')
    
    return
end

if (nargin < 7), pdtol = 1e-4;  end
if (nargin < 8), pdmaxiter = 20;  end
if (nargin < 9), cgtol = 1e-9;  end
if (nargin < 10), cgmaxiter = 800;  end

N = length(x0);

alpha = 0.01;
beta = 0.5;
mu = 5;



x = x0;

Tnz=find(P==0);

Tc=find(P==1);

u = (0.95)*abs(x0(Tc)) + (0.1)*max(abs(x0(Tc)));

gradf0 = [zeros(N,1); ones(length(Tc),1)];


fu1 = x(Tc) - u;
fu2 = -x(Tc) - u;



lamu1full = zeros(length(x),1);
lamu2full = zeros(length(x),1);
lamu1full(Tc) = -(1./fu1);
lamu2full(Tc) = -(1./fu2);

lamu1 = lamu1full(Tc);
lamu2 = lamu2full(Tc);

  v = -A(lamu1full-lamu2full);
  Atv = At(v);
  rpri = A(x) - b;


sdg = -(fu1'*lamu1 + fu2'*lamu2);
tau = mu*2*N/sdg*4;

rcent = [-(lamu1.*fu1); -(lamu2.*fu2)] - (1/tau);
rdual = gradf0 + [P0t(lamu1-lamu2,Tc,N); (-lamu1-lamu2)] + [Atv; zeros(length(Tc),1)];
resnorm = norm([rdual; rcent; rpri]);

pditer = 0;
done = (sdg < pdtol) | (pditer >= pdmaxiter) |(norm(rpri)<0.01);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%function handle definition
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
AsubT = @(z) Grsub(z,A,Tnz,N);                 %A_T
AtsubT = @(z) Grtsub(z,At,Tnz);                %A'_T
AsubTc= @(z) Grsub(z,A,Tc,N);                  %A_(Tc)
AtsubTc= @(z) Grtsub(z,At,Tc);                 %A'_(Tc)




while (~done)
  
  pditer = pditer + 1;
  w1 = -1/tau*P0t(-1./fu1 + 1./fu2,Tc,N) - Atv;
  w2 = -1 - 1/tau*(1./fu1 + 1./fu2);
  w3 = -rpri;
  
  
  sig1 = (-lamu1./fu1 - lamu2./fu2);
  sig2 = (lamu1./fu1 - lamu2./fu2);
  sigx = sig1-(sig2./sig1).*sig2;
  
  
  

        AhessianFT= @(z) hessianmxFT(z,AsubT,AtsubT,AsubTc,AtsubTc,sigx,Tnz,Tc,mask);
        faitb=@(z)([AtsubT(z(1:length(A(x))));-AsubTc((1./sigx).*AtsubTc(z(1:length(A(x)))))+AsubT(z(length(A(x))+1:end))]);
        
        C0=AsubTc((1./sigx).*(-1/tau*(1./fu1-1./fu2)+(sig2./sig1).*w2))+AsubTc((1./sigx).*AtsubTc(v))+w3;
        C=[C0;-AtsubT(v)];
        w1p=real(faitb(C));

        [dxTv, cgres, cgiter] = cgsolve(AhessianFT, w1p, cgtol, cgmaxiter, 0); 
        
         
        
       if (cgres > 1/2)
           disp('Primal-dual: Cannot solve system.  Returning previous iterate.');
           xp = x;
           return
       end

       dxT=dxTv(1:length(Tnz));
       dv=dxTv(length(Tnz)+1:end);
       dxTc=(-1./sigx).*AtsubTc(v+dv)+(1./sigx).*((1./fu1-1./fu2)/tau-(sig2./sig1).*w2);
       dx=zeros(length(x0),1); 
       dx(Tnz)=dxT;
       dx(Tc)=dxTc;
       du=(1./sig1).*(w2-sig2.*dxTc);
       Adx=A(dx);
       Atdv=At(dv);
        
     

  
  
  dlamu1 = (lamu1./fu1).*(-P0(dx,Tc)+du) - lamu1 - (1/tau)./fu1;
  dlamu2 = (lamu2./fu2).*(P0(dx,Tc)+du) - lamu2 - 1/tau./fu2;


  % make sure that the step is feasible: keeps lamu1,lamu2 > 0, fu1,fu2 < 0
  indp = find(dlamu1 < 0);  
  indn = find(dlamu2 < 0);
  s = min([1; -lamu1(indp)./dlamu1(indp); -lamu2(indn)./dlamu2(indn)]);
  indp = find((dx(Tc)-du) > 0);
  indn = find((-dx(Tc)-du) > 0);
  s = (0.99)*min([s; -fu1(indp)./(dx(Tc(indp))-du(indp)); -fu2(indn)./(-dx(Tc(indn))-du(indn))]);


  lamu1full(Tc)=lamu1;
  lamu2full(Tc)=lamu2;

  % backtracking line search 
  backiter = 0;
  xp = x + s*dx;  up = u + s*du; 
  vp = v + s*dv;  Atvp = Atv + s*Atdv; 
  lamu1p = lamu1 + s*dlamu1;  lamu2p = lamu2 + s*dlamu2;
  lamu1pfull = lamu1full; lamu2pfull = lamu2full;
  lamu1pfull(Tc)=lamu1p; lamu2pfull(Tc)=lamu2p;
  fu1p = xp(Tc) - up;  fu2p = -xp(Tc) - up;  
  rdp = gradf0 + [P0t(lamu1p-lamu2p,Tc,N); (-lamu1p-lamu2p)] + [Atvp; zeros(length(Tc),1)];
  rcp = [-lamu1p.*fu1p; -lamu2p.*fu2p] - (1/tau);
  rpp = rpri + s*Adx;
  

  
  
  while(norm([rdp; rcp; rpp]) > (1-alpha*s)*resnorm)
  
    s = beta*s;
      
    xp = x + s*dx;  up = u + s*du; 
    vp = v + s*dv;  Atvp = Atv + s*Atdv; 
    lamu1p = lamu1 + s*dlamu1;  lamu2p = lamu2 + s*dlamu2;
    lamu1pfull(Tc)=lamu1p; lamu2pfull(Tc)=lamu2p;
    fu1p = xp(Tc) - up;  fu2p = -xp(Tc) - up;  
    rdp = gradf0 + [P0t(lamu1p-lamu2p,Tc,N); (-lamu1p-lamu2p)] + [Atvp; zeros(length(Tc),1)];
    rcp = [-lamu1p.*fu1p; -lamu2p.*fu2p] - (1/tau);
    rpp = rpri + s*Adx;
    backiter = backiter+1;
    backiter
    if (backiter > 50)
      disp('Stuck backtracking, returning last iterate.')
      xp = x;
   
      return
    end
  end

  
  % next iteration
  x = xp;  u = up;
  v = vp;  Atv = Atvp; 
  lamu1 = lamu1p;  lamu2 = lamu2p;
  fu1 = fu1p;  fu2 = fu2p;
  
  % surrogate duality gap
  sdg = -(fu1'*lamu1 + fu2'*lamu2);
  tau = mu*2*N/sdg*4;
  rpri = rpp;
  rcent = [-(lamu1.*fu1); -(lamu2.*fu2)] - (1/tau);
  rdual = gradf0 + [P0t(lamu1p-lamu2p,Tc,N); (-lamu1-lamu2)] + [Atv; zeros(length(Tc),1)];
  resnorm = norm([rdual; rcent; rpri]);
  
  done = (sdg < pdtol) | (pditer >= pdmaxiter) |(norm(rpri)<0.001);

  disp(sprintf('Iteration = %d, Primal = %8.3e, PDGap = %8.3e, Dual res = %8.3e, Primal res = %8.3e',...
    pditer,  sum(u), sdg, norm(rdual), norm(rpri)));

  disp(sprintf('                  CG Res = %8.3e, CG Iter = %d'));

  
end

%%%%P0 is to extract a vector x on T^c
function y= P0(x,Tc)
y=x(Tc);

%%%P0t is to to create a vector x=[y; 0]
function x= P0t(y,Tc,N)
x=zeros(N,1);
x(Tc)=y;