struct MultiChains{X, L <: Number} <: RejectionBasedChains states :: Matrix{X} logprobs :: Matrix{L} levels :: Matrix{<:Integer} info :: Vector{<:NamedTuple} end Base.eltype(c::MultiChains{X,L}) where {X,L} = X Base.length(c::MultiChains) = length(c.states) function MultiChains(states::Vector, logprobs::Vector, levels::Vector, rejections::Vector; info... ) N = sum(rejections) repetitions = sum.(rejections) .+ 1 MultiChains( reshape(vcat(fill.(states, repetitions)), N, 1), reshape(vcat(fill.(logprobs, repetitions)), N, 1), reshape(vcat(fill.(levels, repetitions)), N, 1), [(; rejection_rate = sum(rejections) ./ N, info...) ] ) end function AbstractMCMC.chainscat(chains::MultiChains ... ) DefaultChains( hcat((c.states for c in chains) ... ), hcat((c.logprobs for c in chains) ... ), hcat((c.levels for c in chains) ... ), vcat((c.info for c in chains)...) ) end function Base.show(io::IO, c::SimpleChains) println(io, "Chain ", size(c), " {", eltype(c), "}") nt = eltype(c.info) for param in nt.parameters[1] x = getfield.(c.info, param) if size(c, 2) == 1 println(io, " ", param, " ", x) else println(io, " ", param, " ", mean(x), " ± ", std(x) ) end end end