""" Delayed Acceptance algorithm If `saveproxies == true` save accepted states on lower levels. This is particularly useful for MLMC integration. References * Christen, J.A. and Fox, C. (2005). "Markov chain Monte Carlo Using an Approximation" Journal of Computational and Graphical Statistics """ struct ChristenFox{saveproxies, P <: AbstractProposal} <: RejectionBasedSampler proposal :: P end ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{saveproxies, typeof(proposal)}(proposal) # Initialize samples container function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P} return (; rejections=[zeros(Int, length(model))], transitions=[sample[2:end]]) end # Store sample to container function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P} if sample[1] == length(model)+1 # sample was accepted push!(samples.rejections, zeros(Int, length(model))) push!(samples.transitions, sample[2:end]) else samples.rejections[end][sample[1]] += 1 end return samples end # Chain step function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}; kwargs...) where {P} x = rand(rng, sampler.proposal) f_x = [ logdensity(model, x, i) for i=1:length(model) ] return (length(model), x, f_x[end]), (x, f_x) end function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}, state; kwargs...) where {P} x, f_x = state y = propose(rng, sampler.proposal, x) f_y = [ logdensity(model, y, 1) ] q = logpratio(sampler.proposal, x, y) A_1 = min( f_y[1] - f_x[1] + q, 0) accept = log(rand(rng)) < A_1 if !accept return (1, x, f_x[end]), (x, f_x) end for l = 2:length(model) push!(f_y, logdensity(model, y, l) ) A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1] accept = log(rand(rng)) < A_l if !accept return (l, y, f_y[end]), (x, f_x) end end return (length(model)+1, y, f_y[end]), (y, f_y) end # Collect samples function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox{false,P}, state, chain_type::Type; kwargs ... ) where {P} states = getindex.(samples.transitions, 1) logprobs = getindex.(samples.transitions, 2) info = Dict() N = sum( sum.(samples.rejections) .+ 1) info[:rejection_rate] = sum(samples.rejections) ./ N if m isa MultilevelSampledLogDensity nevals = N .- cumsum([0, sum(samples.rejections)[1:end-1]...]) info[:evaluations] = sum( m.nlevels .* nevals ) end return SimpleChains(states, logprobs, samples.rejections; info...) end ## Modified to save lower level accepted steps