#= 
Functions to solve the model. All these are accessed through the entry point 
model_solve.jl

Tristany Armangué-Jubert, Nov 2023
=#

### DEPENDENCIES
using Distributions, Random, Optim, StatsBase, DataFrames, CSV, Tables
include("counterfactual_functions_xi.jl")

# Utility function 
function u(x) 
    if x <= 0
        return -9e10
    else
        return log(x)
    end
end

# Production function
function F(l)
    if l <= 0
        return 0
    else
        return l^par.xi
    end
end

# Function to find bounds for discretization of the dist 
function dist_bounds(par,N_a)
    # Productivity
    DistZ = LogNormal(par.μ_z, par.σ_z)
    lbz = quantile(DistZ,0.01)
    ubz = quantile(DistZ,0.99)
    # Amenities
    DistA = LogNormal(par.μ_a, par.σ_a)
    lba = quantile(DistA,(1/N_a))
    uba = quantile(DistA,0.98)
    # Return 
    return lbz,ubz,lba,uba
end

# Function to discretize the joint distribution
function discretize_joint_dist(dist, p, N_a, bounds)
    # Bounds 
    z_l = bounds[1]
    z_h = bounds[2] * 1.5
    a_l = bounds[3]
    a_h = bounds[4]

    # Make grid for z 
    N = Int(ceil((log(z_h)-log(z_l))/(log(1+p))))+1
    range_z = zeros(N)
    range_z[1] = z_l 
    for i=2:N 
        range_z[i] = range_z[i-1] * (1+p)
    end
    
    # Make grid for a
    range_a = collect(LinRange(a_l,a_h,N_a))

    # Make distribution
    Z = [pdf(dist,[i,j]) for i in range_z, j in range_a]

    # Normalize 
    Z = Z ./ sum(Z)
    
    # Return objects 
    return range_z,range_a,Z
end

# Function to solve the wage FP given some stationary distribution, some entry rule and some value functions 
function solveWP(μ::Matrix{Float64}, ϕ::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, V::Matrix{Float64}, Ũ::Matrix{Float64}, W_0::Matrix{Float64}, par, L::Float64, E::Float64, V₊::Matrix{Float64}, V₋::Matrix{Float64}, Ũ₊::Matrix{Float64}, Ũ₋::Matrix{Float64}, A_j::Matrix{Float64})
    # Size of grids
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Container for updated wages 
    W_p = zeros(SIZE_Z, SIZE_A)

    # Compute Lambda
    λ = zeros(SIZE_Z,SIZE_A)
    @simd for idx_zi=1:SIZE_Z
        for idx_ai=1:SIZE_A
            λ[idx_zi,idx_ai] += (E * sum(exp.((par.ϵᴸ .* u.(W_0) .+ A_j .+ par.β .* (1-par.δ) .* (par.p_n .* max(V₊[idx_zi,idx_ai],Ũ₊[idx_zi,idx_ai]) + (1-par.p_n) .* max(V₋[idx_zi,idx_ai],Ũ₋[idx_zi,idx_ai]))) ./ (par.ν)) .* μ))
        end
    end

    # Compute xi 
    ξ = zeros(SIZE_Z,SIZE_A)
    @simd for idx_zi=1:SIZE_Z
        for idx_ai=1:SIZE_A
            ξ[idx_zi,idx_ai] += exp.(par.β .* (1-par.δ) .* (par.p_n .* max(V₊[idx_zi,idx_ai],Ũ₊[idx_zi,idx_ai]) + (1-par.p_n) .* max(V₋[idx_zi,idx_ai],Ũ₋[idx_zi,idx_ai])) ./ (par.ν))
        end
    end

    # Compute Theta 
    Θ = sum((ξ ./ λ) .* ϕ)

    # Solve the problem of the firms
    Π_j = zeros(SIZE_Z,SIZE_A)
    L_j = zeros(SIZE_Z, SIZE_A)
    @simd for idx_zj = 1:SIZE_Z
        z_j = grid_z[idx_zj]
        for idx_aj = 1:SIZE_A
            a_j = grid_a[idx_aj]
            # Save updated wage
            #wj = exp(
            #    (1/(1+par.ϵᴸ)) * (log(z_j) - a_j + log((par.ϵᴸ/(1+par.ϵᴸ))) - log(L) - log(Θ))
            #)
            wj = exp( 
                1/par.ϵᴸ*(  (par.ϵᴸ/(1+ (1-par.xi)*par.ϵᴸ)) * log( par.A .* z_j) - ((1-par.xi)*par.ϵᴸ/(1+ (1-par.xi)*par.ϵᴸ))*a_j + 1/(1+ (1-par.xi)*par.ϵᴸ)*(par.ϵᴸ*log(par.xi) + log(L) + log(Θ) - par.ϵᴸ*log((1+par.ϵᴸ)/par.ϵᴸ)) -log(L) - log(Θ) )
            )
            W_p[idx_zj,idx_aj] += wj
            L_j[idx_zj,idx_aj] += L * Θ * exp(par.ϵᴸ * log(wj) + a_j)
            Π_j[idx_zj,idx_aj] += par.A * z_j * F(L_j[idx_zj,idx_aj]) - wj * L_j[idx_zj,idx_aj] - par.c 
        end
    end
    # Return solutions
    return W_p,Π_j,L_j
end



# Function to solve the wage FP given some stationary distribution, some entry rule and some value functions 
function solveWP_cf2(μ::Matrix{Float64}, ϕ::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, V::Matrix{Float64}, Ũ::Matrix{Float64}, W_0::Matrix{Float64}, par, L::Float64, E::Float64, V₊::Matrix{Float64}, V₋::Matrix{Float64}, Ũ₊::Matrix{Float64}, Ũ₋::Matrix{Float64}, A_j::Matrix{Float64}, ρᵉ::Array)
    # Size of grids
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Container for updated wages 
    W_p = zeros(SIZE_Z, SIZE_A)

    # Compute Lambda
    λ = zeros(SIZE_Z,SIZE_A)
    @simd for idx_zi=1:SIZE_Z
        for idx_ai=1:SIZE_A
            λ[idx_zi,idx_ai] += (E * sum(exp.((par.ϵᴸ .* u.(W_0)   .+ A_j   .+ par.β .* (1-par.δ) .* (par.p_n .* (ρᵉ[idx_zi,idx_ai].* V₊[idx_zi,idx_ai] + (1 .-ρᵉ[idx_zi,idx_ai]).*Ũ₊[idx_zi,idx_ai]) + (1-par.p_n) .* (ρᵉ[idx_zi,idx_ai].*V₋[idx_zi,idx_ai] + (1 .- ρᵉ[idx_zi,idx_ai]).*Ũ₋[idx_zi,idx_ai]))) ./ (par.ν)) .* μ))
        end
    end

    # Compute xi 
    ξ = zeros(SIZE_Z,SIZE_A)
    @simd for idx_zi=1:SIZE_Z
        for idx_ai=1:SIZE_A
            ξ[idx_zi,idx_ai] += exp.(par.β .* (1-par.δ) .* (par.p_n .* (ρᵉ[idx_zi,idx_ai].* V₊[idx_zi,idx_ai] + (1 .-ρᵉ[idx_zi,idx_ai]).*Ũ₊[idx_zi,idx_ai]) + (1-par.p_n) .* (ρᵉ[idx_zi,idx_ai].*V₋[idx_zi,idx_ai] + (1 .-ρᵉ[idx_zi,idx_ai]).*Ũ₋[idx_zi,idx_ai])) ./ (par.ν))
        end
    end

    # Compute Theta 
    Θ = sum((ξ ./ λ) .* ϕ)


    # Solve the problem of the firms
    Π_j = zeros(SIZE_Z,SIZE_A)
    L_j = zeros(SIZE_Z, SIZE_A)
    @simd for idx_zj = 1:SIZE_Z
        z_j = grid_z[idx_zj]
        for idx_aj = 1:SIZE_A
            a_j = grid_a[idx_aj]
            # Save updated wage
            wj = exp( 
                1/par.ϵᴸ*(  (par.ϵᴸ/(1+ (1-par.xi)*par.ϵᴸ)) * log( par.A .* z_j) - ((1-par.xi)*par.ϵᴸ/(1+ (1-par.xi)*par.ϵᴸ))*a_j + 1/(1+ (1-par.xi)*par.ϵᴸ)*(par.ϵᴸ*log(par.xi) + log(L) + log(Θ) - par.ϵᴸ*log((1+par.ϵᴸ)/par.ϵᴸ)) -log(L) - log(Θ) )
            )
            W_p[idx_zj,idx_aj] += wj
            L_j[idx_zj,idx_aj] += L * Θ * exp(par.ϵᴸ * log(wj) + a_j)
            Π_j[idx_zj,idx_aj] += par.A * z_j * F(L_j[idx_zj,idx_aj]) - wj * L_j[idx_zj,idx_aj] - par.c 
        end
    end

    # Return solutions
    return W_p,Π_j,L_j
end



# Function to solve the VFs given some measures of workers/entrepreneurs 
function solveVF(μ::Matrix{Float64}, ϕ::Matrix{Float64}, M::Matrix{Float64}, L::Float64, E::Float64, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par)
    # Size of grids
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Guess values
    V = ones(SIZE_Z, SIZE_A)
    Ũ = ones(SIZE_Z, SIZE_A)

    # Initial guess for wages 
    W_0 = ones(SIZE_Z, SIZE_A) .* 0.5

    # Tensor for amenities 
    A_j = repeat(grid_a',SIZE_Z)

    # Outer loop
    while true
        # Containers for updates 
        TV = zeros(SIZE_Z, SIZE_A)
        TŨ = zeros(SIZE_Z, SIZE_A)

        # Make V_+ 
        V₊ = zeros(SIZE_Z,SIZE_A)
        V₊[1:SIZE_Z-1,:] .= V[2:SIZE_Z,:]
        V₊[SIZE_Z,:] .= V[SIZE_Z,:]

        # Make V_-
        V₋ = zeros(SIZE_Z,SIZE_A)
        V₋[2:SIZE_Z,:] .= V[1:SIZE_Z-1,:]
        V₋[1,:] .= V[1,:]

        # Make Ũ₊
        Ũ₊ = zeros(SIZE_Z,SIZE_A)
        Ũ₊[1:SIZE_Z-1,:] .= Ũ[2:SIZE_Z,:]
        Ũ₊[SIZE_Z,:] .= Ũ[SIZE_Z,:]

        # Make Ũ₋
        Ũ₋ = zeros(SIZE_Z,SIZE_A)
        Ũ₋[2:SIZE_Z,:] .= Ũ[1:SIZE_Z-1,:]
        Ũ₋[1,:] .= Ũ[1,:]

        # Solve the firm problem given measures and value functions
        W_p,Π_j,L_j = solveWP(μ, ϕ, grid_z, grid_a, 1e-4, V, Ũ, W_0, par, L, E, V₊, V₋, Ũ₊, Ũ₋,A_j)

        # Compute V^I and V^N 
        Vᴵ = par.ϵᴸ .* u.(Π_j.- par.c_z)  .+ A_j .+ par.β .* (1 .- par.δ) .* (par.p_i .* max.(V₊, Ũ₊) .+ (1-par.p_i) .* max.(V₋, Ũ₋))
        Vᴺ = par.ϵᴸ .* u.(Π_j)            .+ A_j .+ par.β .* (1 .- par.δ) .* (par.p_n .* max.(V₊, Ũ₊) .+ (1-par.p_n) .* max.(V₋, Ũ₋))

        # Update V
        TV .= max.(Vᴵ, Vᴺ)

        # Compute  U 
        U = zeros(SIZE_Z,SIZE_A,SIZE_Z,SIZE_A)
        @simd for idx_zi=1:SIZE_Z
            for idx_ai=1:SIZE_A
                for idx_zj=1:SIZE_Z
                    for idx_aj=1:SIZE_A
                        a_j = grid_a[idx_aj]
                        if idx_zi == 1
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * max(V[idx_zi+1,idx_ai],Ũ[idx_zi+1,idx_ai]) + (1-par.p_n)*max(V[idx_zi,idx_ai],Ũ[idx_zi,idx_ai]))
                        elseif idx_zi == SIZE_Z
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * max(V[idx_zi,idx_ai],Ũ[idx_zi,idx_ai]) + (1-par.p_n)*max(V[idx_zi-1,idx_ai],Ũ[idx_zi-1,idx_ai]))
                        else
                            U[idx_zi,idx_ai,idx_zj,idx_aj] += par.ϵᴸ * u(W_p[idx_zj,idx_aj]) + a_j + par.β * (1-par.δ) * (par.p_n * max(V[idx_zi+1,idx_ai],Ũ[idx_zi+1,idx_ai]) + (1-par.p_n)*max(V[idx_zi-1,idx_ai],Ũ[idx_zi-1,idx_ai]))
                        end
                    end
                end
            end
        end

        # Update Ũ
        @simd for idx_zi=1:SIZE_Z
            for idx_ai=1:SIZE_A
                TŨ[idx_zi,idx_ai] = par.ν * log(E * sum(exp.((U[idx_zi,idx_ai,:,:]) ./ (par.ν)) .* μ))
            end
        end 
        
        # Compute deviation 
        dev = maximum(abs.(TŨ-Ũ)) + maximum(abs.(TV-V))

        if dev > tol
            # Weight convergence
            Ũ = deepcopy(TŨ)
            V = deepcopy(TV)
            W_0 = deepcopy(W_p)
        else 
            # Return solutions
            return TV,U,TŨ,Π_j,L_j,W_p,Vᴵ,Vᴺ
        end
    end
end

# Function to solve for equilibrium entry given distribution 
function solveEq(M::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par)
    # Size of grids 
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Initial guess  
    ρᵉ = zeros(SIZE_Z,SIZE_A)
    ρᵉ[end-Int(floor(SIZE_Z/2)):end, end-Int(floor(SIZE_A/2)):end] .= 1

    # Outer loop 
    iter = 0
    prev = 9e10
    prev2 = 9e9
    while iter < 25
        # Update counter 
        iter = iter + 1

        # Compute μ and ϕ
        μ = zeros(size(M))
        ϕ = zeros(size(M))
        μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
        ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

        # Compute L and E 
        L = sum(ϕ) * par.Λ
        E = sum(μ) * par.Λ

        # Normalize distributions 
        μ .= μ./sum(μ)
        ϕ .= ϕ./sum(ϕ)
        
        # Solve the value functions and wage schedule
        V,U,Ũ,Π_j,L_j,W_p,Vᴵ,Vᴺ = solveVF(μ, ϕ, M, L, E, grid_z, grid_a, tol, par)

        # Compute ranking of competitiveness
        ranking = ordinalrank(Π_j,rev=true)

        # Update entry 
        tρᵉ = zeros(size(M))
        tρᵉ[V .> Ũ] .= 1

        # Compute hc investment policy 
        ρᶻ = (Vᴵ .> Vᴺ) .* tρᵉ

        # Compute deviation
        dev = sum(abs.(tρᵉ-ρᵉ))

        if (dev == 0) | (prev == dev == 1) | ((dev == prev2 == 1) & (prev == 2))
            # ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ
            return tρᵉ,ρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ
        else
            # Mask changers and non-changers 
            mask = (ρᵉ .!== tρᵉ)
            mask_ = (ρᵉ .== tρᵉ)
            if (dev != prev) 
                # Obtain ranks of changers (and min and max)
                R = zeros(size(M))
                R[mask] .= ranking[mask]
                # display(R)
                maxr = maximum(R)
                R[mask_] .= Inf
                minr = minimum(R)
                # Set threshold
                splt = (maxr-minr)*(3/4) + minr
                # Set changers with productivity above threshold to 1, others to 0
                nmask = (R .<= splt)
                # display(mask)
                # display(nmask)
                R[mask_] .= -Inf
                nmask_ = (R .> splt)
                # display(nmask_)
                # Apply masks
                ρᵉ[nmask] .= 1.0
                ρᵉ[nmask_] .= 0.0
            else
                ρᵉ = deepcopy(tρᵉ)
            end
            
        end
        # Update prev 
        prev2 = prev
        prev = dev
    end
    # Catch non-convergence
    iter = 0
    prev = 9e10
    prev2 = 9e9
    # Initial guess  
    ρᵉ = zeros(SIZE_Z,SIZE_A)
    ρᵉ[end-Int(floor(SIZE_Z/2)):end, end-Int(floor(SIZE_A/2)):end] .= 1
    while iter < 35
        # Update counter 
        iter = iter + 1

        # Compute μ and ϕ
        μ = zeros(size(M))
        ϕ = zeros(size(M))
        μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
        ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

        # Compute L and E 
        L = sum(ϕ) * par.Λ
        E = sum(μ) * par.Λ

        # Normalize distributions 
        μ .= μ./sum(μ)
        ϕ .= ϕ./sum(ϕ)
        
        # Solve the value functions and wage schedule
        V,U,Ũ,Π_j,L_j,W_p,Vᴵ,Vᴺ = solveVF(μ, ϕ, M, L, E, grid_z, grid_a, tol, par)

        # Compute ranking of competitiveness
        ranking = ordinalrank(Π_j,rev=true)

        # Update entry 
        tρᵉ = zeros(size(M))
        tρᵉ[V .> Ũ] .= 1

        # Compute hc investment policy 
        ρᶻ = (Vᴵ .> Vᴺ) .* tρᵉ

        # Compute deviation
        dev = sum(abs.(tρᵉ-ρᵉ))
        if (dev == 0) | (prev == dev == 1) | ((dev == prev2 == 1) & (prev == 2))
            # ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ
            return tρᵉ,ρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ
        else
            # Mask changers and non-changers 
            mask = (ρᵉ .!== tρᵉ)
            mask_ = (ρᵉ .== tρᵉ)
            if (dev != prev) 
                # Obtain ranks of changers (and min and max)
                R = zeros(size(M))
                R[mask] .= ranking[mask]
                # display(R)
                maxr = maximum(R)
                R[mask_] .= Inf
                minr = minimum(R)
                # Set threshold
                splt = (maxr-minr)*(1/4) + minr
                # Set changers with productivity above threshold to 1, others to 0
                nmask = (R .<= splt)
                # display(mask)
                # display(nmask)
                R[mask_] .= -Inf
                nmask_ = (R .> splt)
                # display(nmask_)
                # Apply masks
                ρᵉ[nmask] .= 1.0
                ρᵉ[nmask_] .= 0.0
            else
                ρᵉ = deepcopy(tρᵉ)
            end
            # Update prev 
            prev2 = prev
            prev = dev
        end
    end
    return nothing
end

# Function to solve for equilibrium
function solve(Ψ::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, tol::Float64, par, maxiter::Integer)
    # Fix the seed 
    Random.seed!(8)

    # Initial guess of distribution = Psi 
    M = deepcopy(Ψ)

    # Size of grids 
    SIZE_Z = size(Ψ,1)
    SIZE_A = size(Ψ,2)

    # Loop 
    prev = 9e5
    for iter = 1:maxiter
        # println("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
        # Solve for entry
        ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ = solveEq(M, grid_z, grid_a, 1e-5, par)

        # Update M 
        # display(M)
        TM = deepcopy(M)
        for k = 1:100 
            TTM = zeros(size(M))
            for i = 1:SIZE_Z
                for j = 1:SIZE_A
                    if i == 1
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            (1 - par.p_n) * ((1-ρᶻ[i,j])*ρᵉ[i,j] + (1-ρᵉ[i,j])) * TM[i,j] + 
                            (1 - par.p_i) * (ρᶻ[i,j]*ρᵉ[i,j]) * TM[i,j] + 
                            (1 - par.p_n) * ((1-ρᶻ[i+1,j])*ρᵉ[i+1,j] + (1-ρᵉ[i+1,j])) * TM[i+1,j] + 
                            (1 - par.p_i) * (ρᶻ[i+1,j]*ρᵉ[i+1,j]) * TM[i+1,j]
                        )
                    elseif i == SIZE_Z
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            par.p_n * ((1-ρᶻ[i-1,j])*ρᵉ[i-1,j] + (1-ρᵉ[i-1,j])) * TM[i-1,j] + 
                            par.p_i * (ρᶻ[i-1,j]*ρᵉ[i-1,j]) * TM[i-1,j] + 
                            par.p_n * ((1-ρᶻ[i,j])*ρᵉ[i,j] + (1-ρᵉ[i,j])) * TM[i,j] + 
                            par.p_i * (ρᶻ[i,j]*ρᵉ[i,j]) * TM[i,j]
                        )
                    else
                        TTM[i,j] = par.δ * Ψ[i,j] + (1-par.δ) * (
                            par.p_n * ((1-ρᶻ[i-1,j])*ρᵉ[i-1,j] + (1-ρᵉ[i-1,j])) * TM[i-1,j] + 
                            par.p_i * ρᶻ[i-1,j] * ρᵉ[i-1,j] * TM[i-1,j] + 
                            (1-par.p_n) * ((1-ρᶻ[i+1,j])*ρᵉ[i+1,j] + (1-ρᵉ[i+1,j])) * TM[i+1,j] + 
                            (1-par.p_i) * ρᶻ[i+1,j] * ρᵉ[i+1,j] * TM[i+1,j]
                        )
                    end
                end
            end
            TM = deepcopy(TTM)
        end
       
        # Update counter 
        iter = iter + 1

        # Check convergence 
        dev = sum(abs.(TM-M))

        if dev < tol
            # println("EQUILIBRIUM FOUND")
            return M, ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        else
            if (dev < prev) & ((dev/prev) < 0.9)
                M = deepcopy(TM)
            else
                M = 0.5 .* M + 0.5 .* TM 
            end
            prev = dev
        end
    end
end

# Function to compute moments at given equilibrium
function computeMoments(M::Matrix{Float64}, Ψ, grid_z::Vector{Float64}, grid_a::Vector{Float64}, ρᵉ::Matrix{Float64}, ρᶻ::Matrix{Float64}, Π_j::Matrix{Float64}, L_j::Matrix{Float64}, W_p::Matrix{Float64}, par, ext, rtn_gdp=false)
    # Size of grids 
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Compute μ and ϕ
    μ = zeros(size(M))
    ϕ = zeros(size(M))
    μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
    ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

    # Compute L and E 
    L = sum(ϕ)
    E = sum(μ)
    
    # Normalize distributions 
    μ .= μ./sum(μ)
    ϕ .= ϕ./sum(ϕ)

    # Compute employment shares 
    lj = L_j[ρᵉ .== 1] ./ sum(L_j[ρᵉ .== 1])

    # Compute Average Log Wage 
    mean_log_wage = sum((log.(W_p[ρᵉ .== 1])) .* lj)

    # Mean firm size conditional on 5+ and unconditional
    mean_firm_size = sum(L_j[(ρᵉ .== 1)] .* (μ[(ρᵉ .== 1)] ./ sum(μ[(ρᵉ .== 1)])))
    mean_firm_size_cond = sum(L_j[(ρᵉ .== 1) .& (L_j .>= 5)] .* (μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)])))

    # Firm size dispersion conditional on 5+ and unconditional 
    std_log_firm_size = std(log.(L_j[(ρᵉ .== 1)]), AnalyticWeights(μ[(ρᵉ .== 1)]))
    std_log_firm_size_cond = std(log.(L_j[(ρᵉ .== 1) .& (L_j .>= 5)]), AnalyticWeights(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))

    # Share invest in R&D 
    share_entrep_invest = sum((ρᶻ[(ρᵉ .== 1)] .==1 ) .* (μ[(ρᵉ .== 1)] ./ sum(μ[(ρᵉ .== 1)])))
    share_entrep_invest_cond = sum((ρᶻ[(ρᵉ .== 1) .& (L_j .>= 5)] .==1 ) .* (μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)])))

    # Wage dispersion 
    std_log_wage = std(log.(W_p[(ρᵉ .== 1)]), AnalyticWeights((μ[(ρᵉ .== 1)] ./ sum(μ[(ρᵉ .== 1)]))))
    std_log_wage_cond = std(log.(W_p[(ρᵉ .== 1) .& (L_j .>= 5)]), AnalyticWeights((μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))))

    # Firm size wage premium 
    Y = vec(log.(W_p[(ρᵉ .== 1) .& (L_j .>= 5)]))
    X = vec(log.(L_j[(ρᵉ .== 1) .& (L_j .>= 5)]))
    varcovar = cov([Y X], AnalyticWeights((μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))))
    beta_hat_fswp = varcovar[1,2] / var(X,AnalyticWeights((μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))))

    ### DYNAMIC MOMENTS ###
    # Simulate 10000 agents
    mean_firm_age,mean_firm_age_cut, mean_empl_gro, fs_rd_premium, mean_firm_size_,share_firms_30a, share_firms_30_60a, share_firms_60a, logwFE_amen_cov = simulate_paths(Ψ, ρᵉ, ρᶻ, par, L_j, grid_z, grid_a, 10000, 5000, ext)

    # Compute GDP
    Z_j = repeat(grid_z',SIZE_A)'
    gdp = sum(E .* μ[ρᵉ .== 1] .* (
        par.A .* Z_j[ρᵉ .== 1] .* F.(L_j[ρᵉ .== 1])  .- par.c .- par.c_z .* ρᶻ[ρᵉ .== 1]
        ))
    loggdp   = log(gdp)
    gdppw    = gdp ./ L 
    loggdppw = log(gdppw)

    # Average productivity
    avg_Z = sum(μ[ρᵉ .== 1] .* Z_j[ρᵉ .== 1])

    # Share of firms less 20 employees
    tag_20l = (L_j .<= 20)
    share_firms_20l = sum((tag_20l[(ρᵉ .== 1) .& (L_j .>= 5)] .==1 ) .* (μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)])))

    # Share of firms 20-100 employees
    tag_20_100 = ((L_j .> 20) .& (L_j .<= 100))
    share_firms_20_100 = sum((tag_20_100[(ρᵉ .== 1) .& (L_j .>= 5)] .==1 ) .* (μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))) 

    # Share of firms 100+ employees
    tag_100p = (L_j .> 100)
    share_firms_100p = sum((tag_100p[(ρᵉ .== 1) .& (L_j .>= 5)] .==1 ) .* (μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))) 

    # Catch condition 
    if rtn_gdp == true 
        return gdppc
    end

    # Make output table 
    names = [
        "LSE", "Mean log-wage", "Std log-wage", "Std log-wage (cond.)", 
        "Log GDPpc", "Log GDPpw", "Share Entrepreneurs who Invest", "Share Entrepreneurs who Invest (cond.)",
        "Mean firm size", "Mean firm size (cond.)", "Std log firm size", "Std log firm size (cond)", 
        "Mean productivity", "Share firms size<=20", "Share firms size 21-100", "Share firms size>100",
        "Mean firm age", "Mean firm age (cond.)" , "Mean employment growth", "Firm size RD premium", "Firm size wage premium", 
        "Share firms age<=30","Share firms age 31-60",  "Share firms age>60"
    ]
    vals = [
        par.ϵᴸ, mean_log_wage, std_log_wage, std_log_wage_cond,
        loggdp, loggdppw, share_entrep_invest,share_entrep_invest_cond,
        mean_firm_size, mean_firm_size_cond, std_log_firm_size, std_log_firm_size_cond, 
        avg_Z, share_firms_20l, share_firms_20_100, share_firms_100p,
        mean_firm_age,mean_firm_age_cut, mean_empl_gro, fs_rd_premium, beta_hat_fswp, 
        share_firms_30a, share_firms_30_60a, share_firms_60a
    ]
    df = DataFrame(var=names, value=vals)

    # Export output 
    CSV.write("../../RR_2025/robustness_xi/counter_out/moments_" * ext * ".csv", df)
    
    # Return nothing 
    return nothing
end



# Function to compute moments at given equilibrium
function computeMoments_cf3(M::Matrix{Float64}, Ψ, grid_z::Vector{Float64}, grid_a::Vector{Float64}, ρᵉ::Matrix{Float64}, ρᶻ::Matrix{Float64}, par, ext, rtn_gdp=false)
    # Size of grids 
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Compute μ and ϕ
    μ = zeros(size(M))
    ϕ = zeros(size(M))
    μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
    ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

    # Compute L and E 
    L = sum(ϕ)
    E = sum(μ)
    
    # Normalize distributions 
    μ .= μ./sum(μ)
    ϕ .= ϕ./sum(ϕ)

 
    # Compute GDP
    Z_j = repeat(grid_z',SIZE_A)'
 
    # Average productivity
    avg_Z = sum(μ[ρᵉ .== 1] .* Z_j[ρᵉ .== 1])

    # Make output table 
    names = [
        "LSE", "Mean productivity"
    ]
    vals = [
        par.ϵᴸ, avg_Z 
    ]
    df = DataFrame(var=names, value=vals)
    
    # Export output 
    CSV.write("../../RR_2025/robustness_xi/counter_out/moments_" * ext * ".csv", df)
    
    # Return nothing 
    return nothing
end



# Function to compute only estimation moments
function estimationMoments(M::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, ρᵉ::Matrix{Float64}, ρᶻ::Matrix{Float64}, Π_j::Matrix{Float64}, L_j::Matrix{Float64},  W_p::Matrix{Float64}, Ψ::Matrix{Float64},par,ext)
    # Size of grids 
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Compute μ and ϕ
    μ = zeros(size(M))
    ϕ = zeros(size(M))
    μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
    ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

    # Compute L and E 
    L = sum(ϕ)
    E = sum(μ)
    
    # Normalize distributions 
    μ .= μ./sum(μ)
    ϕ .= ϕ./sum(ϕ)

    # Mean firm size conditional on 5+
    mean_firm_size = sum(L_j[(ρᵉ .== 1) .& (L_j .>= 5)] .* (μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)])))

    # Firm size dispersion conditional on 5+ 
    std_log_firm_size = std(log.(L_j[(ρᵉ .== 1) .& (L_j .>= 5)]), AnalyticWeights(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))

    # Share invest in R&D 
    share_entrep_invest = sum((ρᶻ[(ρᵉ .== 1) .& (L_j .>= 5)] .==1 ) .* (μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)])))

    # Wage dispersion 
    std_log_wage = std(log.(W_p[(ρᵉ .== 1) .& (L_j .>= 5)]), AnalyticWeights((μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))))

    # Firm size wage premium 
    Y = vec(log.(W_p[(ρᵉ .== 1) .& (L_j .>= 5)]))
    X = vec(log.(L_j[(ρᵉ .== 1) .& (L_j .>= 5)]))
    varcovar = cov([Y X], AnalyticWeights((μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))))
    beta_hat_fswp = varcovar[1,2] / var(Y,AnalyticWeights((μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)])))) 

    ### DYNAMIC MOMENTS ###
    # Simulate 10000 agents
    mean_firm_age, mean_firm_age_cond, mean_empl_gro, fs_rd_premium, mean_firm_size_, logwFE_amen_cov = simulate_paths(Ψ, ρᵉ, ρᶻ, par, L_j, grid_z, grid_a, 10000, 5000, ext)
   
    # Compute GDP
    Z_j = repeat(grid_z',SIZE_A)'
    gdp = sum(E .* μ[ρᵉ .== 1] .* (
    par.A .* Z_j[ρᵉ .== 1] .* F.(L_j[ρᵉ .== 1])  .- par.c .- par.c_z .* ρᶻ[ρᵉ .== 1]
    ))

    # Return 
    return [mean_firm_size, share_entrep_invest, mean_firm_age, std_log_firm_size, std_log_wage, mean_empl_gro]
end



# Function to compute only estimation moments
function estimationMoments_correl(M::Matrix{Float64}, grid_z::Vector{Float64}, grid_a::Vector{Float64}, ρᵉ::Matrix{Float64}, ρᶻ::Matrix{Float64}, Π_j::Matrix{Float64}, L_j::Matrix{Float64},  W_p::Matrix{Float64}, Ψ::Matrix{Float64},par,ext)
    # Size of grids 
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Compute μ and ϕ
    μ = zeros(size(M))
    ϕ = zeros(size(M))
    μ[ρᵉ .== 1] .= M[ρᵉ .== 1]
    ϕ[ρᵉ .== 0] .= M[ρᵉ .== 0]

    # Compute L and E 
    L = sum(ϕ)
    E = sum(μ)
    
    # Normalize distributions 
    μ .= μ./sum(μ)
    ϕ .= ϕ./sum(ϕ)

    # Mean firm size conditional on 5+
    mean_firm_size = sum(L_j[(ρᵉ .== 1) .& (L_j .>= 5)] .* (μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)])))

    # Firm size dispersion conditional on 5+ 
    std_log_firm_size = std(log.(L_j[(ρᵉ .== 1) .& (L_j .>= 5)]), AnalyticWeights(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))

    # Share invest in R&D 
    share_entrep_invest = sum((ρᶻ[(ρᵉ .== 1) .& (L_j .>= 5)] .==1 ) .* (μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)])))

    # Wage dispersion 
    std_log_wage = std(log.(W_p[(ρᵉ .== 1) .& (L_j .>= 5)]), AnalyticWeights((μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))))

    # Firm size wage premium 
    Y = vec(log.(W_p[(ρᵉ .== 1) .& (L_j .>= 5)]))
    X = vec(log.(L_j[(ρᵉ .== 1) .& (L_j .>= 5)]))
    varcovar = cov([Y X], AnalyticWeights((μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)]))))
    beta_hat_fswp = varcovar[1,2] / var(Y,AnalyticWeights((μ[(ρᵉ .== 1) .& (L_j .>= 5)] ./ sum(μ[(ρᵉ .== 1) .& (L_j .>= 5)])))) 

    ### DYNAMIC MOMENTS ###
    # Simulate 10000 agents
    mean_firm_age, mean_firm_age_cond, mean_empl_gro, fs_rd_premium, mean_firm_size_, logwFE_amen_cov = simulate_paths(Ψ, ρᵉ, ρᶻ, par, L_j, grid_z, grid_a, 10000, 5000, ext)
   
    # Compute GDP
    Z_j = repeat(grid_z',SIZE_A)'
    gdp = sum(E .* μ[ρᵉ .== 1] .* (
    par.A .* Z_j[ρᵉ .== 1] .* F.(L_j[ρᵉ .== 1])  .- par.c .- par.c_z .* ρᶻ[ρᵉ .== 1]
    ))

    # Return 
    return [mean_firm_size, share_entrep_invest, mean_firm_age, std_log_firm_size, std_log_wage, mean_empl_gro,logwFE_amen_cov]
end


# Function to simulate paths of 10000 entrepreneurs 
function simulate_paths(Ψ, ρᵉ, ρᶻ, par, L_j, grid_z, grid_a, N, T, ext)
    # Set random seed 
    Random.seed!(8)

    # Size of grids 
    SIZE_Z = size(grid_z,1)
    SIZE_A = size(grid_a,1)

    # Make matrix with indices 
    IND = Array{Tuple{Int, Int}}(undef, (SIZE_Z,SIZE_A))
    for i = 1:SIZE_Z
        for j = 1:SIZE_A
            IND[i,j] = (i,j)
        end
    end

    # Sample indices 
    X = sample(vec(IND), Weights(vec(Ψ)),N)

    # Make containers - Z and A are indices not values
    age = zeros(N)
    age_firm = zeros(N)
    Z0 = zeros(N)
    Z = zeros(N)
    A = zeros(N)
    ent = zeros(N)
    l0  = zeros(N)
    l1 = zeros(N)
    rd = zeros(N)

    # Fill containers with initial values 
    for i = 1:N 
        Z[i] = X[i][1]
        Z0[i] = X[i][1]
        A[i] = X[i][2]
    end

    # Iterate periods 
    for t = 2:T 
        # Iterate agents 
        for n = 1:N 
            # Check death
            if rand() < par.δ
                # Dead, replace
                x = sample(vec(IND), Weights(vec(Ψ)),1)
                Z[n] = x[1][1]
                A[n] = x[1][2]
                Z0[n] = x[1][1]
                age[n] = 0
                
                # Check if entrepreneur 
                if ρᵉ[Int(Z[n]),Int(A[n])] == 1
                    ent[n] = 1
                    age_firm[n] = 0
                    l0[n] = L_j[Int(Z[n]),Int(A[n])]
                    l1[n] = L_j[Int(Z[n]),Int(A[n])]
                    rd[n] = ρᶻ[Int(Z[n]),Int(A[n])]
                else
                    ent[n] = 0
                    age_firm[n] = 0 
                    l0[n] = 0
                    l1[n] = 0
                    rd[n] = 0
                end
            else
                # Update age 
                age[n] = age[n] + 1

                # Draw shock 
                lot = rand()
                
                # Check which case applies 
                if (ρᵉ[Int(Z[n]),Int(A[n])] == 1) & (ρᶻ[Int(Z[n]),Int(A[n])] == 1)
                    # Invested so p_i 
                    if lot < par.p_i 
                        if Z[n] < SIZE_Z
                            Z[n] = Z[n] + 1
                        end
                    else
                        if Z[n] > 1
                            Z[n] = Z[n] - 1
                        end
                    end
                else
                    # p_n 
                    if lot < par.p_n
                        if Z[n] < SIZE_Z
                            Z[n] = Z[n] + 1
                        end
                    else
                        if Z[n] > 1
                            Z[n] = Z[n] - 1
                        end
                    end
                end

                # Check if entrepreneur 
                if ρᵉ[Int(Z[n]),Int(A[n])] == 1
                    # If not continuing entrepreneur (new, update l0)
                    if ent[n] == 0
                        l0[n] = L_j[Int(Z[n]),Int(A[n])]
                    end
                    ent[n] = 1
                    l1[n] = L_j[Int(Z[n]),Int(A[n])]
                    age_firm[n] = age_firm[n] + 1
                    rd[n] = ρᶻ[Int(Z[n]),Int(A[n])]
                else
                    ent[n] = 0
                    age_firm[n] = 0 
                    l0[n] = 0
                    l1[n] = 0
                    rd[n] = 0
                end
            end
        end
    end

    # Mean firm size (this is static but sanity check)
    mean_firm_size = mean(l1[ent .== 1])

    # Mean firm age 
    mean_firm_age = mean(age_firm[ent .== 1])

    #tag_age_cut  =  (age_firm[ent .== 1] .>= 0 .& l1[ent .==1] .>= 5)
    mean_firm_age_cut = mean_firm_age

    # Share of firms with positive age
    tag_a = (age_firm[ent .== 1] .>= 0)

    # Share of firms less 30 years
    tag_30a = (age_firm[ent .== 1] .<= 30)
    share_firms_30a = sum((tag_30a .==1 ))./sum((tag_a .==1 ))
    
    # Share of firms 30-60 years
    tag_30_60a = ((age_firm[ent .== 1] .> 30) .& (age_firm[ent .== 1] .<= 60))
    share_firms_30_60a  = sum((tag_30_60a .==1 ))./sum((tag_a .==1 ))
    
    # Share of firms 60+ years
    tag_60a = (age_firm[ent .== 1] .> 60)
    share_firms_60a = sum((tag_60a .==1 ))./sum((tag_a .==1 ))

    # Mean firm employment growth
    mean_empl_gro = mean(log.(l1[ent .== 1]) .- log.(l0[ent .== 1]))
    # Firm size RandD premium 
    lnsize = log.(l1[ent .== 1])
    rd = rd[ent .== 1]
    fs_rd_premium = ((cov(lnsize,rd))/var(lnsize))
    # Correlation wage FE vs Amenities
    amenities = (A[ent .== 1])
    logwFE    = 1/(1+(1-par.xi)*par.ϵᴸ) .* log.(grid_z[Int.(Z[ent .== 1])])  -  (1-par.xi)/(1+(1-par.xi)*par.ϵᴸ) .* A[ent .== 1]
    logwFE_amen_cov = (cov(logwFE, amenities)/var(logwFE))

    # Return moments 
    return [mean_firm_age, mean_firm_age_cut, mean_empl_gro, fs_rd_premium, mean_firm_size, share_firms_30a, share_firms_30_60a, share_firms_60a, logwFE_amen_cov]

end



# Function to evaluate parameter set
function evaluate(x, par, N_a, p, targets, rtn_obj=false)
    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]
    
    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))

        # Find bounds 
        bounds = dist_bounds(par, N_a)

        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end


        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)

        # Compute estimation moments 
        ext = "estim"
        sim = estimationMoments(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi, par, ext)
        println("****************************************************")
        println(x)
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum(
            [   
            ((sim[1]-targets[1])/targets[1])^2*10,
            ((sim[2]-targets[2])/targets[2])^2,
            ((sim[3]-targets[3])/targets[3])^2*10,
            ((sim[4]-targets[4])/targets[4])^2,
            ((sim[5]-targets[5])/targets[5])^2,
            ((sim[6]-targets[6])/targets[6])^2, 
                ]
        ) 


        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss) 

        names = ["mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "loss"]
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6],loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/bestfit.csv", df)


        return loss 
    catch
        return 9e15
    end 
end



# Function to evaluate parameter set
function jacobian(x, par, N_a, p, targets, ext, rtn_obj=false)
    # Extract 
    par.c   = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]
    
    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))

        # Find bounds 
        bounds = dist_bounds(par, N_a)

        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end


        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)

        # Compute estimation moments 
        sim = estimationMoments(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi, par,ext)
        println("****************************************************")
        println(round.(x, digits=2))
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum(
            [   
            ((sim[1]-targets[1])/targets[1])^2,
            ((sim[2]-targets[2])/targets[2])^2,
            ((sim[3]-targets[3])/targets[3])^2,
            ((sim[4]-targets[4])/targets[4])^2,
            ((sim[5]-targets[5])/targets[5])^2,
            ((sim[6]-targets[6])/targets[6])^2, 
                ]
        ) 

        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss) 

        names = ["mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "loss"] 
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6], loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/jacobian" * ext * ".csv", df)

        
        return loss 
    catch
        return 9e15
    end 
end




# Function to evaluate parameter set
function evaluate_counter_GRE(x, y, par, N_a, p, targets, rtn_obj=false)
    # Extract 
    par.A   = x[1]

    par.c   = y[1]
    par.c_z = y[2]
    par.p_i = y[3]
    par.p_n = y[4]
    par.σ_z = y[5]
    par.σ_a = y[6]
   
    
    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))

        # Find bounds 
        bounds = dist_bounds(par, N_a)

        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end


        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)

        # Compute estimation moments 
        sim = estimationMoments(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi,par,ext)
        println("****************************************************")
        println(round.(x, digits=2))
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum([(sim[7]-targets(7))/targets(7)])


        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss) 

        names = ["mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "gdp","loss"]
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6], sim[7], loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/counter_GRE_bestfit.csv", df)

        return loss 
    catch
        return 9e15
    end 
end


# Function to evaluate parameter set
function evaluate_correl(x, par, N_a, p, targets, rtn_obj=false)
    # Extract 
    par.c    = x[1]
    par.c_z  = x[2]
    par.p_i  = x[3]
    par.p_n  = x[4]
    par.σ_z  = x[5]
    par.σ_a  = x[6]
    par.σ_za = x[7]
    
    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z par.σ_za; par.σ_za par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))

        # Find bounds 
        bounds = dist_bounds(par, N_a)

        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end


        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)

        # Compute estimation moments 
        sim = estimationMoments_correl(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi, par,ext)
        println("****************************************************")
        println(round.(x, digits=2))
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum(
            [   
                ((sim[1]-targets[1])/targets[1])^2,
                ((sim[2]-targets[2])/targets[2])^2,
                ((sim[3]-targets[3])/targets[3])^2,
                ((sim[4]-targets[4])/targets[4])^2,
                ((sim[5]-targets[5])/targets[5])^2,
                ((sim[6]-targets[6])/targets[6])^2, 
                ((sim[7]-targets[7])/targets[7])^2, 
                ]
        ) 
        
        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss) 

        names = ["mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "wage FE - amenities correl","loss"]
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6],sim[7], loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/fit_withcorrelation.csv", df)

        return loss 
    catch
        return 9e15
    end 
end


# Function to evaluate parameter set
function sensitivity_c(x, par, N_a, p, targets, rtn_obj=false)
    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]

    eps = x[1]
    ext = "c_" * replace(string(eps), "."=>"_")

    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))
        # Find bounds 
        bounds = dist_bounds(par, N_a)
        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end
        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)
        

         # Compute estimation moments 
        sim = estimationMoments(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi,par,ext)
        println("****************************************************")
        println(round.(x, digits=2))
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum(
           [  
               ((sim[1]-targets[1])/targets[1])^2,
               ((sim[2]-targets[2])/targets[2])^2,
               ((sim[3]-targets[3])/targets[3])^2,
               ((sim[4]-targets[4])/targets[4])^2,
               ((sim[5]-targets[5])/targets[5])^2,
               ((sim[6]-targets[6])/targets[6])^2, 
               ]
        ) 

        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss)

        names = [ "mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "loss"]
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6], loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/fit_" * ext * ".csv", df)

        return loss 
    catch
        return 9e15
    end 

end



# Function to evaluate parameter set
function sensitivity_cz(x, par, N_a, p, targets, rtn_obj=false)
    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]

    eps = x[2]
    ext = "cz_" * replace(string(eps), "."=>"_")

    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))
        # Find bounds 
        bounds = dist_bounds(par, N_a)
        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end
        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)
        

         # Compute estimation moments 
        sim = estimationMoments(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi, par,ext)
        println("****************************************************")
        println(round.(x, digits=2))
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum(
           [  
               ((sim[1]-targets[1])/targets[1])^2,
               ((sim[2]-targets[2])/targets[2])^2,
               ((sim[3]-targets[3])/targets[3])^2,
               ((sim[4]-targets[4])/targets[4])^2,
               ((sim[5]-targets[5])/targets[5])^2,
               ((sim[6]-targets[6])/targets[6])^2, 
               ]
        ) 

        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss)

        names = [ "mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "loss"]
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6], loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/fit_" * ext * ".csv", df)

        return loss 
    catch
        return 9e15
    end 

end


# Function to evaluate parameter set
function sensitivity_pi(x, par, N_a, p, targets, rtn_obj=false)
    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]

    eps = x[3]
    ext = "pi_" * replace(string(eps), "."=>"_")

    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))
        # Find bounds 
        bounds = dist_bounds(par, N_a)
        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end
        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)
        

         # Compute estimation moments 
        sim = estimationMoments(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi, par,ext)
        println("****************************************************")
        println(round.(x, digits=2))
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum(
           [  
               ((sim[1]-targets[1])/targets[1])^2,
               ((sim[2]-targets[2])/targets[2])^2,
               ((sim[3]-targets[3])/targets[3])^2,
               ((sim[4]-targets[4])/targets[4])^2,
               ((sim[5]-targets[5])/targets[5])^2,
               ((sim[6]-targets[6])/targets[6])^2, 
               ]
        ) 

        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss)

        names = [ "mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "loss"]
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6], loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/fit_" * ext * ".csv", df)

        return loss 
    catch
        return 9e15
    end 

end


# Function to evaluate parameter set
function sensitivity_pn(x, par, N_a, p, targets, rtn_obj=false)
    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]

    eps = x[4]
    ext = "pn_" * replace(string(eps), "."=>"_")

    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))
        # Find bounds 
        bounds = dist_bounds(par, N_a)
        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end
        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)
        

         # Compute estimation moments 
        sim = estimationMoments(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi,par,ext)
        println("****************************************************")
        println(round.(x, digits=2))
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum(
           [  
               ((sim[1]-targets[1])/targets[1])^2,
               ((sim[2]-targets[2])/targets[2])^2,
               ((sim[3]-targets[3])/targets[3])^2,
               ((sim[4]-targets[4])/targets[4])^2,
               ((sim[5]-targets[5])/targets[5])^2,
               ((sim[6]-targets[6])/targets[6])^2, 
               ]
        ) 

        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss)

        names = [ "mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "loss"]
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6], loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/fit_" * ext * ".csv", df)

        return loss 
    catch
        return 9e15
    end 

end


# Function to evaluate parameter set
function sensitivity_sigmaz(x, par, N_a, p, targets, rtn_obj=false)
    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]

    eps = x[5]
    ext = "sigmaz_" * replace(string(eps), "."=>"_")

    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))
        # Find bounds 
        bounds = dist_bounds(par, N_a)
        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end
        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)
        

         # Compute estimation moments 
        sim = estimationMoments(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi, par,ext)
        println("****************************************************")
        println(round.(x, digits=2))
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum(
           [  
               ((sim[1]-targets[1])/targets[1])^2,
               ((sim[2]-targets[2])/targets[2])^2,
               ((sim[3]-targets[3])/targets[3])^2,
               ((sim[4]-targets[4])/targets[4])^2,
               ((sim[5]-targets[5])/targets[5])^2,
               ((sim[6]-targets[6])/targets[6])^2, 
               ]
        ) 

        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss)

        names = [ "mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "loss"]
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6], loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/fit_" * ext * ".csv", df)

        return loss 
    catch
        return 9e15
    end 

end



# Function to evaluate parameter set
function sensitivity_sigmaa(x, par, N_a, p, targets, rtn_obj=false)
    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]

    eps = x[6]
    ext = "sigmaa_" * replace(string(eps), "."=>"_")

    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))
        # Find bounds 
        bounds = dist_bounds(par, N_a)
        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 80
            return 9e10
        end
        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, 1e-6, par, 20)
        

         # Compute estimation moments 
        sim = estimationMoments(TM, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, Psi, par,ext)
        println("****************************************************")
        println(round.(x, digits=2))
        println(round.(sim,digits=3))
        println(round.(targets,digits=3))

        # Compute loss 
        loss = sum(
           [  
               ((sim[1]-targets[1])/targets[1])^2,
               ((sim[2]-targets[2])/targets[2])^2,
               ((sim[3]-targets[3])/targets[3])^2,
               ((sim[4]-targets[4])/targets[4])^2,
               ((sim[5]-targets[5])/targets[5])^2,
               ((sim[6]-targets[6])/targets[6])^2, 
               ]
        ) 

        if rtn_obj == true
            return M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM
        end
        println(loss)

        names = [ "mean_firm_size", "share_entrep_invest", "mean_firm_age", "std_log_firm_size", "std_log_wage", "mean_empl_gro", "loss"]
        vals = [sim[1], sim[2], sim[3], sim[4], sim[5], sim[6], loss]
        df = DataFrame(var=names, value=vals)
        CSV.write("../../RR_2025/robustness_xi/estim_out/fit_" * ext * ".csv", df)

        return loss 
    catch
        return 9e15
    end 


end




# Function to estimate the model 
function estimate(par, p, N_a, targets, x0)
    # Random seed for replicability
    Random.seed!(8)

    # Bounds
    lb = 0.99 .* x0
    ub = 1.01 .* x0

    # Estimate 
    res = Optim.optimize(x -> evaluate(x, par, N_a, p, targets), x0, ParticleSwarm(lb,ub,20), Optim.Options(x_tol=5e-2,iterations=30,show_trace=true, extended_trace=true))
    println(res)
    println(res.minimizer)
end


# Function to estimate the model 
function estimate_2(par, p, N_a, targets, x0)
    # Init
    # Estimate 
    res = Optim.optimize(x -> evaluate(x, par, N_a, p, targets), x0, Optim.Options(x_tol=5e-2,show_trace=true, extended_trace=true, iterations=30))
    println(res)
    println(res.minimizer)
end


# Function to estimate the model 
function estimate_counter_GRE(par, p, N_a, targets, y0, x0)
    # Random seed for replicability
    Random.seed!(8)

    # Bounds
    lb = 0.90 .* x0
    ub = 1.10 .* x0

    # Estimate 
    res = Optim.optimize(x -> evaluate_counter_GRE(x, y0,par, N_a, p, targets), x0, ParticleSwarm(lb,ub,20), Optim.Options(x_tol=5e-2,iterations=30,show_trace=true, extended_trace=true))
    println(res)
    println(res.minimizer)
end


# Function to estimate the model 
function estimate_2_counter_GRE(par, p, N_a, targets, x0,y0)
    # Init
    # Estimate 
    res = Optim.optimize(x -> evaluate_counter_GRE(x, y0,par, N_a, p, targets), x0, Optim.Options(x_tol=5e-2,show_trace=true, extended_trace=true, iterations=30))
    println(res)
    println(res.minimizer)
end


# Function to conduct simulations and counterfactuals
function counterfactual(x, par, N_a, p, targets, eps, ext, tol, rtn_gdp=false, save_matrices=false)
     # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]
    
    # Set epsilon 
    par.ϵᴸ = eps
 
     try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0;0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))

        # Find bounds 
        bounds = dist_bounds(par, N_a)

        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 60
            return 9e10
        end

        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, Vᴵ, Vᴺ, TM = solve(Psi, Grid_z, Grid_a, tol, par, 20)

        # Save matrices 
        if save_matrices == true 
            # Compute μ and ϕ
            μ = zeros(size(TM))
            ϕ = zeros(size(TM))
            μ[bρᵉ .== 1] .= TM[bρᵉ .== 1]
            ϕ[bρᵉ .== 0] .= TM[bρᵉ .== 0]
            # Compute L and E 
            L = sum(ϕ)
            E = sum(μ)
            ### Mean profit by Z 
            # Replace non-entrep by nan 
            profits = copy(Π_j)
            profits_l = profits[:,1]
            profits_h = profits[:,10]
            lj = copy(L_j)
            lj_ = copy(L_j)
            invs = copy(bρᶻ)
            # profits[bρᵉ .== 0] .= NaN
            lj[bρᵉ .== 0] .= NaN
            lj_[bρᵉ .== 0] .= 0
            invs[bρᵉ .== 0] .= 0
            # Compute employment and measure by bin 
            lj_ = lj_ .* μ            
            lj_z = sum(lj_,dims=2)
            lj_a = sum(lj_,dims=1)
            # Compute means over A and Z
            # lprofits_a = nanmean(log.(profits),1)
            # lprofits = nanmean(log.(profits),2)
            # profits_a = nanmean(profits,1)
            # profits = nanmean(profits,2)
            logL = nanmean(log.(lj),2)
            logL_a = nanmean(log.(lj),1)
            # invs = nanmean(invs,2)

            # Shares entrepreneurs by productivity 
            sum_measure = 1 ./ sum(TM, dims=2)
            W = TM .* sum_measure
            entry_z = bρᵉ .* W
            entry_z = sum(entry_z, dims=2)

            # Shares investors by productivity 
            # TM_ents = copy(TM)
            # TM_ents[bρᵉ .== 0] .= 0
            # sum_measure = 1 ./ sum(TM_ents, dims=2)
            sum_measure = 1 ./ sum(TM, dims=2)
            # W = TM_ents .* sum_measure
            W = TM .* sum_measure
            invs_z = invs .* W
            invs_z = sum(invs_z, dims=2)

            # Shares entrepreneurs by amenities 
            sum_measure = 1 ./ sum(TM, dims=1)
            W = TM .* sum_measure
            entry_a = bρᵉ .* W
            entry_a = sum(entry_a, dims=1)

            # Shares investors by amenities 
            # TM_ents = copy(TM)
            # TM_ents[bρᵉ .== 0] .= 0
            # sum_measure = 1 ./ sum(TM_ents, dims=1)
            # W = TM_ents .* sum_measure
            invs_a = invs .* W
            invs_a = sum(invs_a, dims=1)
            
            # Compute VI, VN 
            vi = copy(Vᴵ)
            vn = copy(Vᴺ)
            vi[(bρᵉ .== 0)] .= NaN
            vn[(bρᵉ .== 0)] .= NaN
            vivn = vi .- vn
            vivn[(vivn .< 0)] .= 0
            vi = nanmean(vi,2)
            vn = nanmean(vn,2)
            vivn = nanmean(vivn,2)

            # Make dataframe over z
            df = DataFrame(prod=vec(Grid_z), profit_l=vec(profits_l), profit_h=vec(profits_h), mean_log_emp=vec(logL), mean_inv=vec(invs_z), share_entrep=vec(entry_z), emp_share=vec(lj_z), vi=vec(vi), vn=vec(vn), vivn=vec(vivn))
            CSV.write("../../RR_2025/robustness_xi/counter_out/mean_profits_" * ext * ".csv", df)
            # Make dataframe over a 
            df = DataFrame(amenities=vec(Grid_a), mean_log_emp=vec(logL_a), mean_inv=vec(invs_a), emp_share=vec(lj_a), share_entrep=vec(entry_a))
            CSV.write("../../RR_2025/robustness_xi/counter_out/mean_profits_amenities_" * ext * ".csv", df)
            
            
            ### ENTRY POLICY 
            CSV.write("../../RR_2025/robustness_xi/counter_out/entry_" * ext * ".csv",  Tables.table(bρᵉ), writeheader=false)
            ### INVESTMENT POLICY 
            CSV.write("../../RR_2025/robustness_xi/counter_out/investment_" * ext * ".csv",  Tables.table(bρᶻ), writeheader=false)
            # Replace -Inf with NaN 
            VV = copy(V)
            VV[VV .< -1e5] .= NaN
            # Compute means over a 
            v = nanmean(VV,2)
            u = nanmean(Ũ,2)
            # Make dataframe 
            df = DataFrame(prod=vec(Grid_z), mean_v=vec(v), mean_utilde=vec(u))
            CSV.write("../../RR_2025/robustness_xi/counter_out/mean_v_u_" * ext * ".csv", df)
        end

        # Compute moments 
        if rtn_gdp == true
            gdppc = computeMoments(M, Psi, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, par, ext, rtn_gdp)
            return gdppc 
        else
            computeMoments(M, Psi, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, par, ext, rtn_gdp)
            # Return policies 
            return bρᵉ, bρᶻ, L_j
        end
     catch
         return 9e15
     end 
end

# Function to conduct counterfactuals with fixed investment policies
function counterfactual_cf1(x, par, N_a, p, targets, eps, ext, ρᶻ)

    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]
    
    # Set epsilon 
    par.ϵᴸ = eps

    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))

        # Find bounds 
        bounds = dist_bounds(par, N_a)

        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 60
            return 9e10
        end

        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, Π_j, L_j, W_p, V, U, Ũ, TM = solve_cf1(Psi, Grid_z, Grid_a, 1e-6, par, 100000, ρᶻ)

        # Compute moments 
        computeMoments(M, Psi, Grid_z, Grid_a, bρᵉ, bρᶻ, Π_j, L_j, W_p, par, ext)
    catch
        return 9e15
    end 
end

# Function to conduct counterfactuals with fixed entry policies
function counterfactual_cf2(x, par, N_a, p, targets, eps, ext, bρᵉ)
    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]
    
    # Set epsilon 
    par.ϵᴸ = eps

    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))

        # Find bounds 
        bounds = dist_bounds(par, N_a)

        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 60
            return 9e10
        end


        # Solve for stationary equilibrium
        M, ρᵉ, ρᶻ, Π_j, L_j, W_p, V, U, Ũ, TM = solve_cf2(Psi, Grid_z, Grid_a, 1e-3, par, 100000, bρᵉ)

        # Compute moments 
        computeMoments(M, Psi, Grid_z, Grid_a, ρᵉ, ρᶻ, Π_j, L_j, W_p, par, ext)
    catch
        return 9e15
    end 
end

# Function to conduct counterfactuals with fixed investment and entry policies
function counterfactual_cf3(x, par, N_a, p, targets, eps, ext, ρᵉ, ρᶻ)
    # Extract 
    par.c = x[1]
    par.c_z = x[2]
    par.p_i = x[3]
    par.p_n = x[4]
    par.σ_z = x[5]
    par.σ_a = x[6]
    
    # Set epsilon 
    par.ϵᴸ = eps

    try
        # Joint distribution of produtivities and ameninites
        μ = [par.μ_z, par.μ_a]
        Σ = [par.σ_z 0; 0 par.σ_a]
        Dist = MvLogNormal(MvNormal(μ,Σ))

        # Find bounds 
        bounds = dist_bounds(par, N_a)

        # Discretize the distribution PDF and get PMF matrix for the
        Grid_z, Grid_a, Psi = discretize_joint_dist(Dist, p, N_a, bounds)
        
        if length(Grid_z) > 60
            return 9e10
        end


        # Solve for stationary equilibrium
        M, bρᵉ, bρᶻ, TM  =  solve_cf3(Psi, Grid_z, Grid_a, 1e-6, par, 100000, ρᵉ, ρᶻ)

        # Compute moments 
        computeMoments_cf3(M, Psi, Grid_z, Grid_a, bρᵉ, bρᶻ, par, ext)
    catch
        return 9e15
    end 
end

# Helper Functions
nanmean(x) = mean(filter(!isnan,x))
nanmean(x,y) = mapslices(nanmean,x,dims=y)