Skip to content
Snippets Groups Projects
run_tests.jl 1.48 KiB
Newer Older

using StatsBase, Distributions
using AbstractMCMC: MCMCSerial #, MCMCThreads, MCMCDistributed
using CairoMakie

using Revise
Luca Lenz's avatar
Luca Lenz committed
#using Pkg; Pkg.activate(".")
using MultilevelChainSampler

function analyse(chains)

    t = TruncatedNormal(0, 1, -1, 1)
    println("  Target: mean = ", mean(t), ", std = ", std(t))

    accept = [ getindex.(c, 1) for c in chains ]
    states = [ getindex.(c, 2) for c in chains ]

    rejection_rate = 1 .- mean.(accept)
    println("  Rejection rate: ", mean(rejection_rate) * 100, " ± ", std(rejection_rate) * 100, " % ")

    m = mean.( states )
    s = std.(  states )
    println("  Mean ", mean(m), " ± ", std(m) )
    println("  Std  ", mean(s), " ± ", std(s) )
    println("\n")

    fig = Figure(); ax = Axis(fig[1,1])
    for s in states[1:1]
        hist!(ax, s, normalization=:pdf)
    end
    x = [-1:.01:1 ... ]; y = pdf.(t, x)
    lines!(ax, x, y, label="N(0,1)")
    display(fig);
end

begin # Vanilla MetropolisHastings
    
    w = CyclicWalk(-1, 1, .5)
    s = MetropolisHastings(w)
    f = LogDensity(x -> -x^2/2)

    print("Sampling π_∞")
    @time chains = sample(f, s, MCMCSerial(), 10000, 4)

    analyse(chains)
end

begin # Sample Based Energy Function
    
    w = CyclicWalk(-1, 1, .5)
    s = MetropolisHastings(w)
    n = 100
    z = [ rand(2) .* 2 .- 1 for i=1:n ]  
    f = SampledLogDensity(z, (x,z) -> - ( sum(z.^2) < x^2 )/2 )

    print("Sampling π_$n")
    @time chains = sample(f, s, MCMCSerial(), 10000, 4)

    analyse(chains)
end