Skip to content
Snippets Groups Projects
generate_chains.jl 1.36 KiB
Newer Older
include("utils/all.jl")

μ, σ = .5, .1
π_ref = MixtureModel([
    TruncatedNormal(-μ, σ, -1, 0),
    TruncatedNormal( μ, σ,  0, 1),
])

single_peak = TruncatedNormal(μ, σ, -1, 1)
dist_single_peak = (; 
    ks = exact_distance_ks(single_peak, π_ref),
    cm = exact_distance_cm(single_peak, π_ref)
)

z = halton_points(2000, 3)
f = normalmix_proxy(z, μ, σ)
l = MultilevelSampledLogDensity(f, [100, 2000])

w = CyclicWalk(-1,1,1.0)


begin 
    println("Sampling chains...\n")

    println("MetropolisHastings")
    s_mh = MetropolisHastings(w)
    @time c_mh = sample(f, s_mh, MCMCSerial(), 100000, 100)
    display(c_mh)
    @time write_chain("$(@__DIR__)/data/mh.csv", c_mh)

    println("ChristenFox")
    s_cf = ChristenFox(w)
    @time c_cf = sample(l, s_cf, MCMCSerial(), 100000, 100)
    display(c_cf)
    @time write_chain("$(@__DIR__)/data/cf.csv", c_cf)
    
    saved = mean(get_info(c_cf).evaluations) / mean(get_info(c_mh).evaluations)
    println("\n => Saved costs ", round(saved * 100, digits=2), " % \n")
end


begin
    @time write_metrics("$(@__DIR__)/data/mh.csv", π_ref) 
    @time write_metrics("$(@__DIR__)/data/cf.csv", π_ref) 
end


#= 
fig = Figure(); ax = Axis(fig[1,1])
plot_pdf!(ax, π_ref, color=:black, linestyle=:dash, label=L"π_{∞}")
plot_pdf!(ax, f)
#hist!(ax, c.states[:,1], normalization=:pdf, label="MH")
Legend(fig[1,2], ax)
display(fig)
=#