Skip to content
Snippets Groups Projects
christen_fox.jl 6.19 KiB
"""
 Delayed Acceptance algorithm

 If `saveproxies == true` save log-density lower levels. 
 This is particularly useful for Multilevel integration.

 References
    
    * Christen, J.A. and Fox, C. (2005). "Markov chain Monte Carlo Using an Approximation"
      Journal of Computational and Graphical Statistics

""" 
struct ChristenFox{saveproxies, P <: AbstractProposal} <: RejectionBasedSampler
    proposal :: P
end 
ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{saveproxies, typeof(proposal)}(proposal)

## Save samples only for highest level
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P}
    (level, x, f_x) = sample
    return (; 
        states = typeof(x)[],
        logprobs = typeof(f_x)[],
        rejections = Vector{Int}[], 
    )
end
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P}
    (level, x, f_x) = sample
    
    # accept  (at hightest level) 
    if level == length(model)                
        push!(samples.states, x)
        push!(samples.logprobs, f_x)
        push!(samples.rejections, zeros(Int, length(model)))

    # reject
    else samples.rejections[end][level+1] += 1 end
    
    return samples
end


## Save each level

function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs...) where {P}
    (level, x, f_x) = sample
    return (; 
        states = typeof(x)[],
        logprobs = typeof(f_x)[],
        rejections = Vector{Int}[], 
        current = Ref{Int}(0)  # reference to current top level state
    )
end
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs... ) where {P}
    (level, x, f_x) = sample

    # accept 
    if level == length(model)        
        push!(samples.states, x)
        push!(samples.logprobs, f_x)
        push!(samples.rejections, zeros(Int, length(model)))

        samples.current[] = length(samples.states) # reset current to accepted 
    else
        samples.rejections[samples.current[]][level+1] += 1 # update rejection counter 
        
        # store promoted
        if level > 0 
            push!(samples.states, x)
            push!(samples.logprobs, f_x)
            push!(samples.rejections, zeros(Int, length(model)))
        end
    end 
    return samples
end



## General chain stepping

function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::ChristenFox; x0=nothing, f0=nothing, kwargs...)
    
    # Initialize states
    _resample = isnothing(x0)  
    x = _resample ? rand(rng, sampler.proposal) : x0
    
    # Initialize logprobs
    _recompute = _resample || isnothing(f0)
    f_x = _recompute ?  [ logdensity(model, x; level=l) for l=1:length(model) ] : f0

    return (length(model), x, f_x), (x, f_x)
end


function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox, state; kwargs...)

    # Load old state
    x, f_x = state

    # Propose
    y = propose(rng, sampler.proposal, x)
    f_y = [ logdensity(model, y; level=1) ]
    q = logpratio(sampler.proposal, x, y)
    
    # Promotion probability
    A_1 = min( f_y[1] - f_x[1] + q, 0)
    accept = log(rand(rng)) < A_1
    if !accept return (0, x, f_x), (x, f_x) end # completly reject, never promoted
    
    # Promotion loop
    for l = 2:length(model)
        push!(f_y, logdensity(model, y; level=l) )

        # Next promotion
        A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1]   
        accept = log(rand(rng)) < A_l 
        if !accept return (l-1, y, f_y), (x, f_x) end # rejected at level l, promoted to l-1 
    end

    # Accept at highest level
    return (length(model), y, f_y), (y, f_y) 
end


## Chain stepping specialized for Sampled-Based LogDensities, uses cached evaluation 

function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox; x0=nothing, f0=nothing, kwargs...)
    
    # Initialize states
    _resample = isnothing(x0)  
    x = _resample ? rand(rng, sampler.proposal) : x0
    
    # Initialize logprobs
    _recompute = _resample || isnothing(f0)
    if _recompute 

        # Recursively cached
        f_x = [ logdensity(model, x; level=1) ]
        sizehint!(f_x, length(model))
        for l=2:length(model)
            push!(f_x, logdensity(model, x; level=l, cache=f_x[end]))
        end 
    else f_x = f0 end

    return (length(model), x, f_x), (x, f_x)
end


function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox, state; kwargs...)
    # Load old state
    x, f_x = state

    # Propose
    y = propose(rng, sampler.proposal, x)
    f_y = [ logdensity(model, y; level=1) ]
    q = logpratio(sampler.proposal, x, y)
    
    # Promotion probability
    A_1 = min( f_y[1] - f_x[1] + q, 0)
    accept = log(rand(rng)) < A_1
    if !accept return (0, x, f_x), (x, f_x) end # completly reject, never promoted
    
    # Promotion loop
    for l = 2:length(model)
        push!(f_y, logdensity(model, y; level=l, cache=f_y[end]) ) # > Only this line changes!

        # Next promotion
        A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1]   
        accept = log(rand(rng)) < A_l 
        if !accept return (l-1, y, f_y), (x, f_x) end # rejected at level l, promoted to l-1 
    end

    # Accept at highest level
    return (length(model), y, f_y), (y, f_y) 
end


function AbstractMCMC.bundle_samples(samples, ::AbstractMultilevelModel, ::ChristenFox, state, chain_type::Type{<:NamedTuple}; kwargs...)
    return (; states = samples.states, logprobs = samples.logprobs, rejections = samples.rejections)
end

function _chain_info(samples, model::AbstractMultilevelModel, ::ChristenFox)
    n_total = sum( 1 .+ sum.(samples.rejections))
    n_reject = sum(samples.rejections)
    info = Dict(
        :chain_length   => n_total,
        :rejection_rate => n_reject ./ n_total,
    )

    if model isa MultilevelSampledLogDensity
        evals_per_level = n_total .- cumsum([0, n_reject[1:end-1]... ])
        costs_per_level = model.nlevels .- [0, model.nlevels[1:end-1]...] 
        info[:total_costs] = sum( evals_per_level .* costs_per_level )
    end

    return info
end