""" Delayed Acceptance algorithm If `saveproxies == true` save log-density lower levels. This is particularly useful for Multilevel integration. References * Christen, J.A. and Fox, C. (2005). "Markov chain Monte Carlo Using an Approximation" Journal of Computational and Graphical Statistics """ struct ChristenFox{saveproxies, P <: AbstractProposal} <: RejectionBasedSampler proposal :: P end ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{saveproxies, typeof(proposal)}(proposal) ## Save samples only for highest level function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P} (level, x, f_x) = sample return (; states = typeof(x)[], logprobs = typeof(f_x)[], rejections = Vector{Int}[], ) end function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P} (level, x, f_x) = sample # accept (at hightest level) if level == length(model) push!(samples.states, x) push!(samples.logprobs, f_x) push!(samples.rejections, zeros(Int, length(model))) # reject else samples.rejections[end][level+1] += 1 end return samples end ## Save each level function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs...) where {P} (level, x, f_x) = sample return (; states = typeof(x)[], logprobs = typeof(f_x)[], rejections = Vector{Int}[], current = Ref{Int}(0) # reference to current top level state ) end function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs... ) where {P} (level, x, f_x) = sample # accept if level == length(model) push!(samples.states, x) push!(samples.logprobs, f_x) push!(samples.rejections, zeros(Int, length(model))) samples.current[] = length(samples.states) # reset current to accepted else samples.rejections[samples.current[]][level+1] += 1 # update rejection counter # store promoted if level > 0 push!(samples.states, x) push!(samples.logprobs, f_x) push!(samples.rejections, zeros(Int, length(model))) end end return samples end ## General chain stepping function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::ChristenFox; x0=nothing, f0=nothing, kwargs...) # Initialize states _resample = isnothing(x0) x = _resample ? rand(rng, sampler.proposal) : x0 # Initialize logprobs _recompute = _resample || isnothing(f0) f_x = _recompute ? [ logdensity(model, x; level=l) for l=1:length(model) ] : f0 return (length(model), x, f_x), (x, f_x) end function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox, state; kwargs...) # Load old state x, f_x = state # Propose y = propose(rng, sampler.proposal, x) f_y = [ logdensity(model, y; level=1) ] q = logpratio(sampler.proposal, x, y) # Promotion probability A_1 = min( f_y[1] - f_x[1] + q, 0) accept = log(rand(rng)) < A_1 if !accept return (0, x, f_x), (x, f_x) end # completly reject, never promoted # Promotion loop for l = 2:length(model) push!(f_y, logdensity(model, y; level=l) ) # Next promotion A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1] accept = log(rand(rng)) < A_l if !accept return (l-1, y, f_y), (x, f_x) end # rejected at level l, promoted to l-1 end # Accept at highest level return (length(model), y, f_y), (y, f_y) end ## Chain stepping specialized for Sampled-Based LogDensities, uses cached evaluation function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox; x0=nothing, f0=nothing, kwargs...) # Initialize states _resample = isnothing(x0) x = _resample ? rand(rng, sampler.proposal) : x0 # Initialize logprobs _recompute = _resample || isnothing(f0) if _recompute # Recursively cached f_x = [ logdensity(model, x; level=1) ] sizehint!(f_x, length(model)) for l=2:length(model) push!(f_x, logdensity(model, x; level=l, cache=f_x[end])) end else f_x = f0 end return (length(model), x, f_x), (x, f_x) end function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox, state; kwargs...) # Load old state x, f_x = state # Propose y = propose(rng, sampler.proposal, x) f_y = [ logdensity(model, y; level=1) ] q = logpratio(sampler.proposal, x, y) # Promotion probability A_1 = min( f_y[1] - f_x[1] + q, 0) accept = log(rand(rng)) < A_1 if !accept return (0, x, f_x), (x, f_x) end # completly reject, never promoted # Promotion loop for l = 2:length(model) push!(f_y, logdensity(model, y; level=l, cache=f_y[end]) ) # > Only this line changes! # Next promotion A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1] accept = log(rand(rng)) < A_l if !accept return (l-1, y, f_y), (x, f_x) end # rejected at level l, promoted to l-1 end # Accept at highest level return (length(model), y, f_y), (y, f_y) end function AbstractMCMC.bundle_samples(samples, ::AbstractMultilevelModel, ::ChristenFox, state, chain_type::Type{<:NamedTuple}; kwargs...) return (; states = samples.states, logprobs = samples.logprobs, rejections = samples.rejections) end function _chain_info(samples, model::AbstractMultilevelModel, ::ChristenFox) n_total = sum( 1 .+ sum.(samples.rejections)) n_reject = sum(samples.rejections) info = Dict( :chain_length => n_total, :rejection_rate => n_reject ./ n_total, ) if model isa MultilevelSampledLogDensity evals_per_level = n_total .- cumsum([0, n_reject[1:end-1]... ]) costs_per_level = model.nlevels .- [0, model.nlevels[1:end-1]...] info[:total_costs] = sum( evals_per_level .* costs_per_level ) end return info end