struct SimpleChains{X, L <: Number} <: RejectionBasedChains states :: Matrix{X} logprobs :: Matrix{L} info :: Vector{<:NamedTuple} end Base.eltype(c::SimpleChains{X,L}) where {X,L} = X Base.size(c::SimpleChains, args...) = size(c.states, args...) function get_info(c::SimpleChains) fields = fieldnames(eltype(c.info)) (; (f => getfield.(c.info, f) for f=fields) ... ) end function SimpleChains(states::Vector, logprobs::Vector, rejections::Vector; info... ) repetitions = sum.(rejections) .+ 1 N = sum(repetitions) states = reshape(vcat(fill.(states, repetitions)...), N, 1) logprobs = reshape(vcat(fill.(logprobs, repetitions)...), N, 1) info = [(; rejections = sum(rejections), info...) ] SimpleChains(states, logprobs, info) end function AbstractMCMC.chainscat(chains::SimpleChains ... ) DefaultChains( hcat((c.states for c in chains) ... ), hcat((c.logprobs for c in chains) ... ), vcat((c.info for c in chains)...) ) end #= function Base.show(io::IO, c::SimpleChains) println(io, "Chain ", size(c), " {", eltype(c), "}") nt = eltype(c.info) for param in nt.parameters[1] x = getfield.(c.info, param) if size(c, 2) == 1 println(io, " ", param, " ", x) else println(io, " ", param, " ", mean(x), " ± ", std(x) ) end end end =#