Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
struct SimpleChains{X, L <: Number} <: RejectionBasedChains
states :: Matrix{X}
logprobs :: Matrix{L}
info :: Vector{<:NamedTuple}
end
Base.eltype(c::SimpleChains{X,L}) where {X,L} = X
Base.size(c::SimpleChains, args...) = size(c.states, args...)
function get_info(c::SimpleChains)
fields = fieldnames(eltype(c.info))
(; (f => getfield.(c.info, f) for f=fields) ... )
end
function SimpleChains(states::Vector, logprobs::Vector, rejections::Vector; info... )
repetitions = sum.(rejections) .+ 1
N = sum(repetitions)
states = reshape(vcat(fill.(states, repetitions)...), N, 1)
logprobs = reshape(vcat(fill.(logprobs, repetitions)...), N, 1)
info = [(; rejections = sum(rejections), info...) ]
SimpleChains(states, logprobs, info)
end
function AbstractMCMC.chainscat(chains::SimpleChains ... )
DefaultChains(
hcat((c.states for c in chains) ... ),
hcat((c.logprobs for c in chains) ... ),
vcat((c.info for c in chains)...)
)
end
#=
function Base.show(io::IO, c::SimpleChains)
println(io, "Chain ", size(c), " {", eltype(c), "}")
nt = eltype(c.info)
for param in nt.parameters[1]
x = getfield.(c.info, param)
if size(c, 2) == 1
println(io, " ", param, " ", x)
else
println(io, " ", param, " ", mean(x), " ± ", std(x) )
end
end
end
=#