Skip to content
Snippets Groups Projects
lykkegard_scheichl.jl 6.64 KiB
Newer Older
"""
 Recursive Delayed Acceptance algorithm


 References
    
    * Mikkel B. Lykkegaard, T. Dodwell, C. Fox, Grigorios Mingas, Robert Scheichl (2022). "Multilevel Delayed Acceptance MCMC"
    SIAM/ASA J. Uncertain. Quantification

""" 


struct LykkegaardScheichl{saveproxies, P <: AbstractProposal, N} <: RejectionBasedSampler
    proposal :: P
    sublen :: N # sub-chain length (may be a distribution)
end 
LykkegaardScheichl(p,s=1,saveproxies::Bool=true) = LykkegaardScheichl{saveproxies, typeof(p), typeof(s)}(p,s)

subchainlength(rng::AbstractRNG, s::LykkegaardScheichl{saveproxies,P,<:Distribution}) where {saveproxies,P} = rand(rng, s.sublen)
subchainlength(rng::AbstractRNG, s::LykkegaardScheichl{saveproxies,P,<:Integer}) where {saveproxies,P} = s.sublen 


## Only save top level

function AbstractMCMC.samples(sample, mode::AbstractMultilevelModel, sampler::LykkegaardScheichl{false,P,N}; L=length(model), kwargs...) where {P,N}
    accept, x, f_x, r_x, Y = sample 
    return (; 
        states = typeof(x)[], 
        logprobs = typeof(f_x)[], 
        rejections = typeof(r_x)[], 
    )
end

function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, sampler::LykkegaardScheichl{false, P,N}) where {P,N}
    accept, x, f_x, r_x, Y = sample
    
    if accept # save new entry
        push!(samples.states, x)
        push!(samples.logprobs, f_x)
        push!(samples.rejections, r_x )

    else # update rejection counter
        samples.rejections[end][L] += 1
    end
    return samples

end

## Save proxies

function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, sampler::LykkegaardScheichl{true,P,N}; L=length(model), kwargs...) where {P,N}
    accept, x, f_x, r_x, Y = sample 
    return (; 
        states = typeof(x)[], 
        logprobs = typeof(f_x)[], 
        rejections = typeof(r_x)[], 

        current=Ref{Int}(0) # reference to current top level state
    )
end


function AbstractMCMC.save!!(samples, sample, iter::Integer, model::AbstractMultilevelModel, sampler::LykkegaardScheichl{true, P,N}; L=length(model), kwargs... ) where {P,N}
    accept, x, f_x, r_x, Y = sample
    
    # Save proxies
    append!(samples.states, Y.states)
    append!(samples.logprobs, Y.logprobs)
    append!(samples.rejections, Y.rejections)

    if accept
        push!(samples.states, x)
        push!(samples.logprobs, f_x)
        push!(samples.rejections, r_x)

        samples.current[] = length(samples.states)  # update reference
    else
        samples.rejections[samples.current[]][L] += 1 # reference, 
    end

    return samples
end 

# Initialize chain 
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::LykkegaardScheichl; x0=nothing, f0=nothing, L=length(model), kwargs...) 
    x = !isnothing(x0) ? x0 : rand(rng, sampler.proposal)
    if !isnothing(x0) && !isnothing(f0) 
        if (f0 isa Number) && (L == 1) 
            f_x = [f0] 
        elseif (f0 isa Vector)      
            f_x = f0[1:L] 
        end
    else
        f_x = [ logdensity(model, x; level=l) for l=1:L ]
    end
    Y = (; states=typeof(x)[], logprobs=typeof(f_x)[], rejections=Vector{Int}[])
    return (true, x, f_x, zeros(Int, L), Y), (x, f_x)
end

function AbstractMCMC.step(rng::AbstractRNG, model::MultilevelSampledLogDensity, sampler::LykkegaardScheichl; x0=nothing, f0=nothing, L=length(model), kwargs...)
    x = !isnothing(x0) ? x0 : rand(rng, sampler.proposal)
    if !isnothing(x0) && !isnothing(f0) 
        if (f0 isa Number && L == 1) 
            f_x = [f0] 
        elseif (f0 isa Vector)      
            f_x = f0[1:L] 
        end
    else
        f_x = [ logdensity(model, x; level=1) ]
        sizehint!(f_x, L)
        for l=2:L
            push!(f_x, logdensity(model, x; level=l, cache=f_x[end]))
        end 
    end
    Y = (; states=typeof(x)[], logprobs=typeof(f_x)[], rejections=Vector{Int}[])
    return (true, x, f_x, zeros(Int, L), Y), (x, f_x)
end
 

function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::LykkegaardScheichl, state; L=length(model), kwargs...)
    x, f_x = state

    # Sample length at random
    n = 1 + subchainlength(rng, sampler)
    
    if L == 2 # reduces to iterated MH 

        # sample subchain
        mh = MetropolisHastings(sampler.proposal)
        model_1 = LogDensity(x->logdensity(model, x, level=1))
        c = sample(rng, model_1, mh, n, chain_type=NamedTuple, x0=x,f0=f_x[1])
        
        # chain in between
        Y = (; states     = c.states[2:end-1], 
               logprobs   = map(x->[x],   c.logprobs[2:end-1]), 
               rejections = map(x->[x,0], c.rejections[2:end-1])
            )
        
        # end of chain
        if length(Y.states) == 0
            return (false, x, f_x, Vector{Int}[], Y), (x, f_x)
        end

        y = c.states[end]
        f_y = [c.logprobs[end],   logdensity(model, y, level=2)]
        r_y = [c.rejections[end], zeros(Int, L-1)... ]

    else # Recursion 

        # sample subchain
        c = sample(rng, sampler, model, n, chain_type=(;), L=L-1, x0=x,f0=f_x[1:L-1], discard_initial=1)
        
        # chain in between
        Y = (; states     = c.states[2:end-1], 
               logprobs   = c.logprobs[2:end-1], 
               rejections = map(x->[x,0], c.rejections[2:end-1])
            )

        # end of chain
        if length(c.states) == 0
            return (false, x, f_x, Vector{Int}[], Y), (x, f_x)
        end
        y = c.states[end]
        f_y = [c.logprobs[end]...,   logdensity(model, y, level=L)] 
        r_y = [c.rejections[end]..., 0 ]
    end

    # accept/reject step
    A = min( f_y[L] - f_x[L] - f_y[L-1] + f_x[L-1], 0)
    accept = log(rand(rng)) < A

    if accept 
        return (true, y, f_y, r_y, Y), (y, f_y)
    else
        return (false, x, f_x, Vector{Int}[], Y), (x, f_x)
    end
end

function AbstractMCMC.bundle_samples(samples, model::AbstractMultilevelModel, sampler::LykkegaardScheichl, state, chain_type::Type{<:NamedTuple}; kwargs...) 
    return (; states = samples.states, logprobs = samples.logprobs, rejections = samples.rejections )
end

function _chain_info(samples, model::AbstractMultilevelModel, sampler::LykkegaardScheichl) 
    n_total = sum( 1 .+ sum.(samples.rejections))    
    n_reject = sum(samples.rejections)
    info = Dict(
        :chain_length   => n_total,
        :rejection_rate => n_reject ./ n_total,
    )

    if model isa MultilevelSampledLogDensity    # just copied from Christen-Fox sampler, is this valid ...?
        evals_per_level = n_total .- cumsum([0, n_reject[1:end-1]... ])
        costs_per_level = model.nlevels .- [0, model.nlevels[1:end-1]...] 
        info[:total_costs] = sum( evals_per_level .* costs_per_level )
    end

    return info

end