Skip to content
Snippets Groups Projects
Commit 1d5abea8 authored by Luca Lenz's avatar Luca Lenz
Browse files

fixed logdensity keyword args

parent 9fd88371
No related branches found
No related tags found
No related merge requests found
""" """
Delayed Acceptance algorithm Delayed Acceptance algorithm
If `saveproxies == true` save accepted states on lower levels. If `saveproxies == true` save log-density lower levels.
This is particularly useful for MLMC integration. This is particularly useful for MLMC integration.
References References
...@@ -17,13 +17,16 @@ ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{s ...@@ -17,13 +17,16 @@ ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{s
# Initialize samples container # Initialize samples container
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P} function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs...) where {P}
return (; rejections=[zeros(Int, length(model))], transitions=[sample[2:end]]) return (; rejections=[zeros(Int, length(model))], transitions=[sample[2:end]])
end end
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P}
return (; rejections=[zeros(Int, length(model))], transitions=[(sample[2], sample[3][end], sample[4:end]...)])
end
# Store sample to container # Store sample to container
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P} function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs... ) where {P}
if sample[1] == length(model)+1 # sample was accepted if sample[1] == length(model)+1 # sample was accepted
push!(samples.rejections, zeros(Int, length(model))) push!(samples.rejections, zeros(Int, length(model)))
push!(samples.transitions, sample[2:end]) push!(samples.transitions, sample[2:end])
...@@ -33,36 +36,78 @@ function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilev ...@@ -33,36 +36,78 @@ function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilev
return samples return samples
end end
# Chain step function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P}
if sample[1] == length(model)+1 # sample was accepted
push!(samples.rejections, zeros(Int, length(model)))
push!(samples.transitions, ( sample[2], sample[3][end], sample[4:end]...))
else
samples.rejections[end][sample[1]] += 1
end
return samples
end
# Initialize chain
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox; kwargs...)
x = rand(rng, sampler.proposal)
f_x = [ logdensity(model, x; level=l) for l=1:length(model) ]
return (length(model), x, f_x), (x, f_x)
end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}; kwargs...) where {P} function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox; kwargs...)
x = rand(rng, sampler.proposal) x = rand(rng, sampler.proposal)
f_x = [ logdensity(model, x, i) for i=1:length(model) ] f_x = [ logdensity(model, x; level=1) ]
return (length(model), x, f_x[end]), (x, f_x) sizehint!(f_x, length(model))
for l=2:length(model)
push!(f_x, logdensity(model, x; level=l, cache=f_x[end]))
end
return (length(model), x, f_x), (x, f_x)
end end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}, state; kwargs...) where {P}
# Chain stepping
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox, state; kwargs...)
x, f_x = state x, f_x = state
y = propose(rng, sampler.proposal, x) y = propose(rng, sampler.proposal, x)
f_y = [ logdensity(model, y, 1) ] f_y = [ logdensity(model, y; level=1) ]
q = logpratio(sampler.proposal, x, y) q = logpratio(sampler.proposal, x, y)
A_1 = min( f_y[1] - f_x[1] + q, 0) A_1 = min( f_y[1] - f_x[1] + q, 0)
accept = log(rand(rng)) < A_1 accept = log(rand(rng)) < A_1
if !accept return (1, x, f_x[end]), (x, f_x) end if !accept return (1, x, f_x), (x, f_x) end
for l = 2:length(model) for l = 2:length(model)
push!(f_y, logdensity(model, y, l) ) push!(f_y, logdensity(model, y; level=l) )
A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1] A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1]
accept = log(rand(rng)) < A_l accept = log(rand(rng)) < A_l
if !accept return (l, y, f_y[end]), (x, f_x) end if !accept return (l, y, f_y), (x, f_x) end
end
return (length(model)+1, y, f_y), (y, f_y)
end
function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox, state; kwargs...)
x, f_x = state
y = propose(rng, sampler.proposal, x)
f_y = [ logdensity(model, y; level=1) ]
q = logpratio(sampler.proposal, x, y)
A_1 = min( f_y[1] - f_x[1] + q, 0)
accept = log(rand(rng)) < A_1
if !accept return (1, x, f_x), (x, f_x) end
for l = 2:length(model)
push!(f_y, logdensity(model, y; level=l, cache=f_y[end]) )
A_l = (f_y[l] - f_x[l]) - (f_y[l-1] - f_x[l-1])
accept = log(rand(rng)) < A_l
if !accept return (l, y, f_y), (x, f_x) end
end end
return (length(model)+1, y, f_y[end]), (y, f_y) return (length(model)+1, y, f_y), (y, f_y)
end end
# Collect samples # Collect samples
function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox{false,P}, state, chain_type::Type; kwargs ... ) where {P} function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox, state, chain_type::Type; kwargs ... )
states = getindex.(samples.transitions, 1) states = getindex.(samples.transitions, 1)
logprobs = getindex.(samples.transitions, 2) logprobs = getindex.(samples.transitions, 2)
info = Dict() info = Dict()
...@@ -71,7 +116,8 @@ function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::Chri ...@@ -71,7 +116,8 @@ function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::Chri
info[:rejection_rate] = sum(samples.rejections) ./ N info[:rejection_rate] = sum(samples.rejections) ./ N
if m isa MultilevelSampledLogDensity if m isa MultilevelSampledLogDensity
nevals = N .- cumsum([0, sum(samples.rejections)[1:end-1]...]) nevals = N .- cumsum([0, sum(samples.rejections)[1:end-1]...])
info[:evaluations] = sum( m.nlevels .* nevals ) nlevels = m.nlevels[1:end] .- [0, m.nlevels[1:end-1]...]
info[:evaluations] = sum( nlevels .* nevals )
end end
return SimpleChains(states, logprobs, samples.rejections; info...) return SimpleChains(states, logprobs, samples.rejections; info...)
end end
......
...@@ -4,9 +4,9 @@ struct LogDensity{L} <: AbstractLogDensity ...@@ -4,9 +4,9 @@ struct LogDensity{L} <: AbstractLogDensity
density :: L density :: L
end end
logdensity(model::LogDensity{<:Function}, x) = model.density(x) logdensity(model::LogDensity{<:Function}, x; kwargs...) = model.density(x)
logdensity(model::LogDensity{<:Distribution}, x) = logpdf(model.density, x) logdensity(model::LogDensity{<:Distribution}, x; kwargs...) = logpdf(model.density, x)
logdensity(model::AbstractMCMC.LogDensityModel, x) = LogDensityProblems.logdensity(model.logdensity, x) logdensity(model::AbstractMCMC.LogDensityModel, x; kwargs...) = LogDensityProblems.logdensity(model.logdensity, x)
include("samplebased.jl") include("samplebased.jl")
include("multilevel.jl") include("multilevel.jl")
\ No newline at end of file
...@@ -7,7 +7,7 @@ MultilevelLogDensity(v::L... ) where {L <: Tuple{Vararg{<:AbstractModel}}} = Mul ...@@ -7,7 +7,7 @@ MultilevelLogDensity(v::L... ) where {L <: Tuple{Vararg{<:AbstractModel}}} = Mul
MultilevelLogDensity(v::Vector{ <: AbstractModel} ) = MultilevelLogDensity( tuple(v...) ) MultilevelLogDensity(v::Vector{ <: AbstractModel} ) = MultilevelLogDensity( tuple(v...) )
length(m::MultilevelLogDensity) = length(m.proxies) length(m::MultilevelLogDensity) = length(m.proxies)
logdensity(m::MultilevelLogDensity, x, level=length(m)) = logdensity(m.proxies[level], x) logdensity(m::MultilevelLogDensity, x; level=length(m), kwargs...) = logdensity(m.proxies[level], x)
function Base.show(io::IO, m::SampledLogDensity) function Base.show(io::IO, m::SampledLogDensity)
println(io, "MultilevelLogDensity ", length(m), " {", eltype(m.samples), "} samples") println(io, "MultilevelLogDensity ", length(m), " {", eltype(m.samples), "} samples")
...@@ -21,10 +21,10 @@ end ...@@ -21,10 +21,10 @@ end
MultilevelSampledLogDensity(f::Function, n::Vector{<:Integer}, d::Int=1, a=0, b=1) = MultilevelSampledLogDensity(SampledLogDensity(f, maximum(n), d, a, b), n) MultilevelSampledLogDensity(f::Function, n::Vector{<:Integer}, d::Int=1, a=0, b=1) = MultilevelSampledLogDensity(SampledLogDensity(f, maximum(n), d, a, b), n)
length(m::MultilevelSampledLogDensity) = length(m.nlevels) length(m::MultilevelSampledLogDensity) = length(m.nlevels)
logdensity(m::MultilevelSampledLogDensity, x, level=length(m), cache=nothing) = logdensity(m.density, x, level, cache) logdensity(m::MultilevelSampledLogDensity, x; level=length(m), cache=nothing, kwargs...) = _logdensity(m, x, level, cache)
logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Nothing) = logdensity(m.density, x, m.nlevels[level]) _logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Nothing) = logdensity(m.density, x; level=m.nlevels[level])
logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Number) = logdensity(m.density, x, m.nlevels[level], (level-1, cache)) _logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Number) = logdensity(m.density, x; level=m.nlevels[level], cache=(level-1, cache))
......
...@@ -12,18 +12,18 @@ end ...@@ -12,18 +12,18 @@ end
length(m::SampledLogDensity) = length(m.samples) length(m::SampledLogDensity) = length(m.samples)
logdensity(m::SampledLogDensity, x, level=length(m), cache=nothing) = logdensity(m, x, level, cache) logdensity(m::SampledLogDensity, x; level=length(m), cache=nothing) = _logdensity(m, x, level, cache)
function logdensity(m::SampledLogDensity, x, level::Int64, cache::Nothing) function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Nothing)
y = [m.func(x, z) for z=m.samples[1:level]] #TODO: parallelize evaluations of m.func ! y = [m.func(x, z) for z=m.samples[1:level]] #TODO: parallelize evaluations of m.func !
return mean(y) return mean(y)
end end
function logdensity(m::SampledLogDensity, x, level::Int64, cache::Tuple{Int64, <:Number}) function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Tuple{Int64, <:Number})
if cache[1] <= level if cache[1] <= level
y = [m.func(x, z) for z=m.samples[cache[1]+1:level]] y = [m.func(x, z) for z=m.samples[cache[1]+1:level]]
return mean(y) * length(y)/level + cache[2] * cache[1]/level return mean(y) * length(y)/level + cache[2] * cache[1]/level
else else
return logdensity(m, x, level, nothing) return logdensity(m, x; level, nothing)
end end
end end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment