Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
abstract type RejectionBasedSampler <: AbstractMCMC.AbstractSampler end
# Ignore chain length by default
function AbstractMCMC.save!!(samples, sample, iterations::Integer, model::AbstractModel, sampler::RejectionBasedSampler, ::Integer; kwargs...)
AbstractMCMC.save!!(samples, sample, iterations, model, sampler; kwargs...)
end
# Size hint sample container chain length
function AbstractMCMC.samples(sample, model::AbstractModel, sampler::RejectionBasedSampler, N::Integer; kwargs...)
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...)
T = typeof(samples)
for (k,t) in zip(T.parameters[1], T.types)
if t <: AbstractVector
sizehint!(getfield(samples, k), N)
end
end
return samples
end
function AbstractMCMC.bundle_samples(samples, m::AbstractModel, s::RejectionBasedSampler, state, chain_type::Type; kwargs... )
AbstractMCMC.bundle_samples(samples, m, s, state, RejectionChains; kwargs... )
end
function _chain_info(samples, model::AbstractModel, sampler::RejectionBasedSampler)
return Dict{Symbol, Vector}()
end
function AbstractMCMC.bundle_samples(samples, m::AbstractModel, s::RejectionBasedSampler, state, chain_type::Type{<:AbstractRejectionChains}; kwargs... )
x = AbstractMCMC.bundle_samples(samples, m, s, state, NamedTuple; kwargs...)
c = RejectionChains( x.states, x.logprobs, x.rejections )
data = _chain_info(samples, m, s)
for k in keys(data)
c.info[k] = [data[k]]
end
return c
end
include("metropolis_hastings.jl")
include("christen_fox.jl")
include("lykkegard_scheichl.jl")