""" Recursive Delayed Acceptance algorithm References * Mikkel B. Lykkegaard, T. Dodwell, C. Fox, Grigorios Mingas, Robert Scheichl (2022). "Multilevel Delayed Acceptance MCMC" SIAM/ASA J. Uncertain. Quantification """ struct LykkegaardScheichl{saveproxies, P <: AbstractProposal, N} <: RejectionBasedSampler proposal :: P sublen :: N # sub-chain length (may be a distribution) end LykkegaardScheichl(p,s=1,saveproxies::Bool=true) = LykkegaardScheichl{saveproxies, typeof(p), typeof(s)}(p,s) subchainlength(rng::AbstractRNG, s::LykkegaardScheichl{saveproxies,P,<:Distribution}) where {saveproxies,P} = rand(rng, s.sublen) subchainlength(rng::AbstractRNG, s::LykkegaardScheichl{saveproxies,P,<:Integer}) where {saveproxies,P} = s.sublen ## Only save top level function AbstractMCMC.samples(sample, mode::AbstractMultilevelModel, sampler::LykkegaardScheichl{false,P,N}; L=length(model), kwargs...) where {P,N} accept, x, f_x, r_x, Y = sample return (; states = typeof(x)[], logprobs = typeof(f_x)[], rejections = typeof(r_x)[], ) end function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, sampler::LykkegaardScheichl{false, P,N}) where {P,N} accept, x, f_x, r_x, Y = sample if accept # save new entry push!(samples.states, x) push!(samples.logprobs, f_x) push!(samples.rejections, r_x ) else # update rejection counter samples.rejections[end][L] += 1 end return samples end ## Save proxies function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, sampler::LykkegaardScheichl{true,P,N}; L=length(model), kwargs...) where {P,N} accept, x, f_x, r_x, Y = sample return (; states = typeof(x)[], logprobs = typeof(f_x)[], rejections = typeof(r_x)[], current=Ref{Int}(0) # reference to current top level state ) end function AbstractMCMC.save!!(samples, sample, iter::Integer, model::AbstractMultilevelModel, sampler::LykkegaardScheichl{true, P,N}; L=length(model), kwargs... ) where {P,N} accept, x, f_x, r_x, Y = sample # Save proxies append!(samples.states, Y.states) append!(samples.logprobs, Y.logprobs) append!(samples.rejections, Y.rejections) if accept push!(samples.states, x) push!(samples.logprobs, f_x) push!(samples.rejections, r_x) samples.current[] = length(samples.states) # update reference else samples.rejections[samples.current[]][L] += 1 # reference, end return samples end # Initialize chain function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::LykkegaardScheichl; x0=nothing, f0=nothing, L=length(model), kwargs...) x = !isnothing(x0) ? x0 : rand(rng, sampler.proposal) if !isnothing(x0) && !isnothing(f0) if (f0 isa Number) && (L == 1) f_x = [f0] elseif (f0 isa Vector) f_x = f0[1:L] end else f_x = [ logdensity(model, x; level=l) for l=1:L ] end Y = (; states=typeof(x)[], logprobs=typeof(f_x)[], rejections=Vector{Int}[]) return (true, x, f_x, zeros(Int, L), Y), (x, f_x) end function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::LykkegaardScheichl; x0=nothing, f0=nothing, L=length(model), kwargs...) x = !isnothing(x0) ? x0 : rand(rng, sampler.proposal) if !isnothing(x0) && !isnothing(f0) if (f0 isa Number && L == 1) f_x = [f0] elseif (f0 isa Vector) f_x = f0[1:L] end else f_x = [ logdensity(model, x; level=1) ] sizehint!(f_x, L) for l=2:L push!(f_x, logdensity(model, x; level=l, cache=f_x[end])) end end Y = (; states=typeof(x)[], logprobs=typeof(f_x)[], rejections=Vector{Int}[]) return (true, x, f_x, zeros(Int, L), Y), (x, f_x) end function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::LykkegaardScheichl, state; L=length(model), kwargs...) x, f_x = state # Sample length at random n = 1 + subchainlength(rng, sampler) if L == 2 # reduces to iterated MH # sample subchain mh = MetropolisHastings(sampler.proposal) model_1 = LogDensity(x->logdensity(model, x, level=1)) c = sample(rng, model_1, mh, n, chain_type=NamedTuple, x0=x,f0=f_x[1]) # chain in between Y = (; states = c.states[2:end-1], logprobs = map(x->[x], c.logprobs[2:end-1]), rejections = map(x->[x,0], c.rejections[2:end-1]) ) # end of chain if length(Y.states) == 0 return (false, x, f_x, Vector{Int}[], Y), (x, f_x) end y = c.states[end] f_y = [c.logprobs[end], logdensity(model, y, level=2)] r_y = [c.rejections[end], zeros(Int, L-1)... ] else # Recursion # sample subchain c = sample(rng, sampler, model, n, chain_type=(;), L=L-1, x0=x,f0=f_x[1:L-1], discard_initial=1) # chain in between Y = (; states = c.states[2:end-1], logprobs = c.logprobs[2:end-1], rejections = map(x->[x,0], c.rejections[2:end-1]) ) # end of chain if length(c.states) == 0 return (false, x, f_x, Vector{Int}[], Y), (x, f_x) end y = c.states[end] f_y = [c.logprobs[end]..., logdensity(model, y, level=L)] r_y = [c.rejections[end]..., 0 ] end # accept/reject step A = min( f_y[L] - f_x[L] - f_y[L-1] + f_x[L-1], 0) accept = log(rand(rng)) < A if accept return (true, y, f_y, r_y, Y), (y, f_y) else return (false, x, f_x, Vector{Int}[], Y), (x, f_x) end end function AbstractMCMC.bundle_samples(samples, model::AbstractMultilevelModel, sampler::LykkegaardScheichl, state, chain_type::Type{<:NamedTuple}; kwargs...) return (; states = samples.states, logprobs = samples.logprobs, rejections = samples.rejections ) end function _chain_info(samples, model::AbstractMultilevelModel, sampler::LykkegaardScheichl) n_total = sum( 1 .+ sum.(samples.rejections)) n_reject = sum(samples.rejections) info = Dict( :chain_length => n_total, :rejection_rate => n_reject ./ n_total, ) if model isa MultilevelSampledLogDensity # just copied from Christen-Fox sampler, is this valid ...? evals_per_level = n_total .- cumsum([0, n_reject[1:end-1]... ]) costs_per_level = model.nlevels .- [0, model.nlevels[1:end-1]...] info[:total_costs] = sum( evals_per_level .* costs_per_level ) end return info end