Skip to content
Snippets Groups Projects
Commit 2bc10141 authored by Luca Lenz's avatar Luca Lenz
Browse files

changed chain structure, added total_costs() utility function

parent a8f39678
No related branches found
No related tags found
No related merge requests found
......@@ -106,28 +106,28 @@ function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity,
end
#=
# Collect samples
function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox, state, chain_type::Type; kwargs ... )
states = getindex.(samples.transitions, 1)
logprobs = getindex.(samples.transitions, 2)
info = Dict()
N = sum( sum.(samples.rejections) .+ 1)
info[:rejection_rate] = sum(samples.rejections) ./ N
if m isa MultilevelSampledLogDensity
nevals = N .- cumsum([0, sum(samples.rejections)[1:end-1]...])
function total_costs(c::RejectionChains, m::MultilevelSampledLogDensity)
costs = Int[]; sizehint!(costs, length(c))
for i = 1:length(c)
N = c.info[:chain_length][i]
r = getfield.(c.samples[i], :reject)
nevals = N .- cumsum([0, sum(r)[1:end-1]...])
nlevels = m.nlevels[1:end] .- [0, m.nlevels[1:end-1]...]
info[:evaluations] = sum( nlevels .* nevals )
push!(costs, sum(nlevels .* nevals))
end
return SimpleChains(states, logprobs, samples.rejections; info...)
return costs
end
=#
function AbstractMCMC.bundle_samples(samples, m::AbstractModel, ::ChristenFox, state, chain_type::Type; kwargs...)
function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox, state, chain_type::Type; kwargs...)
s = getindex.(samples.transitions, 1)
l = getindex.(samples.transitions, 2)
r = samples.rejections
return RejectionChains(s,l,r)
c = RejectionChains(s,l,r)
if m isa MultilevelSampledLogDensity
c.info[:total_costs] = total_costs(c, m)
end
return c
end
......@@ -51,25 +51,19 @@ function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::Metr
end
end
function total_costs(c::RejectionChains, m::SampledLogDensity)
return c.info[:chain_length] .* length(m)
end
function AbstractMCMC.bundle_samples(samples, m::AbstractModel, ::MetropolisHastings, state, chain_type::Type; kwargs...)
s = getindex.(samples.transitions, 1)
l = getindex.(samples.transitions, 2)
r = samples.rejections
return RejectionChains(s,l,r)
end
c = RejectionChains(s,l,r)
#=
# Collect samples
function AbstractMCMC.bundle_samples(samples, m::AbstractModel, ::MetropolisHastings, state, chain_type::Type; kwargs ... )
states = getindex.(samples.transitions, 1)
logprobs = getindex.(samples.transitions, 2)
info = Dict()
N = sum( sum.(samples.rejections) .+ 1)
info[:rejection_rate] = sum(samples.rejections) ./ N
if m isa SampledLogDensity
info[:evaluations] = N * length(m)
c.info[:total_costs] = total_costs(c,m)
end
return SimpleChains(states, logprobs, samples.rejections; info...)
end
=#
return c
end
\ No newline at end of file
......@@ -9,12 +9,23 @@ struct RejectionChains{S,L,R} <: AbstractMCMC.AbstractChains
end
}
}
info::Dict{Symbol, Vector}
end
# calculate rejection info
function RejectionChains(samples::Vector{Vector{@NamedTuple{state::S, logprob::L, reject::R}}}) where {S,L,R}
rejections = [ getfield.(s, :reject) for s=samples ]
data = Dict{Symbol, Vector}()
data[:chain_length] = [ sum(1 .+ sum.(r)) for r=rejections ]
data[:rejection_rate] = sum.(rejections) ./ data[:chain_length]
RejectionChains(samples, data)
end
# construct single chain
function RejectionChains(s::Vector{S}, l::Vector{L}, r::Vector{R}) where {S,L,R}
samples = NamedTuple{ (:state, :logprob, :reject)}.([zip(s, l, r)...] )
return RejectionChains{S,L,R}([samples])
return RejectionChains([samples])
end
length(chains::RejectionChains) = length(chains.samples)
......@@ -24,9 +35,22 @@ getindex(chains::RejectionChains, id::Integer) = chains.samples[id]
getindex(chains::RejectionChains, ids::OrdinalRange) = RejectionChains(chains.samples[ids])
function AbstractMCMC.chainscat(c::RejectionChains ... )
RejectionChains( vcat( getfield.(c, :samples) ... ) )
infos = getfield.(c, :info)
lens = length.(c)
dict = Dict{Symbol, Vector}()
for k in union(keys.(infos) ... )
vals = [ haskey(infos[i],k) ?
infos[i][k] : repeat([missing], lens[i])
for i = 1:length(c)
]
println
dict[Symbol(k)] = vcat(vals ... )
end
RejectionChains( vcat( getfield.(c, :samples) ... ), dict)
end
is_multilevel(chains::RejectionChains{S,L,R}) where {S,L,R} = R <: Vector
# Repeats rejected values
function states(chains::RejectionChains{S,L,R}) where {S,L,R}
V = Vector{S}[]; sizehint!(V, length(chains))
......@@ -42,13 +66,7 @@ function states(chains::RejectionChains{S,L,R}) where {S,L,R}
end
# Get info on chain
function info(chains::RejectionChains)
data = Dict()
rejections = [getfield.(x, :reject) for x=chains.samples]
data[:chain_length] = [ sum(1 .+ sum.(r)) for r=rejections ]
data[:rejection_rate] = sum.(rejections) ./ data[:chain_length]
return data
end
info(chains::RejectionChains) = chains.info
# Convert to table
function convert(::Type{<:DataFrame}, chains::RejectionChains)
......@@ -60,7 +78,3 @@ function convert(::Type{<:DataFrame}, chains::RejectionChains)
end
return df
end
function is_multilevel(chains::RejectionChains{S,L,R}) where {S,L,R}
return R <: Vector
end
\ No newline at end of file
......@@ -12,20 +12,20 @@ w = CyclicWalk()
s = ChristenFox(w)
f = LogDensity(x->-x^2/2)
g = LogDensity(x->-max(0,abs(x)/sqrt(2)))
g = LogDensity(x->-min(10,abs(x)/sqrt(2)))
l = MultilevelLogDensity([g,f])
q = (x,z)-> - 4 / pi * ((z[1]^2 + z[2]^2) < x^2) / 2
l = MultilevelSampledLogDensity(q, [100, 1000], 2)
c = sample(l, s, 1000)
#display(c)
display(c)
x = vcat(states(c)... )
@test abs(mean(x)) < 0.1
@test abs(std(x) - 0.539560) < 0.1
c = sample(l, s, MCMCSerial(), 100, 5)
#display(c)
display(c)
x = hcat(states(c)...)
@test abs(mean(x)) < 0.1
@test abs(std(x) - 0.539560) < 0.1
......@@ -10,16 +10,33 @@ using MultilevelChainSampler
w = CyclicWalk()
s = MetropolisHastings(w)
# Test analytical log density
f = LogDensity(x->-x^2/2)
c = sample(f, s, 1000)
#display(c)
display(c)
x = vcat(states(c)... )
@test abs(mean(x)) < 0.1
@test abs(std(x) - 0.539560) < 0.1
c = sample(f, s, MCMCSerial(), 100, 5)
#display(c)
display(c)
x = hcat(states(c)...)
@test abs(mean(x)) < 0.1
@test abs(std(x) - 0.539560) < 0.1
# Test sampled log density
g = SampledLogDensity(
[ rand(2) for i=1:100] ,
(x,z)->-4/pi*(sum(z.^2) < x^2) / 2
)
c = sample(g, s, MCMCSerial(), 100, 4)
display(c)
x = hcat(states(c)...)
@test abs(mean(x)) < 0.1
@test abs(std(x) - 0.539560) < 0.1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment