Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# each chain may have different length
struct RejectionChains{S,L,R} <: AbstractRejectionChains
samples::Vector{
Vector{
@NamedTuple begin
state::S
logprob::L
reject::R
end
}
}
info::Dict{Symbol, Vector}
end
# construct with empty info
function RejectionChains(samples::Vector{Vector{@NamedTuple{state::S, logprob::L, reject::R}}}) where {S,L,R}
RejectionChains(samples, Dict{Symbol, Vector}())
end
# construct single chain
function RejectionChains(s::Vector{S}, l::Vector{L}, r::Vector{R}) where {S,L,R}
samples = NamedTuple{ (:state, :logprob, :reject) }.([zip(s, l, r)...] )
return RejectionChains([samples])
end
# concatination method
function AbstractMCMC.chainscat(c::RejectionChains ... )
infos = getfield.(c, :info)
lens = length.(c)
dict = Dict{Symbol, Vector}()
for k in union(keys.(infos) ... )
vals = [ haskey(infos[i],k) ?
infos[i][k] : repeat([missing], lens[i])
for i = 1:length(c)
]
dict[Symbol(k)] = vcat(vals ... )
end
RejectionChains( vcat( getfield.(c, :samples) ... ), dict)
end
# Properties
nchains(chains::RejectionChains) = length(chains.samples)
length(chains::RejectionChains) = sum(length.(chains.samples)) # total number of unique samples (=storage costs)
eltype(chains::RejectionChains{S,L,R}) where {S,L,R} = S # state type
# Get info on chain
info(chains::RejectionChains) = chains.info
# Get subset from list of chains samples
function getindex(chains::RejectionChains, id::Integer)
samples = [chains.samples[id]]
info = Dict((k=>[chains.info[k][id]] for k in keys(chains.info)) ... )
RejectionChains(samples, info)
end
function getindex(chains::RejectionChains, ids::OrdinalRange)
samples = chains.samples[ids]
info = Dict((k=>chains.info[k][ids] for k in keys(chains.info)) ... )
RejectionChains(samples, infos)
end
# Multilevel logprobs or rejection
is_multilevel(chains::RejectionChains{S,L,R}) where {S,L,R} = (L <: Union{AbstractVector, <:Tuple, <:NamedTuple})
levels(chains::RejectionChains{S,L,R}) where {S, L <: Number, R} = ones(Int, nchains(chains))
function levels(chains::RejectionChains{S,L,R}) where {S, L <: Union{AbstractVector, <:Tuple, <:NamedTuple}, R}
return [
length.(getfield.(chains.samples[i], :logprob))
for i=1:nchains(chains)
]
end
# get number of repetitions
function repetitions(chains::RejectionChains)
return [
1 .+ sum.(getfield.(chains.samples[i], :reject))
for i=1:nchains(chains)
]
end
# Chain of values with repetition of where rejected
function states(chains::RejectionChains{S,L,R}) where {S,L,R}
V = Vector{S}[]
sizehint!(V, nchains(chains))
reps = repetitions(chains)
for i=1:nchains(chains)
x = chains.samples[i]
s = getfield.(x, :state)
v = vcat( fill.(s, reps[i]) ... )
push!(V, v)
end
return V # optionally: vcat(V ... )
end