Skip to content
Snippets Groups Projects
Commit 946f2102 authored by Luca Lenz's avatar Luca Lenz
Browse files

fixed bugs

parent 28f9d2da
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,9 @@ export sample
export MetropolisHastings, ChristenFox
export total_costs
export nchains, info, states
export nchains, info
export is_multilevel
export states
export autocov, logpart, empirical_cdf
export RejectionChains
......
......@@ -41,7 +41,7 @@ function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::Christen
return (;
rejections = [zeros(Int, length(model))],
transitions = [sample[2:end]],
last_accepted = Ptr{Int}(0)
last_accepted = Ref{Int}(0)
)
end
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs... ) where {P}
......
......@@ -28,56 +28,3 @@ function Base.display(c::AbstractRejectionChains)
end
end
#=abstract type AbstractRejectionChains <: AbstractMCMC.AbstractChains end
#include("simple.jl")
#include("multi.jl")
#=
struct RejectionChains{X, L, R, I} <: AbstractRejectionChains
states::Matrix{X}
logprobs::Matrix{L}
rejections::Matrix{R}
end
eltype(c::RejectionChains{X,L,R}) where {X,L,R} = X
size(c::RejectionChains, args...) = size(c.states, args...)
function getindex(c::RejectionChains, n, m )
return RejectionChains(
c.states[n, m],
c.logprobs[n, m],
c.rejections[n, m]
)
end
function states(c::RejectionChains)
return fill.(c.states, c.rejections)
end
function get_info(c::RejectionChains)
s = size(c)
rejection_rate
(; )
end
=#
function Base.show(io::IO, c::RejectionBasedChains)
print(io, "Chain ", size(c), " {", eltype(c), "}")
end
function Base.display(c::RejectionBasedChains)
println(c)
info = get_info(c)
nchains = size(c, 2)
for k = keys(info)
if nchains == 1
println(" ", k, " : ", info[k][1])
else
m, s = mean(info[k]), std(info[k])
println(" ", k, " : ", m, " ± ", s)
end
end
end
=#
\ No newline at end of file
......@@ -9,16 +9,22 @@ using Revise
using MultilevelChainSampler
using CSV, DataFrames
w = CyclicWalk()
s = MetropolisHastings(w)
f = LogDensity(x->-x^2/2)
c = sample(f, s, MCMCSerial(), 100, 4)
c = sample(f, s, MCMCSerial(), 300, 4)
# Test conversion
df = convert(Vector{DataFrame}, c)
c2 = convert(RejectionChains, df)
# Test
autocov(c)
# Test multilevel
@test all( abs.( empirical_cdf(c, .5) .- 0.78045 ) .< 0.1 )
g = MultilevelLogDensity([f,f])
s3 = ChristenFox(w, true)
c3 = sample(g, s3, MCMCSerial(), 300, 4)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment