Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
struct MultiChains{X, L <: Number} <: RejectionBasedChains
states :: Matrix{X}
logprobs :: Matrix{L}
levels :: Matrix{<:Integer}
info :: Vector{<:NamedTuple}
end
Base.eltype(c::MultiChains{X,L}) where {X,L} = X
Base.length(c::MultiChains) = length(c.states)
function MultiChains(states::Vector, logprobs::Vector, levels::Vector, rejections::Vector; info... )
N = sum(rejections)
repetitions = sum.(rejections) .+ 1
MultiChains(
reshape(vcat(fill.(states, repetitions)), N, 1),
reshape(vcat(fill.(logprobs, repetitions)), N, 1),
reshape(vcat(fill.(levels, repetitions)), N, 1),
[(; rejection_rate = sum(rejections) ./ N, info...) ]
)
end
function AbstractMCMC.chainscat(chains::MultiChains ... )
DefaultChains(
hcat((c.states for c in chains) ... ),
hcat((c.logprobs for c in chains) ... ),
hcat((c.levels for c in chains) ... ),
vcat((c.info for c in chains)...)
)
end
function Base.show(io::IO, c::SimpleChains)
println(io, "Chain ", size(c), " {", eltype(c), "}")
nt = eltype(c.info)
for param in nt.parameters[1]
x = getfield.(c.info, param)
if size(c, 2) == 1
println(io, " ", param, " ", x)
else
println(io, " ", param, " ", mean(x), " ± ", std(x) )
end
end
end