Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
include("utils/all.jl")
μ, σ = .5, .1
π_ref = MixtureModel([
TruncatedNormal(-μ, σ, -1, 0),
TruncatedNormal( μ, σ, 0, 1),
])
single_peak = TruncatedNormal(μ, σ, -1, 1)
dist_single_peak = (;
ks = exact_distance_ks(single_peak, π_ref),
cm = exact_distance_cm(single_peak, π_ref)
)
z = halton_points(2000, 3)
f = normalmix_proxy(z, μ, σ)
l = MultilevelSampledLogDensity(f, [100, 2000])
w = CyclicWalk(-1,1,1.0)
begin
println("Sampling chains...\n")
println("MetropolisHastings")
s_mh = MetropolisHastings(w)
@time c_mh = sample(f, s_mh, MCMCSerial(), 100000, 100)
display(c_mh)
@time write_chain("$(@__DIR__)/data/mh.csv", c_mh)
println("ChristenFox")
s_cf = ChristenFox(w)
@time c_cf = sample(l, s_cf, MCMCSerial(), 100000, 100)
display(c_cf)
@time write_chain("$(@__DIR__)/data/cf.csv", c_cf)
saved = mean(get_info(c_cf).evaluations) / mean(get_info(c_mh).evaluations)
println("\n => Saved costs ", round(saved * 100, digits=2), " % \n")
end
begin
@time write_metrics("$(@__DIR__)/data/mh.csv", π_ref)
@time write_metrics("$(@__DIR__)/data/cf.csv", π_ref)
end
#=
fig = Figure(); ax = Axis(fig[1,1])
plot_pdf!(ax, π_ref, color=:black, linestyle=:dash, label=L"π_{∞}")
plot_pdf!(ax, f)
#hist!(ax, c.states[:,1], normalization=:pdf, label="MH")
Legend(fig[1,2], ax)
display(fig)
=#