Skip to content
Snippets Groups Projects
Commit 1d5abea8 authored by Luca Lenz's avatar Luca Lenz
Browse files

fixed logdensity keyword args

parent 9fd88371
No related branches found
No related tags found
No related merge requests found
"""
Delayed Acceptance algorithm
If `saveproxies == true` save accepted states on lower levels.
If `saveproxies == true` save log-density lower levels.
This is particularly useful for MLMC integration.
References
......@@ -17,13 +17,16 @@ ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{s
# Initialize samples container
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P}
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs...) where {P}
return (; rejections=[zeros(Int, length(model))], transitions=[sample[2:end]])
end
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P}
return (; rejections=[zeros(Int, length(model))], transitions=[(sample[2], sample[3][end], sample[4:end]...)])
end
# Store sample to container
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P}
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs... ) where {P}
if sample[1] == length(model)+1 # sample was accepted
push!(samples.rejections, zeros(Int, length(model)))
push!(samples.transitions, sample[2:end])
......@@ -33,36 +36,78 @@ function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilev
return samples
end
# Chain step
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P}
if sample[1] == length(model)+1 # sample was accepted
push!(samples.rejections, zeros(Int, length(model)))
push!(samples.transitions, ( sample[2], sample[3][end], sample[4:end]...))
else
samples.rejections[end][sample[1]] += 1
end
return samples
end
# Initialize chain
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox; kwargs...)
x = rand(rng, sampler.proposal)
f_x = [ logdensity(model, x; level=l) for l=1:length(model) ]
return (length(model), x, f_x), (x, f_x)
end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}; kwargs...) where {P}
function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox; kwargs...)
x = rand(rng, sampler.proposal)
f_x = [ logdensity(model, x, i) for i=1:length(model) ]
return (length(model), x, f_x[end]), (x, f_x)
f_x = [ logdensity(model, x; level=1) ]
sizehint!(f_x, length(model))
for l=2:length(model)
push!(f_x, logdensity(model, x; level=l, cache=f_x[end]))
end
return (length(model), x, f_x), (x, f_x)
end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}, state; kwargs...) where {P}
# Chain stepping
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox, state; kwargs...)
x, f_x = state
y = propose(rng, sampler.proposal, x)
f_y = [ logdensity(model, y, 1) ]
f_y = [ logdensity(model, y; level=1) ]
q = logpratio(sampler.proposal, x, y)
A_1 = min( f_y[1] - f_x[1] + q, 0)
accept = log(rand(rng)) < A_1
if !accept return (1, x, f_x[end]), (x, f_x) end
if !accept return (1, x, f_x), (x, f_x) end
for l = 2:length(model)
push!(f_y, logdensity(model, y, l) )
push!(f_y, logdensity(model, y; level=l) )
A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1]
accept = log(rand(rng)) < A_l
if !accept return (l, y, f_y[end]), (x, f_x) end
if !accept return (l, y, f_y), (x, f_x) end
end
return (length(model)+1, y, f_y), (y, f_y)
end
function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox, state; kwargs...)
x, f_x = state
y = propose(rng, sampler.proposal, x)
f_y = [ logdensity(model, y; level=1) ]
q = logpratio(sampler.proposal, x, y)
A_1 = min( f_y[1] - f_x[1] + q, 0)
accept = log(rand(rng)) < A_1
if !accept return (1, x, f_x), (x, f_x) end
for l = 2:length(model)
push!(f_y, logdensity(model, y; level=l, cache=f_y[end]) )
A_l = (f_y[l] - f_x[l]) - (f_y[l-1] - f_x[l-1])
accept = log(rand(rng)) < A_l
if !accept return (l, y, f_y), (x, f_x) end
end
return (length(model)+1, y, f_y[end]), (y, f_y)
return (length(model)+1, y, f_y), (y, f_y)
end
# Collect samples
function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox{false,P}, state, chain_type::Type; kwargs ... ) where {P}
function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox, state, chain_type::Type; kwargs ... )
states = getindex.(samples.transitions, 1)
logprobs = getindex.(samples.transitions, 2)
info = Dict()
......@@ -71,7 +116,8 @@ function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::Chri
info[:rejection_rate] = sum(samples.rejections) ./ N
if m isa MultilevelSampledLogDensity
nevals = N .- cumsum([0, sum(samples.rejections)[1:end-1]...])
info[:evaluations] = sum( m.nlevels .* nevals )
nlevels = m.nlevels[1:end] .- [0, m.nlevels[1:end-1]...]
info[:evaluations] = sum( nlevels .* nevals )
end
return SimpleChains(states, logprobs, samples.rejections; info...)
end
......
......@@ -4,9 +4,9 @@ struct LogDensity{L} <: AbstractLogDensity
density :: L
end
logdensity(model::LogDensity{<:Function}, x) = model.density(x)
logdensity(model::LogDensity{<:Distribution}, x) = logpdf(model.density, x)
logdensity(model::AbstractMCMC.LogDensityModel, x) = LogDensityProblems.logdensity(model.logdensity, x)
logdensity(model::LogDensity{<:Function}, x; kwargs...) = model.density(x)
logdensity(model::LogDensity{<:Distribution}, x; kwargs...) = logpdf(model.density, x)
logdensity(model::AbstractMCMC.LogDensityModel, x; kwargs...) = LogDensityProblems.logdensity(model.logdensity, x)
include("samplebased.jl")
include("multilevel.jl")
\ No newline at end of file
......@@ -7,7 +7,7 @@ MultilevelLogDensity(v::L... ) where {L <: Tuple{Vararg{<:AbstractModel}}} = Mul
MultilevelLogDensity(v::Vector{ <: AbstractModel} ) = MultilevelLogDensity( tuple(v...) )
length(m::MultilevelLogDensity) = length(m.proxies)
logdensity(m::MultilevelLogDensity, x, level=length(m)) = logdensity(m.proxies[level], x)
logdensity(m::MultilevelLogDensity, x; level=length(m), kwargs...) = logdensity(m.proxies[level], x)
function Base.show(io::IO, m::SampledLogDensity)
println(io, "MultilevelLogDensity ", length(m), " {", eltype(m.samples), "} samples")
......@@ -21,10 +21,10 @@ end
MultilevelSampledLogDensity(f::Function, n::Vector{<:Integer}, d::Int=1, a=0, b=1) = MultilevelSampledLogDensity(SampledLogDensity(f, maximum(n), d, a, b), n)
length(m::MultilevelSampledLogDensity) = length(m.nlevels)
logdensity(m::MultilevelSampledLogDensity, x, level=length(m), cache=nothing) = logdensity(m.density, x, level, cache)
logdensity(m::MultilevelSampledLogDensity, x; level=length(m), cache=nothing, kwargs...) = _logdensity(m, x, level, cache)
logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Nothing) = logdensity(m.density, x, m.nlevels[level])
logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Number) = logdensity(m.density, x, m.nlevels[level], (level-1, cache))
_logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Nothing) = logdensity(m.density, x; level=m.nlevels[level])
_logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Number) = logdensity(m.density, x; level=m.nlevels[level], cache=(level-1, cache))
......
......@@ -12,18 +12,18 @@ end
length(m::SampledLogDensity) = length(m.samples)
logdensity(m::SampledLogDensity, x, level=length(m), cache=nothing) = logdensity(m, x, level, cache)
logdensity(m::SampledLogDensity, x; level=length(m), cache=nothing) = _logdensity(m, x, level, cache)
function logdensity(m::SampledLogDensity, x, level::Int64, cache::Nothing)
function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Nothing)
y = [m.func(x, z) for z=m.samples[1:level]] #TODO: parallelize evaluations of m.func !
return mean(y)
end
function logdensity(m::SampledLogDensity, x, level::Int64, cache::Tuple{Int64, <:Number})
function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Tuple{Int64, <:Number})
if cache[1] <= level
y = [m.func(x, z) for z=m.samples[cache[1]+1:level]]
return mean(y) * length(y)/level + cache[2] * cache[1]/level
else
return logdensity(m, x, level, nothing)
return logdensity(m, x; level, nothing)
end
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment