clear;
clc;

addpath(genpath('code-from-other-groups'));
cvx_setup;

load data/generated_data_it_model_new.mat;

n0 = 180; % measurement number for the first 2 signal
n = 33; % measurement number for the rest signal
global m;
m = 256;
seqlen = 100;
s = 26; % support size
mc=100; % number of monte carlo simulation

sigma_obs0 = 0.001; % measurement noise for the first 2 signal
sigma_obs = 0.02; % measurement noise for rest signal

a = 2;
rho = a/2;


for c=1:1:mc
    
    A0 = A0_res{c};
    A = A_res{c};
    y0 = y0_res{c};
    y = y_res{c};
    x = x_res{c};
    
   %% CS
    tic;
    for i=1:1:seqlen
        if(i<=2)
            y2=y0(:,i);
            A2=A0;
            gamma = max(1e-2*max(max(abs(A0'*y0(:,1:2)))),sigma_obs0*sqrt(log(m)));
        else
            y2=y(:,i);
            A2=A;
            gamma = max(1e-2*max(max(abs(A'*y))),sigma_obs*sqrt(log(m)));
        end
        
        cvx_begin
        variable b(m,1) 
        minimize gamma*norm(b,1)+0.5*square_pos(norm(y2-A2*b));
        cvx_end
        
        if sum(isnan(b))>0||isnan(cvx_optval) % if cvx fails, assign zeros vector to xhat
            xhat(:,i)=0;
        else
            xhat(:,i)=b;
        end
        
        error(i)=norm(xhat(:,i)-x(:,i))/norm(x(:,i));
        Nhat{i} = find(abs(xhat(:,i))>rho);
    end
    xhat_cs_cvx(:,:,c)=xhat;
    error_cs_cvx(:,c)=error;
    Nhat_cs_cvx = Nhat;
    t_cs_cvx(c) = toc;
    
    %%% parameter estimation using recovered signals
    Tinit=find(abs(xhat_cs_cvx(:,1,c))>rho);
    deltax = xhat_cs_cvx(:,2,c) - xhat_cs_cvx(:,1,c);
    deltax(setdiff(Nhat_cs_cvx{2},Nhat_cs_cvx{1}))=0;
    sigma_sys_hat = sqrt(sum(deltax.*deltax)/length(Nhat_cs_cvx{2}));
    tau = length(setdiff(Nhat_cs_cvx{1},Nhat_cs_cvx{2}))/length(Nhat_cs_cvx{2}); % parameter for weighted-CS
    
    
    %% CS-residual
    tic;
    for i=2:1:seqlen
        if(i==2)
            xhat(:,i-1)=xhat_cs_cvx(:,i-1,c);
            y2=y0(:,i)-A0*xhat(:,i-1);
            A2=A0;
        else
            y2=y(:,i)-A*xhat(:,i-1);
            A2=A;
        end
        
        cvx_begin
        variable b(m,1) 
        minimize gamma*norm(b,1)+0.5*square_pos(norm(y2-A2*b));
        cvx_end
        
        if sum(isnan(b))>0||isnan(cvx_optval) 
            xhat(:,i)=xhat(:,i-1);
        else
            xhat(:,i)=xhat(:,i-1)+b;
        end
        
        error(i)=norm(xhat(:,i)-x(:,i))/norm(x(:,i));
        Nhat{i} = find(abs(xhat(:,i))>rho);
    end
    xhat_cs_res_cvx(:,:,c)=xhat;
    error_cs_res_cvx(:,c)=error;
    t_cs_res_cvx(c) = toc;
    
    
     %% reg-mod-bpdn-cvx
     
    lambdaset=[0.5 0.2 0.1 0.05 0.01 0.005 0.001 0.0001];
    for i=1:1:2
        if(i==1)
            mu_hat=xhat_cs_cvx(:,i,c);
            T=Nhat_cs_cvx{1};%Tinit;
            Delta=setdiff(Nhat_cs_cvx{i},T);
        else
            mu_hat=xhat(:,i-1);
            T=Nhat{i-1};
            Delta=setdiff(Nhat_cs_cvx{i},T);
        end
        [gammahat(i),lambdahat(i),g(i)]=gamma_star(A,T,Delta,y(:,i)-A*xhat_cs_cvx(:,i,c),y(:,i),xhat_cs_cvx(:,i,c),mu_hat,lambdaset);
        Tc=setdiff(1:m,T);
        
        cvx_begin
        variable b(m,1) 
        minimize gammahat(i)*norm(b(Tc),1)+0.5*lambdahat(i)*square_pos(norm(b(T)-mu_hat(T)))+0.5*square_pos(norm(y(:,i)-A*b));
        cvx_end
        
        if sum(isnan(b))>0||isnan(cvx_optval)
            xhat(:,i)=0;
            Nhat{i}=[];
        else
            xhat(:,i)=b;
            Nhat{i}=find(abs(xhat(:,i))>rho);
        end                
        error(i)=norm(xhat(:,i)-x(:,i))/norm(x(:,i));
    end       
        
    gamma=gammahat(2);
    lambda=lambdahat(2);  
    tic;   
    for i=3:1:seqlen
        if(i<=2)
            mu_hat = xhat_cs_cvx(:,i,c);
            T=find(abs(mu_hat)>rho);
            y2=lambda*xhat_cs_cvx(T,i,c);
            A2=zeros(length(T),m);
            for j=1:1:length(T)
                A2(j,T(j))=lambda;
            end
            y2=[y0(:,i);y2];
            A2=[A0;A2];
        else
            mu_hat = xhat(:,i-1);
            T=Nhat{i-1};
            y2=lambda*xhat(T,i-1);
            A2=zeros(length(T),m);
            for j=1:1:length(T)
                A2(j,T(j))=lambda;
            end
            y2=[y(:,i);y2];
            A2=[A;A2];
        end
        
        Tc=setdiff(1:m,T);
                        
        cvx_begin
        variables xh(m,1)
        minimize gamma*norm(xh(Tc),1)+0.5*lambda*square_pos(norm(xh(T)-mu_hat(T)))+0.5*square_pos(norm(y2-A2*xh));
        cvx_end
        
        if sum(isnan(xh))>0||isnan(cvx_optval)
            xhat(:,i)=0;
            Nhat{i}=[];
            rho_hat(i)=0;
        else
            xhat(:,i)=xh;
        end
        Nhat{i}=find(abs(xhat(:,i))>rho);
        error(i)=norm(xhat(:,i)-x(:,i))/norm(x(:,i));
    end
    
    xhat_reg_mod_bpdn_cvx(:,:,c)=xhat;
    error_reg_mod_bpdn_cvx(:,c)=error;
    t_reg_mod_bpdn_cvx(c) = toc;

    
    %% mod-bpdn
    lambdaset=0.0001;
    for i=2:1:2
        if(i==2)
            mu_hat=xhat_cs_cvx(:,i-1,c);
            T=Nhat_cs_cvx{i-1};
            Delta=setdiff(Nhat_cs_cvx{i},T);
        end
        [gammahat(i),lambdahat(i),g(i)]=gamma_star(A,T,Delta,y(:,i)-A*xhat_cs_cvx(:,i,c),y(:,i),xhat_cs_cvx(:,i,c),mu_hat,lambdaset);
        Tc=setdiff(1:m,T);        
        cvx_begin
        variable b(m,1) 
        minimize gammahat(i)*norm(b(Tc),1)+0.5*square_pos(norm(y(:,i)-A*b));
        cvx_end
        
        if sum(isnan(b))>0||isnan(cvx_optval)
            xhat(:,i)=0;
            Nhat{i}=[];
        else
            xhat(:,i)=b;
            [Nhat{i}] = find(abs(xhat(:,i))>rho);
        end
        error(i)=norm(xhat(:,i)-x(:,i))/norm(x(:,i));
    end
    
    gamma=gammahat(2);  
    
    tic;
    for i=3:1:seqlen
        if(i==2)
            T=[Tinit];
            y2=y0(:,i);
            A2=A0;
        else
            T=Nhat{i-1};
            y2=y(:,i);
            A2=A;
        end
        Tc=setdiff(1:m,T);
        
        cvx_begin
        variable b(m,1) 
        minimize gamma*norm(b(Tc),1)+0.5*square_pos(norm(y2-A2*b));
        cvx_end
        
        if sum(isnan(b))>0||isnan(cvx_optval)
            xhat(:,i)=0;
            Nhat{i}=[];
            rho_hat(i)=0;
        else
            xhat(:,i)=b;
            [Nhat{i}] = find(abs(xhat(:,i))>rho);
        end
        error(i)=norm(xhat(:,i)-x(:,i))/norm(x(:,i));        
    end
    xhat_mod_bpdn_cvx(:,:,c)=xhat;
    error_mod_bpdn_cvx(:,c)=error;
    t_mod_bpdn_cvx(c) = toc;         
    
    
    %% kf-modcs
    tic;
    Ieye=eye(m);
    for i=2:1:seqlen        
        if(i==2)
            xhat(:,i-1)=xhat_cs_cvx(:,i,c);
            y2=y0(:,i)-A0*xhat(:,i-1);
            A2=A0;
            T=Tinit;
            P_prev = sigma_sys_hat^2*Ieye(:,T)*Ieye(:,T)';
        else
            y2=y(:,i)-A*xhat(:,i-1);
            A2=A;
            T=Nhat{i-1};
            P_prev=P_t;
        end
        Tc=setdiff(1:m,T);
                
        cvx_begin
        variable b(m,1) 
        minimize gamma*norm(b(Tc),1)+0.5*square_pos(norm(y2-A2*b));
        cvx_end
        
        if sum(isnan(b))>0||isnan(cvx_optval)
            xhat_modcs(:,i)=xhat(:,i-1);
        else
            xhat_modcs(:,i)=xhat(:,i-1)+b;
        end
                
        Nhat{i}=find(abs(xhat_modcs(:,i))>rho);

        if(i==2)
            xhat(:,i)=xhat_modcs(:,i);
            P_t=sigma_sys_hat^2*Ieye(:,Nhat{i})*Ieye(:,Nhat{i})';
        else
            Qhat=sigma_sys_hat^2*Ieye(:,Nhat{i})*Ieye(:,Nhat{i})';
            K=(P_prev+Qhat)*A'*inv(A*(P_prev+Qhat)*A'+sigma_obs^2*eye(n));
            P_t=(Ieye-K*A)*(P_prev+Qhat);
            xhat(:,i)=(Ieye-K*A)*xhat(:,i-1)+K*y(:,i);
        end
        error(i)=norm(xhat(:,i)-x(:,i))/norm(x(:,i));
    end
    xhat_kf_modcs_cvx(:,:,c)=xhat;
    error_kf_modcs_cvx(:,c)=error;
    t_kf_modcs_cvx(c) = toc; 
    
    
    
    % ===================Run T-MSBL step by 5 =================
    tic;
    X_tsbl5 = [];
    for i=1:seqlen/5
        
        % According to the SNR range, choose suitable input arguments
        % See the codes for details
        [X_tsbl1] = TMSBL(A, y(:,[(i-1)*5+1:i*5]), 'noise','mild');
        
        
        X_tsbl5 = [X_tsbl5,X_tsbl1];
    end
    t_tsbl5(c) = toc;
    
    for t = 1:seqlen
        error_tsbl5(c,t) = norm( x(:,t)-X_tsbl5(:,t) ) / norm(x(:,t));
    end
    
    
    % =================== l1 homotopy =================
    addpath(genpath('L1_homotopy_v2.0'))
    tic;
    %% Initialize by solving a rwt L1 problem
    in = [];
    tau = max(1e-2*max(max(abs(A'*y))),sigma_obs*sqrt(log(m)));
    in.tau = tau; W = tau;
    in.W = W;
    in.delx_mode = 'mil';
    in.debias = 0;
    in.verbose = 0;
    in.plots = 0;
    in.record = 0;
    
    for wt_itr = 1:5
        
        out = l1homotopy(A,y(:,1),in);
        xh = out.x_out;
        iter_bpdn = out.iter;
        gamma = out.gamma;
        
        % Update weights
        xh_old = xh;
        
        alpha = 1; epsilon = 1;
        beta = n*(norm(xh,2)/norm(xh,1))^2;
        W = tau/alpha./(beta*abs(xh)+epsilon);
        
        W_old = W;
        Atr = A'*(A*xh-y(:,1));
        u =  -W.*sign(xh)-Atr;
        pk_old = Atr+u;
        
        in = out;
        in.xh_old = xh;
        in.pk_old = pk_old;
        in.u = u;
        in.W_old = W_old;
        in.W = W;
    end
    
    for i=1:1:seqlen
        
        out = l1homotopy(A,y(:,1),in);
        xh = out.x_out;
        iter_bpdn = out.iter;
        gamma = out.gamma;
        
        xhat(:,i) = out.x_out;
        error_l1homotopy(c,i) = norm(xhat(:,i) - x(:,i)) / norm(x(:,t));
        
        % Update weights
        xh_old = xh;
        
        alpha = 1; epsilon = 1;
        beta = n*(norm(xh,2)/norm(xh,1))^1;
        W = tau/alpha./(beta*abs(xh)+epsilon);
        
        W_old = W;
        Atr = A'*(A*xh-y(:,1));
        u =  -W.*sign(xh)-Atr;
        pk_old = Atr+u;
        
        in = out;
        in.xh_old = xh;
        in.pk_old = pk_old;
        in.u = u;
        in.W_old = W_old;
        in.W = W;
        AtAgx = A(:,gamma)'*A(:,gamma);
        iAtAgx = pinv(AtAgx);
        in.iAtA = iAtAgx;        
    end
    t_l1homotopy(c) = toc;
    
    % =================== Hierarchical KF =================
    tic;
    for i=1:1:seqlen
        yv{i} = y(:,i);
        Phi{i} = A;
    end
    
    [w_recovered] = HierarchicalKalmanFilter(yv,Phi,0.05);
    
    for i=1:1:seqlen
        
        xhat(:,i)=w_recovered(:,i);
        
        error(i) = norm(x(:,i)-xhat(:,i))/norm(x(:,i));
        
    end
    xhat_hkf(:,:,c) = xhat;
    error_hkf(:,c) = error;
    t_hkf(c) = toc;
    
    %% CS MUSIC
    addpath CSMUSIC
    tic;
    k=5;
    [actSet, Xh] = CSMUSIC(y, A, k,sigma_obs*sqrt(n));
    xhat_csmusic(:,:,c)=Xh;
    error=sqrt(sum(abs(Xh-x).^2))./sqrt(sum(abs(x.^2)));
    error_csmusic(:,c)=error';
    t_csmusic(c) = toc;      
    
    
    %% DCS-AMP
    addpath DCS_AMP_v1_0/Functions
    addpath DCS_AMP_v1_0/ClassDefs
    tic;
    for t=1:1:seqlen
        if t<=2
            A_dcs{t} = A0;
            y_dcs{t} = y0(:,t);
        else
            A_dcs{t} = A;
            y_dcs{t}= y(:,t);
        end
    end
    
    Params.lambda_0 = (length(Nhat_cs_cvx{1})+length(Nhat_cs_cvx{2}))/2./ones(m,1); %1/16 *ones(m,1);
    Params.pz1 = length(setdiff(Nhat_cs_cvx{1},Nhat_cs_cvx{2}))./ones(m,1); %0.01*ones(m,1);
    Params.p1z = length(setdiff(Nhat_cs_cvx{2},Nhat_cs_cvx{1}))./ones(m,1);  %Params.lambda_0.*Params.pz1./(1 - Params.lambda_0);    
    % Amplitude parameters
    Params.eta_0 = mean(xhat_cs_cvx(:,1:2,c),2); %zeros(m,1);    
    N0 = intersect(Nhat_cs_cvx{1},Nhat_cs_cvx{2});
    Params.kappa_0 = mean(var(xhat_cs_cvx(N0,1:2,c)')); %1e3;            
    Params.alpha = corr(xhat_cs_cvx(:,1,c),xhat_cs_cvx(:,2,c)); %.05*ones(m,1);
    Params.rho = var(xhat_cs_cvx(N0,2,c)' - xhat_cs_cvx(N0,1,c)'); %1e4;    
    Params.eps = 1e-7;
    Params.sig2e = sigma_obs^2;    
    Options.smooth_iter = -1;       % # of fwd/bwd passes (-1 to filter)
    Options.eq_iter = 10;           % # of inner AMP iterations
    Options.alg = 2;                % AMP
    Options.update = 1;             % Update hyperparameters during execution
    %Options.upd_groups = [{approx_coeff_group}, {detail_coeff_group}];
    Options.upd_groups{1} = 1:m;
    Options.verbose = 1;            % Print msgs
    
    [x_hat, v_hat, lambda_hat] = sp_multi_frame_fxn(y_dcs, A_dcs, Params, Options);
    
    for t = 1:1:seqlen
        xhat_dcsamp(:,t,c)=x_hat{t};
        error_dcsamp(t,c)=norm(x_hat{t}-x(:,t))/norm(x(:,t));
    end
    t_dcsamp(c) = toc;      
    
    %% PM-CS-KF
    addpath PM-CS-KF
    tic;
    Pk0=eye(m)*1e2;
    beta_k=zeros(m,1);
    beta_kp=beta_k;
    p=0.5;
    n_iterations = 100;
    if p==0.5
        Re_p = 20000^2;
    elseif p ==0.7
        Re_p = 1000^2;
    elseif p==0
        Re_p = 100^2;
    end
    
    Q = eye(m)*1;
    R = eye(n)*sigma_obs^2;
    t_cskf1 = 0;
    Pk = Pk0;
    for i=2:1:seqlen        
        [beta_k, Pk]=CSKF_p(beta_k, Pk, y(:,i), A, R, Re_p, n_iterations,p);  %%%  Algorithm CSKF-p is executing        
        Pk = Pk + Q;       
        xhat(:,i)=beta_k;        
        error(i) = norm(x(:,i)-beta_k)/norm(x(:,i));       
    end
    
    xhat_kf_cs(:,:,c)=xhat;
    error_kf_cs(:,c)=error;
    t_kf_cs(c) = toc; 
    
    
    
end


T = find(isnan(error_dcsamp));
error_dcsamp(T) = 1;

tx = 1:size(error_cs_cvx,1);
figure;
plot(tx,mean(error_l1homotopy,1),'r*-'); hold on;
plot(tx,mean(error_tsbl5,1),'b>-');
plot(tx,mean(error_reg_mod_bpdn_cvx,2),'r+-');
plot(tx,mean(error_mod_bpdn_cvx(tx,:),2),'bo-');
plot(tx,mean(error_cs_res_cvx(tx,:),2),'k-');
plot(tx,mean(error_cs_cvx(tx,:),2),'gs-');
plot(tx,mean(error_csmusic(tx,:),2),'b<-');
plot(tx,mean(error_dcsamp(tx,:),2),'bs-');
plot(tx,mean(error_kf_cs(tx,:),2),'r<-');
plot(tx,mean(error_kf_modcs_cvx(tx,:),2),'b^-');
legend('l1-Homotopy','T-SBL','Reg-Mod-BPDN','mod-BPDN','CS-res(BPDN-res)','BPDN','CS-MUSIC','DCS-AMP','PM-CS-KF','KMoCS');
xlabel('t')
ylabel('NRMSE')
