Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# each chain may have different length
struct RejectionChains{S,L,R} <: AbstractMCMC.AbstractChains
samples::Vector{
Vector{
@NamedTuple begin
state::S
logprob::L
reject::R
end
}
}
end
# construct single chain
function RejectionChains(s::Vector{S}, l::Vector{L}, r::Vector{R}) where {S,L,R}
samples = NamedTuple{ (:state, :logprob, :reject)}.([zip(s, l, r)...] )
return RejectionChains{S,L,R}([samples])
end
length(chains::RejectionChains) = length(chains.samples)
eltype(chains::RejectionChains{S,L,R}) where {S,L,R} = S
getindex(chains::RejectionChains, id::Integer) = chains.samples[id]
getindex(chains::RejectionChains, ids::OrdinalRange) = RejectionChains(chains.samples[ids])
function AbstractMCMC.chainscat(c::RejectionChains ... )
RejectionChains( vcat( getfield.(c, :samples) ... ) )
end
# Repeats rejected values
function states(chains::RejectionChains{S,L,R}) where {S,L,R}
V = Vector{S}[]; sizehint!(V, length(chains))
for i=1:length(chains)
x = chains.samples[i]
s = getfield.(x, :state)
r = getfield.(x, :reject)
r = 1 .+ sum.(r)
v = vcat( fill.(s, r) ... )
push!(V, v)
end
return V # optionally: vcat(V ... )
end
# Get info on chain
function info(chains::RejectionChains)
data = Dict()
rejections = [getfield.(x, :reject) for x=chains.samples]
data[:chain_length] = [ sum(1 .+ sum.(r)) for r=rejections ]
data[:rejection_rate] = sum.(rejections) ./ data[:chain_length]
return data
end
# Convert to table
function convert(::Type{<:DataFrame}, chains::RejectionChains)
df = DataFrame()
for i=1:length(chains)
for v in (:state, :logprob, :reject)
df["$(v)_$i"] = getfield.(chains.sample[i], v)
end
end
return df
end