Skip to content
Snippets Groups Projects
christen_fox.jl 2.83 KiB
Newer Older
"""
 Delayed Acceptance algorithm

 If `saveproxies == true` save accepted states on lower levels. 
 This is particularly useful for MLMC integration.

 References
    
    * Christen, J.A. and Fox, C. (2005). "Markov chain Monte Carlo Using an Approximation"
      Journal of Computational and Graphical Statistics

""" 
struct ChristenFox{saveproxies, P <: AbstractProposal} <: RejectionBasedSampler
    proposal :: P
end 
ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{saveproxies, typeof(proposal)}(proposal)


# Initialize samples container
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P}
    return (; rejections=[zeros(Int, length(model))], transitions=[sample[2:end]])
end


# Store sample to container
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P}
    if sample[1] == length(model)+1 # sample was accepted
        push!(samples.rejections, zeros(Int, length(model)))
        push!(samples.transitions, sample[2:end])
    else 
        samples.rejections[end][sample[1]] += 1 
    end
    return samples
end

# Chain step

function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}; kwargs...) where {P}
    x   = rand(rng, sampler.proposal)
    f_x = [ logdensity(model, x, i) for i=1:length(model) ]
    return (length(model), x, f_x[end]), (x, f_x)
end

function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}, state; kwargs...) where {P}
    x, f_x = state
    y = propose(rng, sampler.proposal, x)
    f_y = [ logdensity(model, y, 1) ]
    q = logpratio(sampler.proposal, x, y)
    
    A_1 = min( f_y[1] - f_x[1] + q, 0)
    accept = log(rand(rng)) < A_1
    if !accept return (1, x, f_x[end]), (x, f_x) end
    
    for l = 2:length(model)
        push!(f_y, logdensity(model, y, l) )
        A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1]   
        accept = log(rand(rng)) < A_l 
        if !accept return (l, y, f_y[end]), (x, f_x) end 
    end
    return (length(model)+1, y, f_y[end]), (y, f_y)
end


# Collect samples
function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox{false,P}, state, chain_type::Type; kwargs ... ) where {P}
    states   = getindex.(samples.transitions, 1)
    logprobs = getindex.(samples.transitions, 2)
    info = Dict()

    N = sum( sum.(samples.rejections) .+ 1)    
    info[:rejection_rate] = sum(samples.rejections) ./ N
    if m isa MultilevelSampledLogDensity
        nevals = N .- cumsum([0, sum(samples.rejections)[1:end-1]...]) 
        info[:evaluations] = sum( m.nlevels .* nevals )
    end
    return SimpleChains(states, logprobs, samples.rejections; info...)
end



## Modified to save lower level accepted steps