Skip to content
Snippets Groups Projects
all.jl 1.57 KiB
Newer Older
abstract type RejectionBasedSampler <: AbstractMCMC.AbstractSampler end

# Ignore chain length by default
function AbstractMCMC.save!!(samples, sample, iterations::Integer, model::AbstractModel, sampler::RejectionBasedSampler, ::Integer; kwargs...)
    AbstractMCMC.save!!(samples, sample, iterations, model, sampler; kwargs...)
end

# Size hint sample container chain length 
function AbstractMCMC.samples(sample, model::AbstractModel, sampler::RejectionBasedSampler, N::Integer; kwargs...)
    samples = AbstractMCMC.samples(sample, model, sampler; kwargs...)
    T = typeof(samples)
    for (k,t) in zip(T.parameters[1], T.types)
        if t <: AbstractVector
            sizehint!(getfield(samples, k), N) 
        end
    end
    return samples
end

function AbstractMCMC.bundle_samples(samples, m::AbstractModel, s::RejectionBasedSampler, state, chain_type::Type; kwargs... )
    AbstractMCMC.bundle_samples(samples, m, s, state, RejectionChains; kwargs... )
end

function _chain_info(samples, model::AbstractModel, sampler::RejectionBasedSampler)
    return Dict{Symbol, Vector}()
end

function AbstractMCMC.bundle_samples(samples, m::AbstractModel, s::RejectionBasedSampler, state, chain_type::Type{<:AbstractRejectionChains}; kwargs... )
    x = AbstractMCMC.bundle_samples(samples, m, s, state, NamedTuple; kwargs...)
    c = RejectionChains( x.states, x.logprobs, x.rejections )
    data = _chain_info(samples, m, s)
    for k in keys(data)
        c.info[k] = [data[k]]
    end
    return c
end



include("metropolis_hastings.jl")
include("christen_fox.jl")
include("lykkegard_scheichl.jl")