Skip to content
Snippets Groups Projects
christen_fox.jl 4.69 KiB
Newer Older
"""
 Delayed Acceptance algorithm

 If `saveproxies == true` save log-density lower levels. 
 This is particularly useful for MLMC integration.

 References
    
    * Christen, J.A. and Fox, C. (2005). "Markov chain Monte Carlo Using an Approximation"
      Journal of Computational and Graphical Statistics

""" 
struct ChristenFox{saveproxies, P <: AbstractProposal} <: RejectionBasedSampler
    proposal :: P
end 
ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{saveproxies, typeof(proposal)}(proposal)


# Initialize samples container
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs...) where {P}
    return (; rejections=[zeros(Int, length(model))], transitions=[sample[2:end]])
end
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P}
    return (; rejections=[zeros(Int, length(model))], transitions=[(sample[2], sample[3][end], sample[4:end]...)])
end


# Store sample to container
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs... ) where {P}
    if sample[1] == length(model)+1 # sample was accepted
        push!(samples.rejections, zeros(Int, length(model)))
        push!(samples.transitions, sample[2:end])
    else 
Luca Lenz's avatar
Luca Lenz committed
        samples.rejections[end][sample[1]] += 1 #TODO: set rejection counter of last top level sample instead!  
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P}
    if sample[1] == length(model)+1 # sample was accepted
        push!(samples.rejections, zeros(Int, length(model)))
        push!(samples.transitions, ( sample[2], sample[3][end], sample[4:end]...))
    else 
        samples.rejections[end][sample[1]] += 1 
    end
    return samples
end


# Initialize chain 
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox; kwargs...) 
    x   = rand(rng, sampler.proposal)
    f_x = [ logdensity(model, x; level=l) for l=1:length(model) ]
    return (length(model), x, f_x), (x, f_x)
end
function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox; kwargs...)
    x   = rand(rng, sampler.proposal)
    f_x = [ logdensity(model, x; level=1) ]
    sizehint!(f_x, length(model))
    for l=2:length(model)
        push!(f_x, logdensity(model, x; level=l, cache=f_x[end]))
    end 
    return (length(model), x, f_x), (x, f_x)

# Chain stepping

function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox, state; kwargs...)
    x, f_x = state
    y = propose(rng, sampler.proposal, x)
    f_y = [ logdensity(model, y; level=1) ]
    q = logpratio(sampler.proposal, x, y)
    
    A_1 = min( f_y[1] - f_x[1] + q, 0)
    accept = log(rand(rng)) < A_1
    if !accept return (1, x, f_x), (x, f_x) end
    
    for l = 2:length(model)
        push!(f_y, logdensity(model, y; level=l) )
        A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1]   
        accept = log(rand(rng)) < A_l 
        if !accept return (l, y, f_y), (x, f_x) end 
    end
    return (length(model)+1, y, f_y), (y, f_y)
end

function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox, state; kwargs...)
    x, f_x = state
    y = propose(rng, sampler.proposal, x)
    f_y = [ logdensity(model, y; level=1) ]
    q = logpratio(sampler.proposal, x, y)
    
    A_1 = min( f_y[1] - f_x[1] + q, 0)
    accept = log(rand(rng)) < A_1
    if !accept return (1, x, f_x), (x, f_x) end
    
    for l = 2:length(model)
        push!(f_y, logdensity(model, y; level=l, cache=f_y[end]) )
        A_l = (f_y[l] - f_x[l]) - (f_y[l-1] - f_x[l-1])   
        accept = log(rand(rng)) < A_l 
        if !accept return (l, y, f_y), (x, f_x) end 
    return (length(model)+1, y, f_y), (y, f_y)
function total_costs(c::RejectionChains, m::MultilevelSampledLogDensity)
    costs = Int[]; sizehint!(costs, length(c))
    for i = 1:length(c)
        N = c.info[:chain_length][i]
        r = getfield.(c.samples[i], :reject)
        nevals = N .- cumsum([0, sum(r)[1:end-1]...]) 
        nlevels = m.nlevels[1:end] .- [0, m.nlevels[1:end-1]...]
        push!(costs, sum(nlevels .* nevals)) 
function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox, state, chain_type::Type; kwargs...)
Luca Lenz's avatar
Luca Lenz committed
    s = getindex.(samples.transitions, 1)
    l = getindex.(samples.transitions, 2)
    r = samples.rejections
    c = RejectionChains(s,l,r)

    if m isa MultilevelSampledLogDensity
        c.info[:total_costs] = total_costs(c, m)
    end

    return c
Luca Lenz's avatar
Luca Lenz committed
end