Skip to content
Snippets Groups Projects
Commit 946f2102 authored by Luca Lenz's avatar Luca Lenz
Browse files

fixed bugs

parent 28f9d2da
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,9 @@ export sample ...@@ -26,7 +26,9 @@ export sample
export MetropolisHastings, ChristenFox export MetropolisHastings, ChristenFox
export total_costs export total_costs
export nchains, info, states export nchains, info
export is_multilevel
export states
export autocov, logpart, empirical_cdf export autocov, logpart, empirical_cdf
export RejectionChains export RejectionChains
......
...@@ -41,7 +41,7 @@ function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::Christen ...@@ -41,7 +41,7 @@ function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::Christen
return (; return (;
rejections = [zeros(Int, length(model))], rejections = [zeros(Int, length(model))],
transitions = [sample[2:end]], transitions = [sample[2:end]],
last_accepted = Ptr{Int}(0) last_accepted = Ref{Int}(0)
) )
end end
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs... ) where {P} function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs... ) where {P}
......
...@@ -28,56 +28,3 @@ function Base.display(c::AbstractRejectionChains) ...@@ -28,56 +28,3 @@ function Base.display(c::AbstractRejectionChains)
end end
end end
#=abstract type AbstractRejectionChains <: AbstractMCMC.AbstractChains end
#include("simple.jl")
#include("multi.jl")
#=
struct RejectionChains{X, L, R, I} <: AbstractRejectionChains
states::Matrix{X}
logprobs::Matrix{L}
rejections::Matrix{R}
end
eltype(c::RejectionChains{X,L,R}) where {X,L,R} = X
size(c::RejectionChains, args...) = size(c.states, args...)
function getindex(c::RejectionChains, n, m )
return RejectionChains(
c.states[n, m],
c.logprobs[n, m],
c.rejections[n, m]
)
end
function states(c::RejectionChains)
return fill.(c.states, c.rejections)
end
function get_info(c::RejectionChains)
s = size(c)
rejection_rate
(; )
end
=#
function Base.show(io::IO, c::RejectionBasedChains)
print(io, "Chain ", size(c), " {", eltype(c), "}")
end
function Base.display(c::RejectionBasedChains)
println(c)
info = get_info(c)
nchains = size(c, 2)
for k = keys(info)
if nchains == 1
println(" ", k, " : ", info[k][1])
else
m, s = mean(info[k]), std(info[k])
println(" ", k, " : ", m, " ± ", s)
end
end
end
=#
\ No newline at end of file
...@@ -9,16 +9,22 @@ using Revise ...@@ -9,16 +9,22 @@ using Revise
using MultilevelChainSampler using MultilevelChainSampler
using CSV, DataFrames using CSV, DataFrames
w = CyclicWalk() w = CyclicWalk()
s = MetropolisHastings(w) s = MetropolisHastings(w)
f = LogDensity(x->-x^2/2) f = LogDensity(x->-x^2/2)
c = sample(f, s, MCMCSerial(), 100, 4) c = sample(f, s, MCMCSerial(), 300, 4)
# Test conversion # Test conversion
df = convert(Vector{DataFrame}, c) df = convert(Vector{DataFrame}, c)
c2 = convert(RejectionChains, df) c2 = convert(RejectionChains, df)
# Test
autocov(c) autocov(c)
# Test multilevel
@test all( abs.( empirical_cdf(c, .5) .- 0.78045 ) .< 0.1 ) @test all( abs.( empirical_cdf(c, .5) .- 0.78045 ) .< 0.1 )
g = MultilevelLogDensity([f,f])
s3 = ChristenFox(w, true)
c3 = sample(g, s3, MCMCSerial(), 300, 4)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment