%Implement Qiu Interference model as adapted for IEEE 802.15.4

%Assumptions: (some are for simplication)
% - each node has equal, fixed packet length
% - each node has a single, fixed destination
% - not consider any state that has at least two nodes being in the same synchronization group
% - not considering packet loss due to receiver sensitivity setting (as in 802.11)

function [pi L]=Qiu_Interference_Model(R, d, B_n, beta_m, delta_n, numNodes)
%Return:
%   pi(i, :): format: [real-state-i  stationary-prob.-of-state-i]
%   L(m): packet loss rate from node m to its destination
%arguments:
%   R: interference-free receiver-side power; R{m, n} format: [mean  var] in regular value (not in dBm)
%   d: traffic demand and destination; d{m} format: [d_mn  n(i.e., destination of m)]
%   B_n: background noise in regular value (not in dBm); format: [mean var]
%   beta_m: CCA threshold in regular value (not in dBm)
%   delta_n: SINR threshold in regular value (not in dB)

global C CW OH Q T_mu_const;

%Analysis control
Constant_Q_m = 1; %do not adapt Q(m) as in Qiu's model, as a way of approximation to study the impact of traffic load

%%model parameters
SYNC_GROUP_CLEAR_PROB_THRESHOLD = 0.1;
MIN_TRANSITION_PROB = 0.001;
ITERRATION_TERMINATION_CHANGE_RATIO = 0.001;
ALPHA = 0.9;

%%constands
numStates = bitshift(1, numNodes);

%%802.15.4 parameters: 1 backoff slot is similar to 10-bytes transmission time at 250kbps rate
%CW_max = 32;%max backoff window size in # of backoff slots, with each slot being 20 symbol durations
%INIT_BF = 3; %initial exponent
OH_const = 2; %fixed contention window: 2 backoff slots, with each slot being 20 symbol durations
T_ACK = 1.5; %ack frame is 15 bytes long

%packet length
T_mu_const = 13;%  %31 bytes (symbols?)


%%Calculating synchronization group:
%%  Format: sync_grp(m, n) = 1 if m and n are synchronized; = 0 otherwise
disp('Is calculating synchronization group ...');
sync_grp = [];
for m=1:numNodes
    sync_grp(m) = 0; %each bit represent whether corresponding node is synchronized with m
end
for m=1:numNodes-1
    for n=m+1:numNodes
        state_m = bitshift(1, m-1);
        state_n = bitshift(1, n-1);
        if C_m_i(R, B_n, beta_m, delta_n, d, numNodes, m, state_n) < SYNC_GROUP_CLEAR_PROB_THRESHOLD && ...
           C_m_i(R, B_n, beta_m, delta_n, d, numNodes, n, state_m) < SYNC_GROUP_CLEAR_PROB_THRESHOLD
            m_bit = bitshift(1, m-1);
            n_bit = bitshift(1, n-1);
            sync_grp(m) = bitor(sync_grp(m), n_bit);
            sync_grp(n) = bitor(sync_grp(n), m_bit);
        end
    end
end
%sync_grp
disp('   Done with calculating synchronization group.');


%%Prune # of states based on synchronization group, s.t. a state does not
%%  conntain any nodes that are syncrhonized
%%  Results are sorted in pi(:, 1)
disp('Pruning states that contain synchronizaed nodes ...');
pi = [];
pi(1, :) = [0 -1];
states_prev = [0]; 
for numBitOne=1:numNodes
    %disp(['      round ' num2str(numBitOne) ': ' num2str(size(pi, 1)) ' states so far ...']);
    
    newStates = [];
    for indState=1:length(states_prev) %for each state found in the last round
        thisState = states_prev(indState);
        for indBit=0:numNodes-1
            tp = bitshift(1, indBit);
            if bitand(thisState, tp) == 0 && bitand(thisState, sync_grp(indBit+1)) == 0 %a new state && no synchronized node
                stateToAdd = bitor(thisState, tp);
                pi = [pi; ...
                      [stateToAdd   -1] ...
                     ];
                 newStates = [newStates stateToAdd];
            end
        end
    end
    states_prev = unique(newStates);
    clear newStates;
end
pi = unique(pi, 'rows');
numWorkingStates = size(pi, 1);
pi(:, 2) = repmat(1/numWorkingStates, numWorkingStates, 1); %initialize to be equal prob.; needed for calculating CW(m)
disp(['     numWorkingStates = ' num2str(numWorkingStates)]);
disp('   Done with pruning states.');
%yun = hongwei

%%calculate C{m|S_i}; formate: C(m, ind_i), where ind_i is the index s.t. pi(index, 1) = i
C = [];
disp('Is calculating C(m|S_i) ...');
for m=1:numNodes
    %disp(['      for m = ' num2str(m)]);
    for ind_i=1:numWorkingStates
        C(m, ind_i) = C_m_i(R, B_n, beta_m, delta_n, d, numNodes, m, pi(ind_i, 1)); 
    end
end
%C
disp('   Done with Calculating C(m|S_i) for all m and i.');


%%%Calculate stationary state distribution using Qiu's interference model
disp('Calculate stationary state distribution ...');
L = [];
Q = [];
for m=1:numNodes
    L(m) = 0;
    if Constant_Q_m ~= 1
        Q(m) = 1;
    else
        Q(m) = d{m}(1);
    end
end
converged = 0;
round = 0;
while converged == 0
    round = round + 1;
    disp(['   Round ' num2str(round) ' ...']);
    converged = 1;
    %%sender model
    %calc. CW(m), OH(m)
    clr = [];
    for m=1:numNodes
        clr(m) = 0;
        for ind_i=1:size(pi, 1)
            clr(m) = clr(m) + pi(ind_i, 2) * C(m, ind_i);
        end
        if clr(m) > 1
            clr(m) = 1;
        end
        CW(m) = 0;
        for k=0:4
            CW(m) = CW(m) + (min(32, realpow(2, 3+k))/2) * realpow(1-clr(m), k);
        end
        %
        OH(m) = OH_const + (1 - L(m)) * T_ACK;
    end
    %clr
    %CW
    %derive transition matrix M
    M = [];
    for ind1=1:size(pi, 1)
        for ind2=1:size(pi, 1)
            i = pi(ind1, 1);
            j = pi(ind2, 1);
            S_i = nodesOfAState(i, numNodes);
            S_j = nodesOfAState(j, numNodes);
            %%compute M(ind1, ind2), i.e., M(i, j) as in Markov Chain
            M(ind1, ind2) = 1;
            %
            nodeSet = intersect(setdiff(1:numNodes, S_i), ...
                                setdiff(1:numNodes, S_j) ...
                               );
            for ind=1:length(nodeSet)
                M(ind1, ind2) = M(ind1, ind2) * P_00(nodeSet(ind), ind1);
            end
            %
            nodeSet = intersect(setdiff(1:numNodes, S_i), ...
                                S_j ...
                               );
            for ind=1:length(nodeSet)
                M(ind1, ind2) = M(ind1, ind2) * P_01(nodeSet(ind), ind1);
            end
            %
            nodeSet = intersect(S_i, ...
                                setdiff(1:numNodes, S_j) ...
                               );
            for ind=1:length(nodeSet)
                M(ind1, ind2) = M(ind1, ind2) * P_10(nodeSet(ind), ind1);
            end
            %
            nodeSet = intersect(S_i, S_j);
            for ind=1:length(nodeSet)
                M(ind1, ind2) = M(ind1, ind2) * P_11(nodeSet(ind), ind1);
            end
        end
    end
    for ind1=1:size(pi, 1)
        for ind2=1:size(pi, 1)
            %prune too small transition prob.
            if M(ind1, ind2) < MIN_TRANSITION_PROB
                rowInd = find(M(ind1, :) >= MIN_TRANSITION_PROB);
                colInd = find(M(:, ind2) >= MIN_TRANSITION_PROB);
                if ~isempty(rowInd) && ~isempty(colInd)
                    M(ind1, ind2) = 0;
                end
            end
        end
        %normalization: needed due to approximation
        %         sumExitProb = sum(M(ind1, :));
        %         M(ind1, :) = M(ind1, :) * (1 / sumExitProb);
    end 
    %M
    %compute stationary prob. pi_i
    A = []; 
    b = [];
    A = M';
    for k=1:size(M, 1)
        A(k, k) = A(k, k) - 1;
        A(size(M, 1)+1, k) = 1;
    end
    b = zeros(size(M, 1)+1, 1);
    b(size(M, 1) + 1) = 1;
    x = lsqr(A, b);
    pi(:, 2) = x;
    disp(['        sumProb = ' num2str(sum(x))]);
    %compute Q_new(m)_new if need be
    Q_new = [];
    for m=1:numNodes
        if d{m}(1) < 1
            t_m_prev = 0;
            for ind_ii=1:size(pi, 1)
                if bitand(pi(ind_ii, 1), bitshift(1, m-1)) ~= 0
                    t_m_prev = t_m_prev + pi(ind_ii, 2);
                end
            end
            %t_m_prev
            Q_new(m) = Q(m) * (d{m}(1) / (1 - d{m}(1))) * ((1 - t_m_prev) / t_m_prev);
            if Q_new(m) > 1
                Q_new(m) = 1;
            end
        else 
            Q_new(m) = Q(m);
        end
    end
    %Q_new
    %%receiver model: compute L_mn (i.e. L)
    %note: in our case, only need to calc. L_mn_asyn, since both L_mn_rss and L_mn_syn are 0
    L_new = [];
    for m=1:numNodes
        L_new(m) = L_mn_asyn(R, B_n, delta_n, d, pi, M, numNodes, m);
    end
    %%relaxation for quick convergence; and test for convergence
    for m=1:numNodes
        %
        oldL = L(m);
        L(m) = ALPHA * L_new(m) + (1 - ALPHA) * oldL;
        if (oldL ~= 0 && abs((L(m) - oldL) / oldL) > ITERRATION_TERMINATION_CHANGE_RATIO) || ...
           (L(m) ~= 0 && abs((L(m) - oldL) / L(m)) > ITERRATION_TERMINATION_CHANGE_RATIO)
            converged = 0;
        end
        %
        oldQ = Q(m);
        if Constant_Q_m ~= 1
            Q(m) = ALPHA * Q_new(m) + (1 - ALPHA) * oldQ;
        end
        if (oldQ ~= 0 && abs((Q(m) - oldQ) / oldQ) > ITERRATION_TERMINATION_CHANGE_RATIO) || ...
           (Q(m) ~= 0 && abs((Q(m) - oldQ) / Q(m)) > ITERRATION_TERMINATION_CHANGE_RATIO)
            converged = 0;
        end
    end
    
    disp('      done with one round.');
end %converged == 0
disp('Done!!!');



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% SUB-FUNCTIONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function prob=L_mn_asyn(R, B_n, delta_n, d, pi, M, numNodes, m)
tp = l_mn_asyn(R, B_n, delta_n, d, pi, M, numNodes, m);
prob = 1 - (1 - tp) * exp(-1 * tp / (1 - tp));

function prob=l_mn_asyn(R, B_n, delta_n, d, pi, M, numNodes, m)
numerator = 0;
denominator = 0;
for ind=1:size(pi, 1)
    nodes = nodesOfAState(pi(ind, 1), numNodes);
    if ~isempty(find(nodes == m))
        tp = l_mn_Si(R, B_n, delta_n, d, pi, M, m, d{m}(2), pi(ind, 1), numNodes);
        numerator = numerator + pi(ind, 2) * tp;
        denominator = denominator + pi(ind, 2);
    end
end
prob = numerator / denominator;

function prob=l_mn_Si(R, B_n, delta_n, d, pi, M, m, n, i, numNodes)
tp1 = l_mn_Si_C1(R, B_n, delta_n, d, m, n, i, numNodes);
tp2 = l_mn_Si_C2_C3(R, B_n, delta_n, d, pi, M, m, n, i, numNodes);
prob = 1 - (1 - tp1) * (1 - tp2);

function prob=l_mn_Si_C1(R, B_n, delta_n, d, m, n, i, numNodes)
senders = setdiff(nodesOfAState(i, numNodes), [m n]);
[mu var] = sum_log_normal_variables(R, B_n, delta_n, d, numNodes, senders, n, [], []);
E_log_Rmn = 2 * log(R{m, n}(1)) - 0.5 * log(R{m, n}(2) + realpow(R{m, n}(1), 2));
Var_log_Rmn = log(R{m, n}(2) + realpow(R{m, n}(1), 2)) - 2 * log(R{m, n}(1));
mu_prime = E_log_Rmn - mu;
var_prime = Var_log_Rmn + var;
prob = normcdf(log(delta_n), mu_prime, sqrt(var_prime));

function prob=l_mn_Si_C2_C3(R, B_n, delta_n, d, pi, M, m, n, i, numNodes)
senders = nodesOfAState(i, numNodes);
prob = 0;
ind_i = find(pi(:, 1) == i);
for ind=1:length(senders) %each node is a sync-group based on the assumptions
    i_prime = bitand(i, bitcmp(bitshift(1, senders(ind)-1), numNodes));
    ind_i_prime = find(pi(:, 1) == i_prime);
    if isempty(ind_i_prime)
        continue
    end
    if senders(ind) == m
        tp1 = l_mn_Si_C3(R, B_n, delta_n, d, m, n, i, numNodes);
        prob = prob + (M(ind_i, ind_i_prime) / (1 - M(ind_i, ind_i))) * tp1;
    else
        tp2 = l_mn_Si_C2_G(R, B_n, delta_n, d, m, n, i, [senders(ind)], numNodes);
        prob = prob + (M(ind_i, ind_i_prime) / (1 - M(ind_i, ind_i))) * tp2;
    end
end

function prob=l_mn_Si_C2_G(R, B_n, delta_n, d, m, n, i, G, numNodes)
senders = setdiff(nodesOfAState(i, numNodes), ...
                  union([m n], G) ...
                 );
[mu var] = sum_log_normal_variables(R, B_n, delta_n, d, numNodes, senders, n, setdiff(G, n), i);
E_log_Rmn = 2 * log(R{m, n}(1)) - 0.5 * log(R{m, n}(2) + realpow(R{m, n}(1), 2));
Var_log_Rmn = log(R{m, n}(2) + realpow(R{m, n}(1), 2)) - 2 * log(R{m, n}(1));
mu_prime = E_log_Rmn - mu;
var_prime = Var_log_Rmn + var;
prob = normcdf(log(delta_n), mu_prime, sqrt(var_prime));

function prob=l_mn_Si_C3(R, B_n, delta_n, d, m, n, i, numNodes)
senders = setdiff(nodesOfAState(i, numNodes), [m n]);
[mu var] = sum_log_normal_variables(R, B_n, delta_n, d, numNodes, senders, m, [], []);
E_log_Rnm = 2 * log(R{n, m}(1)) - 0.5 * log(R{n, m}(2) + realpow(R{n, m}(1), 2));
Var_log_Rnm = log(R{n, m}(2) + realpow(R{n, m}(1), 2)) - 2 * log(R{n, m}(1));
mu_prime = E_log_Rnm - mu;
var_prime = Var_log_Rnm + var;
prob = normcdf(log(delta_n), mu_prime, sqrt(var_prime));

function [mu var]=sum_log_normal_variables(R, B_n, delta_n, d, numNodes, senders, rcver, senders_ack, i)
E_I = B_n(1);
var_I = B_n(2);
for ind=1:length(senders)
    E_I = E_I + R{senders(ind), rcver}(1);
    var_I = var_I + R{senders(ind), rcver}(2);
end
if ~isempty(senders_ack)
    for ind=1:length(senders_ack)
        p = l_mn_Si_C1(R, B_n, delta_n, d, senders_ack(ind), d{senders_ack(ind)}(2), i, numNodes);
        E_I = E_I + R{senders_ack(ind), rcver}(1) * (1 - p);
        var_I = var_I + R{senders_ack(ind), rcver}(2) * (1 - p);
    end
end
E_I2 = var_I + realpow(E_I, 2);
mu = 2 * log(E_I) - (log(E_I2) / 2);
var = log(E_I2) - (2 * log(E_I));


%%%
function value=P_01(m, ind_i)
global C CW OH Q;
value = C(m, ind_i) * Q(m) / (CW(m) + OH(m)); 

function value=P_00(m, ind_i)
value = 1 - P_01(m, ind_i);

function value=P_10(m, ind_i)
global T_mu_const;
value = 1/T_mu_const;

function value=P_11(m, ind_i)
value = P_10(m, ind_i);

%%%
function prob=C_m_i(R, B_n, beta_m, delta_n, d, numNodes, m, real_i)
[mu var] = sum_log_normal_variables(R, B_n, delta_n, d, numNodes, setdiff(nodesOfAState(real_i, numNodes), m), m, [], []);
prob = normcdf(log(beta_m), mu, sqrt(var));
%C_m_Si = [m  i  C(m, i)  mu   sqrt(sigma2)  10*log10(beta_m)  10*log10(E_I_m_Si)  10*log10(var_I_m_Si)  E_I2_m_Si] %debug
%pause

%%%
function nodes=nodesOfAState(state, totalNumNodes)
tp = state;
nodes = [];
if bitand(tp, 1) == 1
    nodes = [1];
end
for nodeToCheck=2:totalNumNodes
    tp = bitshift(tp, -1);
    if bitand(tp, 1) == 1
        nodes = [nodes nodeToCheck];
    end
end
