""" Delayed Acceptance algorithm If `saveproxies == true` save log-density lower levels. This is particularly useful for MLMC integration. References * Christen, J.A. and Fox, C. (2005). "Markov chain Monte Carlo Using an Approximation" Journal of Computational and Graphical Statistics """ struct ChristenFox{saveproxies, P <: AbstractProposal} <: RejectionBasedSampler proposal :: P end ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{saveproxies, typeof(proposal)}(proposal) # Initialize samples container function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs...) where {P} return (; rejections=[zeros(Int, length(model))], transitions=[sample[2:end]]) end function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P} return (; rejections=[zeros(Int, length(model))], transitions=[(sample[2], sample[3][end], sample[4:end]...)]) end # Store sample to container function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{true, P}; kwargs... ) where {P} if sample[1] == length(model)+1 # sample was accepted push!(samples.rejections, zeros(Int, length(model))) push!(samples.transitions, sample[2:end]) else samples.rejections[end][sample[1]] += 1 #TODO: set rejection counter of last top level sample instead! end return samples end function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P} if sample[1] == length(model)+1 # sample was accepted push!(samples.rejections, zeros(Int, length(model))) push!(samples.transitions, ( sample[2], sample[3][end], sample[4:end]...)) else samples.rejections[end][sample[1]] += 1 end return samples end # Initialize chain function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox; kwargs...) x = rand(rng, sampler.proposal) f_x = [ logdensity(model, x; level=l) for l=1:length(model) ] return (length(model), x, f_x), (x, f_x) end function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox; kwargs...) x = rand(rng, sampler.proposal) f_x = [ logdensity(model, x; level=1) ] sizehint!(f_x, length(model)) for l=2:length(model) push!(f_x, logdensity(model, x; level=l, cache=f_x[end])) end return (length(model), x, f_x), (x, f_x) end # Chain stepping function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox, state; kwargs...) x, f_x = state y = propose(rng, sampler.proposal, x) f_y = [ logdensity(model, y; level=1) ] q = logpratio(sampler.proposal, x, y) A_1 = min( f_y[1] - f_x[1] + q, 0) accept = log(rand(rng)) < A_1 if !accept return (1, x, f_x), (x, f_x) end for l = 2:length(model) push!(f_y, logdensity(model, y; level=l) ) A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1] accept = log(rand(rng)) < A_l if !accept return (l, y, f_y), (x, f_x) end end return (length(model)+1, y, f_y), (y, f_y) end function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::ChristenFox, state; kwargs...) x, f_x = state y = propose(rng, sampler.proposal, x) f_y = [ logdensity(model, y; level=1) ] q = logpratio(sampler.proposal, x, y) A_1 = min( f_y[1] - f_x[1] + q, 0) accept = log(rand(rng)) < A_1 if !accept return (1, x, f_x), (x, f_x) end for l = 2:length(model) push!(f_y, logdensity(model, y; level=l, cache=f_y[end]) ) A_l = (f_y[l] - f_x[l]) - (f_y[l-1] - f_x[l-1]) accept = log(rand(rng)) < A_l if !accept return (l, y, f_y), (x, f_x) end end return (length(model)+1, y, f_y), (y, f_y) end function total_costs(c::RejectionChains, m::MultilevelSampledLogDensity) costs = Int[]; sizehint!(costs, length(c)) for i = 1:length(c) N = c.info[:chain_length][i] r = getfield.(c.samples[i], :reject) nevals = N .- cumsum([0, sum(r)[1:end-1]...]) nlevels = m.nlevels[1:end] .- [0, m.nlevels[1:end-1]...] push!(costs, sum(nlevels .* nevals)) end return costs end function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox, state, chain_type::Type; kwargs...) s = getindex.(samples.transitions, 1) l = getindex.(samples.transitions, 2) r = samples.rejections c = RejectionChains(s,l,r) if m isa MultilevelSampledLogDensity c.info[:total_costs] = total_costs(c, m) end return c end