# each chain may have different length struct RejectionChains{S,L,R} <: AbstractMCMC.AbstractChains samples::Vector{ Vector{ @NamedTuple begin state::S logprob::L reject::R end } } info::Dict{Symbol, Vector} end # calculate rejection info function RejectionChains(samples::Vector{Vector{@NamedTuple{state::S, logprob::L, reject::R}}}) where {S,L,R} rejections = [ getfield.(s, :reject) for s=samples ] data = Dict{Symbol, Vector}() data[:chain_length] = [ sum(1 .+ sum.(r)) for r=rejections ] data[:rejection_rate] = sum.(rejections) ./ data[:chain_length] RejectionChains(samples, data) end # construct single chain function RejectionChains(s::Vector{S}, l::Vector{L}, r::Vector{R}) where {S,L,R} samples = NamedTuple{ (:state, :logprob, :reject)}.([zip(s, l, r)...] ) return RejectionChains([samples]) end length(chains::RejectionChains) = length(chains.samples) eltype(chains::RejectionChains{S,L,R}) where {S,L,R} = S getindex(chains::RejectionChains, id::Integer) = chains.samples[id] getindex(chains::RejectionChains, ids::OrdinalRange) = RejectionChains(chains.samples[ids]) function AbstractMCMC.chainscat(c::RejectionChains ... ) infos = getfield.(c, :info) lens = length.(c) dict = Dict{Symbol, Vector}() for k in union(keys.(infos) ... ) vals = [ haskey(infos[i],k) ? infos[i][k] : repeat([missing], lens[i]) for i = 1:length(c) ] println dict[Symbol(k)] = vcat(vals ... ) end RejectionChains( vcat( getfield.(c, :samples) ... ), dict) end is_multilevel(chains::RejectionChains{S,L,R}) where {S,L,R} = R <: Vector # Repeats rejected values function states(chains::RejectionChains{S,L,R}) where {S,L,R} V = Vector{S}[]; sizehint!(V, length(chains)) for i=1:length(chains) x = chains.samples[i] s = getfield.(x, :state) r = getfield.(x, :reject) r = 1 .+ sum.(r) v = vcat( fill.(s, r) ... ) push!(V, v) end return V # optionally: vcat(V ... ) end # Get info on chain info(chains::RejectionChains) = chains.info # Convert to table function convert(::Type{<:DataFrame}, chains::RejectionChains) df = DataFrame() for i=1:length(chains) for v in (:state, :logprob, :reject) df["$(v)_$i"] = getfield.(chains.sample[i], v) end end return df end