#= 
------------------------------------------------------
Functions used for horse-race counterfactuals.
All of these are accessed through the entry point 
model_solve.jl
------------------------------------------------------
Counterfactuals are named as follows:
    1. Fixed investment policy function.
    2. Fixed entry policy function.
    3. Fixed investment and entry policy function.
------------------------------------------------------
Tristany Armangué-Jubert, June 2024
------------------------------------------------------
=#

# Dependencies 

### COUNTERFACTUAL WITH FIXED INVESTMENT POLICY
# Function to solve the VFs given some measures of workers/entrepreneurs and an investment policy function
function solveVF_cf1(μ::Matrix{Float64}, ϕ::Matrix{Float64}, M::Matrix{Float64}, L::Float64, E::Float64, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, ρᶻ::Array)
    # Size of grids
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Guess values
    V = ones(SIZE_Z, SIZE_A)
    Ũ = ones(SIZE_Z, SIZE_A)

    # Initial guess for wages 
    W_0 = ones(SIZE_Z, SIZE_A) .* 0.5

    # Tensor for amenities 
    A_j = repeat(grid_a',SIZE_Z)

    # Outer loop
    while true
        # Containers for updates 
        TV = zeros(SIZE_Z, SIZE_A)
        TŨ = zeros(SIZE_Z, SIZE_A)

        # Make V_+ 
        V₊ = zeros(SIZE_Z,SIZE_A)
        V₊[1:SIZE_Z-1,:] .= V[2:SIZE_Z,:]
        V₊[SIZE_Z,:] .= V[SIZE_Z,:]

        # Make V_-
        V₋ = zeros(SIZE_Z,SIZE_A)
        V₋[2:SIZE_Z,:] .= V[1:SIZE_Z-1,:]
        V₋[1,:] .= V[1,:]

        # Make Ũ₊
        Ũ₊ = zeros(SIZE_Z,SIZE_A)
        Ũ₊[1:SIZE_Z-1,:] .= Ũ[2:SIZE_Z,:]
        Ũ₊[SIZE_Z,:] .= Ũ[SIZE_Z,:]

        # Make Ũ₋
        Ũ₋ = zeros(SIZE_Z,SIZE_A)
        Ũ₋[2:SIZE_Z,:] .= Ũ[1:SIZE_Z-1,:]
        Ũ₋[1,:] .= Ũ[1,:]

        # Solve the firm problem given measures and value functions
        W_p,Π_j,L_j = solveWP(μ, ϕ, grid_z, grid_a, 1e-4, V, Ũ, W_0, par, L, E, V₊, V₋, Ũ₊, Ũ₋,A_j)

        # Compute V^I and V^N 
        Vᴵ = par.ϵᴸ .* u.(Π_j.- par.c_z)  .+ A_j .+ par.β .* (1 .- par.δ) .* (par.p_i .* max.(V₊, Ũ₊) .+ (1-par.p_i) .* max.(V₋, Ũ₋))
        Vᴺ = par.ϵᴸ .* u.(Π_j)            .+ A_j .+ par.β .* (1 .- par.δ) .* (par.p_n .* max.(V₊, Ũ₊) .+ (1-par.p_n) .* max.(V₋, Ũ₋))

        # Update V
        TV .= ρᶻ .* Vᴵ .+ (1 .- ρᶻ) .* Vᴺ

        # Compute  U 
        U = zeros(SIZE_Z,SIZE_A,SIZE_Z,SIZE_A)
        @simd for idx_zi=1:SIZE_Z
            for idx_ai=1:SIZE_A
                for idx_zj=1:SIZE_Z
                    for idx_aj=1:SIZE_A
                        a_j = grid_a[idx_aj]
                        if idx_zi == 1
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * max(V[idx_zi+1,idx_ai],Ũ[idx_zi+1,idx_ai]) + (1-par.p_n)*max(V[idx_zi,idx_ai],Ũ[idx_zi,idx_ai]))
                        elseif idx_zi == SIZE_Z
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * max(V[idx_zi,idx_ai],Ũ[idx_zi,idx_ai]) + (1-par.p_n)*max(V[idx_zi-1,idx_ai],Ũ[idx_zi-1,idx_ai]))
                        else
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * max(V[idx_zi+1,idx_ai],Ũ[idx_zi+1,idx_ai]) + (1-par.p_n)*max(V[idx_zi-1,idx_ai],Ũ[idx_zi-1,idx_ai]))
                        end
                    end
                end
            end
        end

        # Update Ũ
        @simd for idx_zi=1:SIZE_Z
            for idx_ai=1:SIZE_A
                TŨ[idx_zi,idx_ai] = par.ν * log(E * sum(exp.((U[idx_zi,idx_ai,:,:]) ./ (par.ν)) .* μ))
            end
        end 
        
        # Compute deviation 
        dev = maximum(abs.(TŨ-Ũ)) + maximum(abs.(TV-V))

        if dev > tol
            # Weight convergence
            Ũ = deepcopy(TŨ)
            V = deepcopy(TV)
            W_0 = deepcopy(W_p)
        else 
            # Return solutions
            return TV,U,TŨ,Π_j,L_j,W_p
        end
    end
end

# Function to solve for equilibrium entry given distribution 
function solveEq_cf1(M::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, ρᶻ::Array)
    # Size of grids 
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Initial guess  
    ρᵉ = zeros(SIZE_Z,SIZE_A)
    ρᵉ[end-Int(floor(SIZE_Z/2)):end, end-Int(floor(SIZE_A/2)):end] .= 1

    # Outer loop 
    iter = 0
    prev = 9e10
    prev2 = 9e9
    while iter < 25
        # Update counter 
        iter = iter + 1

        # Compute μ and ϕ
        μ = zeros(size(M))
        ϕ = zeros(size(M))
        μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
        ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

        # Compute L and E 
        L = sum(ϕ) * par.Λ
        E = sum(μ) * par.Λ

        # Normalize distributions 
        μ .= μ./sum(μ)
        ϕ .= ϕ./sum(ϕ)
        
        # Solve the value functions and wage schedule
        V,U,Ũ,Π_j,L_j,W_p = solveVF_cf1(μ, ϕ, M, L, E, grid_z, grid_a, tol, par, ρᶻ)

        # Compute ranking of competitiveness
        ranking = ordinalrank(Π_j,rev=true)

        # Update entry 
        tρᵉ = zeros(size(M))
        tρᵉ[V .> Ũ] .= 1

        # Compute deviation
        dev = sum(abs.(tρᵉ-ρᵉ))

        if (dev == 0) | (prev == dev == 1)
            # ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ
            return tρᵉ,ρᶻ, Π_j, L_j, W_p, V, U, Ũ
        else
            # Mask changers and non-changers 
            mask = (ρᵉ .!== tρᵉ)
            mask_ = (ρᵉ .== tρᵉ)
            if (dev != prev) 
                # Obtain ranks of changers (and min and max)
                R = zeros(size(M))
                R[mask] .= ranking[mask]
                # display(R)
                maxr = maximum(R)
                R[mask_] .= Inf
                minr = minimum(R)
                # Set threshold
                splt = (maxr-minr)*(3/4) + minr
                # Set changers with productivity above threshold to 1, others to 0
                nmask = (R .<= splt)
                # display(mask)
                # display(nmask)
                R[mask_] .= -Inf
                nmask_ = (R .> splt)
                # display(nmask_)
                # Apply masks
                ρᵉ[nmask] .= 1.0
                ρᵉ[nmask_] .= 0.0
            else
                ρᵉ = deepcopy(tρᵉ)
            end
            # Update prev 
            prev = dev
            prev2 = prev
        end
    end
    
    # Catch non-convergence
    iter = 0
    prev = 9e10
    prev2 = 9e9
    # Initial guess  
    ρᵉ = zeros(SIZE_Z,SIZE_A)
    ρᵉ[end-Int(floor(SIZE_Z/2)):end, end-Int(floor(SIZE_A/2)):end] .= 1
    while iter < 25
        # Update counter 
        iter = iter + 1

        # Compute μ and ϕ
        μ = zeros(size(M))
        ϕ = zeros(size(M))
        μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
        ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

        # Compute L and E 
        L = sum(ϕ) * par.Λ
        E = sum(μ) * par.Λ

        # Normalize distributions 
        μ .= μ./sum(μ)
        ϕ .= ϕ./sum(ϕ)
        
        # Solve the value functions and wage schedule
        V,U,Ũ,Π_j,L_j,W_p = solveVF_cf1(μ, ϕ, M, L, E, grid_z, grid_a, tol, par, ρᶻ)

        # Compute ranking of competitiveness
        ranking = ordinalrank(Π_j,rev=true)

        # Update entry 
        tρᵉ = zeros(size(M))
        tρᵉ[V .> Ũ] .= 1

        # Compute deviation
        dev = sum(abs.(tρᵉ-ρᵉ))

        if (dev == 0) | (prev == dev == 1)
            # ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ
            return tρᵉ,ρᶻ, Π_j, L_j, W_p, V, U, Ũ
        else
            # Mask changers and non-changers 
            mask = (ρᵉ .!== tρᵉ)
            mask_ = (ρᵉ .== tρᵉ)
            if (dev != prev) 
                # Obtain ranks of changers (and min and max)
                R = zeros(size(M))
                R[mask] .= ranking[mask]
                # display(R)
                maxr = maximum(R)
                R[mask_] .= Inf
                minr = minimum(R)
                # Set threshold
                splt = (maxr-minr)*(1/4) + minr
                # Set changers with productivity above threshold to 1, others to 0
                nmask = (R .<= splt)
                # display(mask)
                # display(nmask)
                R[mask_] .= -Inf
                nmask_ = (R .> splt)
                # display(nmask_)
                # Apply masks
                ρᵉ[nmask] .= 1.0
                ρᵉ[nmask_] .= 0.0
            else
                ρᵉ = deepcopy(tρᵉ)
            end
            # Update prev 
            prev = dev
            prev2 = prev
        end
    end
    return nothing
end

# Function to solve for equilibrium (fixed inv counterfactual)
function solve_cf1(Ψ::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, maxiter::Integer, ρᶻ)
    # Fix the seed 
    Random.seed!(8)

    # Initial guess of distribution = Psi 
    M = deepcopy(Ψ)

    # Size of grids 
    SIZE_Z = size(Ψ,1)
    SIZE_A = size(Ψ,2)

    # Loop 
    prev = 9e5
    for iter = 1:maxiter
        # Solve for entry
        ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ = solveEq_cf1(M, grid_z, grid_a, tol, par, ρᶻ)

        # Update M 
        # display(M)
        TM = deepcopy(M)
        for k = 1:100 
            TTM = zeros(size(M))
            for i = 1:SIZE_Z
                for j = 1:SIZE_A
                    if i == 1
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            (1 - par.p_n) * ((1-ρᶻ[i,j])*ρᵉ[i,j] + (1-ρᵉ[i,j])) * TM[i,j] + 
                            (1 - par.p_i) * (ρᶻ[i,j]*ρᵉ[i,j]) * TM[i,j] + 
                            (1 - par.p_n) * ((1-ρᶻ[i+1,j])*ρᵉ[i+1,j] + (1-ρᵉ[i+1,j])) * TM[i+1,j] + 
                            (1 - par.p_i) * (ρᶻ[i+1,j]*ρᵉ[i+1,j]) * TM[i+1,j]
                        )
                    elseif i == SIZE_Z
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            par.p_n * ((1-ρᶻ[i-1,j])*ρᵉ[i-1,j] + (1-ρᵉ[i-1,j])) * TM[i-1,j] + 
                            par.p_i * (ρᶻ[i-1,j]*ρᵉ[i-1,j]) * TM[i-1,j] + 
                            par.p_n * ((1-ρᶻ[i,j])*ρᵉ[i,j] + (1-ρᵉ[i,j])) * TM[i,j] + 
                            par.p_i * (ρᶻ[i,j]*ρᵉ[i,j]) * TM[i,j]
                        )
                    else
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            par.p_n * ((1-ρᶻ[i-1,j])*ρᵉ[i-1,j] + (1-ρᵉ[i-1,j])) * TM[i-1,j] + 
                            par.p_i * ρᶻ[i-1,j] * ρᵉ[i-1,j] * TM[i-1,j] + 
                            (1-par.p_n) * ((1-ρᶻ[i+1,j])*ρᵉ[i+1,j] + (1-ρᵉ[i+1,j])) * TM[i+1,j] + 
                            (1-par.p_i) * ρᶻ[i+1,j] * ρᵉ[i+1,j] * TM[i+1,j]
                        )
                    end
                end
            end
            TM = deepcopy(TTM)
        end
       
        # Update counter 
        iter = iter + 1

        # Check convergence 
        dev = sum(abs.(TM-M))
        println(dev)

        if dev < 0.05
            # println("EQUILIBRIUM FOUND")
            return M, ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ, TM
        else
            if (dev < prev) & ((dev/prev) < 0.9)
                M = deepcopy(TM)
            else
                M = 0.5 .* M + 0.5 .* TM 
            end
            prev = dev
        end
    end
end

# COUNTERFACTUAL WITH FIXED ENTRY POLICY
# Function to solve the VFs given some measures of workers/entrepreneurs 
function solveVF_cf2(μ::Matrix{Float64}, ϕ::Matrix{Float64}, M::Matrix{Float64}, L::Float64, E::Float64, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, ρᵉ::Array)
    # Size of grids
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Guess values
    V = ones(SIZE_Z, SIZE_A)
    Ũ = ones(SIZE_Z, SIZE_A)

    # Initial guess for wages 
    W_0 = ones(SIZE_Z, SIZE_A) .* 0.5

    # Tensor for amenities 
    A_j = repeat(grid_a',SIZE_Z)

    # Outer loop
    while true
        # Containers for updates 
        TV = zeros(SIZE_Z, SIZE_A)
        TŨ = zeros(SIZE_Z, SIZE_A)

        # Make ρᵉ+ 
        ρᵉ₊= zeros(SIZE_Z,SIZE_A)
        ρᵉ₊[1:SIZE_Z-1,:] .= ρᵉ[2:SIZE_Z,:]
        ρᵉ₊[SIZE_Z,:] .= ρᵉ[SIZE_Z,:]

        # Make ρᵉ₋
        ρᵉ₋= zeros(SIZE_Z,SIZE_A)
        ρᵉ₋[2:SIZE_Z,:].= ρᵉ[1:SIZE_Z-1,:]
        ρᵉ₋[1,:] .= ρᵉ[1,:]
        
        # Remove negatives in V 
        V .= max.(0, V)

        # Make V_+ 
        V₊ = zeros(SIZE_Z,SIZE_A)
        V₊[1:SIZE_Z-1,:] .= V[2:SIZE_Z,:]
        V₊[SIZE_Z,:] .= V[SIZE_Z,:]

        # Make V_-
        V₋ = zeros(SIZE_Z,SIZE_A)
        V₋[2:SIZE_Z,:] .= V[1:SIZE_Z-1,:]
        V₋[1,:] .= V[1,:]

        # Make Ũ₊
        Ũ₊ = zeros(SIZE_Z,SIZE_A)
        Ũ₊[1:SIZE_Z-1,:] .= Ũ[2:SIZE_Z,:]
        Ũ₊[SIZE_Z,:] .= Ũ[SIZE_Z,:]

        # Make Ũ₋
        Ũ₋ = zeros(SIZE_Z,SIZE_A)
        Ũ₋[2:SIZE_Z,:] .= Ũ[1:SIZE_Z-1,:]
        Ũ₋[1,:] .= Ũ[1,:]

        # Solve the firm problem given measures and value functions
        W_p,Π_j,L_j = solveWP_cf2(μ, ϕ, grid_z, grid_a, 1e-4, V, Ũ, W_0, par, L, E, V₊, V₋, Ũ₊, Ũ₋,A_j,ρᵉ)


        # Compute V^I and V^N 
        Vᴵ = ρᵉ .* (par.ϵᴸ .* u.(Π_j.- par.c_z)  .+ A_j) .+ par.β .* (1 .- par.δ) .* (par.p_i .* ( ρᵉ₊.*V₊ + (1 .- ρᵉ₊).* Ũ₊) .+ (1-par.p_i) .* (ρᵉ₋ .* V₋ .+ (1 .- ρᵉ₋).* Ũ₋) )
        Vᴺ = ρᵉ .* (par.ϵᴸ .*u.(Π_j)             .+ A_j) .+ par.β .* (1 .- par.δ) .* (par.p_n .* ( ρᵉ₊.*V₊ + (1 .- ρᵉ₊).* Ũ₊) .+ (1-par.p_n) .* (ρᵉ₋ .* V₋ .+ (1 .- ρᵉ₋).* Ũ₋) )


        # Update V
        TV .= ρᵉ .* max.(Vᴵ, Vᴺ) + (1 .- ρᵉ) .* 0


        # Compute  U 
        U = zeros(SIZE_Z,SIZE_A,SIZE_Z,SIZE_A)
        @simd for idx_zi=1:SIZE_Z
            for idx_ai=1:SIZE_A
                for idx_zj=1:SIZE_Z
                    for idx_aj=1:SIZE_A
                        a_j = grid_a[idx_aj]
                        if idx_zi == 1
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * (ρᵉ[idx_zi+1,idx_ai] .* V[idx_zi+1,idx_ai] + (1 .- ρᵉ[idx_zi+1,idx_ai]).*Ũ[idx_zi+1,idx_ai]) + (1-par.p_n)*(ρᵉ[idx_zi,idx_ai]   .* V[idx_zi,idx_ai]   + (1 .- ρᵉ[idx_zi,idx_ai])   .*Ũ[idx_zi,idx_ai]))
                        elseif idx_zi == SIZE_Z
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * (ρᵉ[idx_zi,idx_ai]  .* V[idx_zi,idx_ai]    + (1 .- ρᵉ[idx_zi,idx_ai])  .*Ũ[idx_zi,idx_ai])   + (1-par.p_n)*(ρᵉ[idx_zi-1,idx_ai] .* V[idx_zi-1,idx_ai] + (1 .- ρᵉ[idx_zi-1,idx_ai]) .*Ũ[idx_zi-1,idx_ai]))
                        else
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * (ρᵉ[idx_zi+1,idx_ai] .* V[idx_zi+1,idx_ai] + (1 .- ρᵉ[idx_zi+1,idx_ai]).*Ũ[idx_zi+1,idx_ai]) + (1-par.p_n)*(ρᵉ[idx_zi-1,idx_ai] .* V[idx_zi-1,idx_ai] + (1 .- ρᵉ[idx_zi-1,idx_ai]) .*Ũ[idx_zi-1,idx_ai]))
                        end
                    end
                end
            end
        end


        # Update Ũ
        @simd for idx_zi=1:SIZE_Z
            for idx_ai=1:SIZE_A
                TŨ[idx_zi,idx_ai] = par.ν * log(E * sum(exp.((U[idx_zi,idx_ai,:,:]) ./ (par.ν)) .* μ))
            end
        end     

        # Compute deviation 
        dev = maximum(abs.(TŨ-Ũ)) + maximum(abs.(TV-V))

        if dev > tol
            # Weight convergence
            Ũ = deepcopy(TŨ)
            V = deepcopy(TV)
            W_0 = deepcopy(W_p)
        else 
            # Return solutions
            return TV,U,TŨ,Π_j,L_j,W_p,Vᴵ,Vᴺ
        end
    end
end

# Function to solve for equilibrium entry given distribution (fixed entry counterfactual)
function solveEq_cf2(M::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, ρᵉ::Array)
    # Size of grids 
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Compute μ and ϕ
    μ = zeros(size(M))
    ϕ = zeros(size(M))
    μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
    ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

    # Compute L and E 
    L = sum(ϕ) * par.Λ
    E = sum(μ) * par.Λ
    
    # Normalize distributions 
    μ .= μ./sum(μ)
    ϕ .= ϕ./sum(ϕ)
        
    # Solve the value functions and wage schedule
    V,U,Ũ,Π_j,L_j,W_p,Vᴵ,Vᴺ = solveVF_cf2(μ, ϕ, M, L, E, grid_z, grid_a, tol, par, ρᵉ)

    # Compute hc investment policy 
    ρᶻ = (Vᴵ .> Vᴺ) .* ρᵉ

    return ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ
end


# Function to solve for equilibrium (fixed entry counterfactual)
function solve_cf2(Ψ::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, maxiter::Integer, ρᵉ)
    # Fix the seed 
    Random.seed!(8)

    # Initial guess of distribution = Psi 
    M = deepcopy(Ψ)

    # Size of grids 
    SIZE_Z = size(Ψ,1)
    SIZE_A = size(Ψ,2)

    # Loop 
    prev = 9e5
    for iter = 1:maxiter
        # Solve for entry
        ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ = solveEq_cf2(M, grid_z, grid_a, tol, par, ρᵉ)

        # Update M 
        # display(M)
        TM = deepcopy(M)
        for k = 1:100 
            TTM = zeros(size(M))
            for i = 1:SIZE_Z
                for j = 1:SIZE_A
                    if i == 1
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            (1 - par.p_n) * ((1-ρᶻ[i,j])*ρᵉ[i,j] + (1-ρᵉ[i,j])) * TM[i,j] + 
                            (1 - par.p_i) * (ρᶻ[i,j]*ρᵉ[i,j]) * TM[i,j] + 
                            (1 - par.p_n) * ((1-ρᶻ[i+1,j])*ρᵉ[i+1,j] + (1-ρᵉ[i+1,j])) * TM[i+1,j] + 
                            (1 - par.p_i) * (ρᶻ[i+1,j]*ρᵉ[i+1,j]) * TM[i+1,j]
                        )
                    elseif i == SIZE_Z
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            par.p_n * ((1-ρᶻ[i-1,j])*ρᵉ[i-1,j] + (1-ρᵉ[i-1,j])) * TM[i-1,j] + 
                            par.p_i * (ρᶻ[i-1,j]*ρᵉ[i-1,j]) * TM[i-1,j] + 
                            par.p_n * ((1-ρᶻ[i,j])*ρᵉ[i,j] + (1-ρᵉ[i,j])) * TM[i,j] + 
                            par.p_i * (ρᶻ[i,j]*ρᵉ[i,j]) * TM[i,j]
                        )
                    else
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            par.p_n * ((1-ρᶻ[i-1,j])*ρᵉ[i-1,j] + (1-ρᵉ[i-1,j])) * TM[i-1,j] + 
                            par.p_i * ρᶻ[i-1,j] * ρᵉ[i-1,j] * TM[i-1,j] + 
                            (1-par.p_n) * ((1-ρᶻ[i+1,j])*ρᵉ[i+1,j] + (1-ρᵉ[i+1,j])) * TM[i+1,j] + 
                            (1-par.p_i) * ρᶻ[i+1,j] * ρᵉ[i+1,j] * TM[i+1,j]
                        )
                    end
                end
            end
            TM = deepcopy(TTM)
        end
       
        # Update counter 
        iter = iter + 1

        # Check convergence 
        dev = sum(abs.(TM-M))
        println(dev)

        if dev < tol
            # println("EQUILIBRIUM FOUND")
            return M, ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ, TM
        else
            if (dev < prev) & ((dev/prev) < 0.9)
                M = deepcopy(TM)
            else
                M = 0.9 .* M + 0.1 .* TM 
            end
            prev = dev
        end
    end
end

### COUNTERFACTUAL WITH FIXED ENTRY AND FIXED INVESTMENT POLICY
function solveVF_cf3(μ::Matrix{Float64}, ϕ::Matrix{Float64}, M::Matrix{Float64}, L::Float64, E::Float64, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, ρᵉ::Array, ρᶻ::Array)
    # Size of grids
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Guess values
    V = ones(SIZE_Z, SIZE_A)
    Ũ = ones(SIZE_Z, SIZE_A)

    # Initial guess for wages 
    W_0 = ones(SIZE_Z, SIZE_A) .* 0.5

    # Tensor for amenities 
    A_j = repeat(grid_a',SIZE_Z)

    # Outer loop
    while true
        # Containers for updates 
        TV = zeros(SIZE_Z, SIZE_A)
        TŨ = zeros(SIZE_Z, SIZE_A)

        # Make ρᵉ+ 
        ρᵉ₊= zeros(SIZE_Z,SIZE_A)
        ρᵉ₊[1:SIZE_Z-1,:] .= ρᵉ[2:SIZE_Z,:]
        ρᵉ₊[SIZE_Z,:] .= ρᵉ[SIZE_Z,:]

        # Make ρᵉ₋
        ρᵉ₋= zeros(SIZE_Z,SIZE_A)
        ρᵉ₋[2:SIZE_Z,:].= ρᵉ[1:SIZE_Z-1,:]
        ρᵉ₋[1,:] .= ρᵉ[1,:]


        # Make V_+ 
        V₊ = zeros(SIZE_Z,SIZE_A)
        V₊[1:SIZE_Z-1,:] .= V[2:SIZE_Z,:]
        V₊[SIZE_Z,:] .= V[SIZE_Z,:]

        # Make V_-
        V₋ = zeros(SIZE_Z,SIZE_A)
        V₋[2:SIZE_Z,:] .= V[1:SIZE_Z-1,:]
        V₋[1,:] .= V[1,:]

        # Make Ũ₊
        Ũ₊ = zeros(SIZE_Z,SIZE_A)
        Ũ₊[1:SIZE_Z-1,:] .= Ũ[2:SIZE_Z,:]
        Ũ₊[SIZE_Z,:] .= Ũ[SIZE_Z,:]

        # Make Ũ₋
        Ũ₋ = zeros(SIZE_Z,SIZE_A)
        Ũ₋[2:SIZE_Z,:] .= Ũ[1:SIZE_Z-1,:]
        Ũ₋[1,:] .= Ũ[1,:]

        # Solve the firm problem given measures and value functions
        W_p,Π_j,L_j = solveWP_cf2(μ, ϕ, grid_z, grid_a, 1e-4, V, Ũ, W_0, par, L, E, V₊, V₋, Ũ₊, Ũ₋,A_j,ρᵉ)

        # Compute V^I and V^N 
        Vᴵ = ρᵉ .* (par.ϵᴸ .* u.(Π_j.- par.c_z)  .+ A_j .+ par.β .* (1 .- par.δ) .* (par.p_i .* (ρᵉ₊ .*V₊ + (1 .- ρᵉ₊).* Ũ₊) .+ (1 -par.p_i) .* (ρᵉ₋ .*V₋ +(1 .- ρᵉ₋).* Ũ₋)) )
        Vᴺ = ρᵉ .* (par.ϵᴸ .* u.(Π_j)            .+ A_j .+ par.β .* (1 .- par.δ) .* (par.p_n .* (ρᵉ₊ .*V₊ + (1 .- ρᵉ₊).* Ũ₊) .+ (1 -par.p_n) .* (ρᵉ₋ .*V₋ +(1 .- ρᵉ₋).* Ũ₋)) )

        # Update V
        TV .= max.( 0,  ρᵉ .*(ρᶻ .* Vᴵ .+ (1 .- ρᶻ) .* Vᴺ)  + (1 .-ρᵉ) .* 0 )


        # Compute  U 
        U = zeros(SIZE_Z,SIZE_A,SIZE_Z,SIZE_A)
        @simd for idx_zi=1:SIZE_Z
            for idx_ai=1:SIZE_A
                for idx_zj=1:SIZE_Z
                    for idx_aj=1:SIZE_A
                        a_j = grid_a[idx_aj]
                        if idx_zi == 1
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * ((ρᵉ[idx_zi+1,idx_ai]) .* V[idx_zi+1,idx_ai] + (1 .- (ρᵉ[idx_zi+1,idx_ai])).*Ũ[idx_zi+1,idx_ai]) + (1-par.p_n)*((ρᵉ[idx_zi,idx_ai])   .* V[idx_zi,idx_ai]   + (1 .-(ρᵉ[idx_zi,idx_ai]))  .*Ũ[idx_zi,idx_ai]))
                        elseif idx_zi == SIZE_Z
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * ((ρᵉ[idx_zi,idx_ai])   .* V[idx_zi,idx_ai]   + (1 .- (ρᵉ[idx_zi,idx_ai]))  .*Ũ[idx_zi,idx_ai])   + (1-par.p_n)*((ρᵉ[idx_zi-1,idx_ai]) .* V[idx_zi-1,idx_ai] + (1 .-(ρᵉ[idx_zi-1,idx_ai])).*Ũ[idx_zi-1,idx_ai]))
                        else
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * ((ρᵉ[idx_zi+1,idx_ai]) .* V[idx_zi+1,idx_ai] + (1 .- (ρᵉ[idx_zi+1,idx_ai])).*Ũ[idx_zi+1,idx_ai]) + (1-par.p_n)*((ρᵉ[idx_zi-1,idx_ai]) .* V[idx_zi-1,idx_ai] + (1 .-(ρᵉ[idx_zi-1,idx_ai])).*Ũ[idx_zi-1,idx_ai]))
                        end
                     end
                end
            end
        end

        # Update Ũ
        @simd for idx_zi=1:SIZE_Z
            for idx_ai=1:SIZE_A
                TŨ[idx_zi,idx_ai] = par.ν * log(E * sum(exp.((U[idx_zi,idx_ai,:,:]) ./ (par.ν)) .* μ))
            end
        end 
        
        # Compute deviation 
        dev = maximum(abs.(TŨ-Ũ)) + maximum(abs.(TV-V))
        if dev > tol
            # Weight convergence
            Ũ =  deepcopy(TŨ)  
            V =  deepcopy(TV)  
            W_0 =  deepcopy(W_p)
            
        else 
            # Return solutions
            return TV,U,TŨ,Π_j,L_j,W_p
        end
    end
end


# Function to solve for equilibrium entry given distribution (fixed entry and inv counterfactual)
function solveEq_cf3(M::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, ρᵉ::Array, ρᶻ::Array)
    # Size of grids 
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Compute μ and ϕ
    μ = zeros(size(M))
    ϕ = zeros(size(M))
    μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
    ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

    # Compute L and E 
    L = sum(ϕ) * par.Λ
    E = sum(μ) * par.Λ
    
    # Normalize distributions 
    μ .= μ./sum(μ)
    ϕ .= ϕ./sum(ϕ)
        
    # Solve the value functions and wage schedule
    V,U,Ũ,Π_j,L_j,W_p = solveVF_cf3(μ, ϕ, M, L, E, grid_z, grid_a, tol, par, ρᵉ,ρᶻ)

    return ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ
end

# Function to solve for equilibrium (fixed entry and inv counterfactual)
function solve_cf3(Ψ::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, maxiter::Integer, ρᵉ, ρᶻ)
    # Fix the seed 
    Random.seed!(8)

    # Initial guess of distribution = Psi 
    M = deepcopy(Ψ)

    # Size of grids 
    SIZE_Z = size(Ψ,1)
    SIZE_A = size(Ψ,2)

    # Loop 
    prev = 9e5
    for iter = 1:maxiter
        # Update M 
        # display(M)
        TM = deepcopy(M)
        for k = 1:100 
            TTM = zeros(size(M))
            for i = 1:SIZE_Z
                for j = 1:SIZE_A
                    if i == 1
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            (1 - par.p_n) * ((1-ρᶻ[i,j])*ρᵉ[i,j] + (1-ρᵉ[i,j])) * TM[i,j] + 
                            (1 - par.p_i) * (ρᶻ[i,j]*ρᵉ[i,j]) * TM[i,j] + 
                            (1 - par.p_n) * ((1-ρᶻ[i+1,j])*ρᵉ[i+1,j] + (1-ρᵉ[i+1,j])) * TM[i+1,j] + 
                            (1 - par.p_i) * (ρᶻ[i+1,j]*ρᵉ[i+1,j]) * TM[i+1,j]
                        )
                    elseif i == SIZE_Z
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            par.p_n * ((1-ρᶻ[i-1,j])*ρᵉ[i-1,j] + (1-ρᵉ[i-1,j])) * TM[i-1,j] + 
                            par.p_i * (ρᶻ[i-1,j]*ρᵉ[i-1,j]) * TM[i-1,j] + 
                            par.p_n * ((1-ρᶻ[i,j])*ρᵉ[i,j] + (1-ρᵉ[i,j])) * TM[i,j] + 
                            par.p_i * (ρᶻ[i,j]*ρᵉ[i,j]) * TM[i,j]
                        )
                    else
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            par.p_n * ((1-ρᶻ[i-1,j])*ρᵉ[i-1,j] + (1-ρᵉ[i-1,j])) * TM[i-1,j] + 
                            par.p_i * ρᶻ[i-1,j] * ρᵉ[i-1,j] * TM[i-1,j] + 
                            (1-par.p_n) * ((1-ρᶻ[i+1,j])*ρᵉ[i+1,j] + (1-ρᵉ[i+1,j])) * TM[i+1,j] + 
                            (1-par.p_i) * ρᶻ[i+1,j] * ρᵉ[i+1,j] * TM[i+1,j]
                        )
                    end
                end
            end
            TM = deepcopy(TTM)
        end
       
        # Update counter 
        iter = iter + 1

        # Check convergence 
        dev = sum(abs.(TM-M))
        println(dev)

        if dev < tol
            return M, ρᵉ, ρᶻ, TM
        else
            if (dev < prev) & ((dev/prev) < 0.9)
                M = deepcopy(TM)
            else
                M = 0.5 .* M + 0.5 .* TM 
            end
            prev = dev
        end
    end
end


