# each chain may have different length struct RejectionChains{S,L,R} <: AbstractMCMC.AbstractChains samples::Vector{ Vector{ @NamedTuple begin state::S logprob::L reject::R end } } end # construct single chain function RejectionChains(s::Vector{S}, l::Vector{L}, r::Vector{R}) where {S,L,R} samples = NamedTuple{ (:state, :logprob, :reject)}.([zip(s, l, r)...] ) return RejectionChains{S,L,R}([samples]) end length(chains::RejectionChains) = length(chains.samples) eltype(chains::RejectionChains{S,L,R}) where {S,L,R} = S getindex(chains::RejectionChains, id::Integer) = chains.samples[id] getindex(chains::RejectionChains, ids::OrdinalRange) = RejectionChains(chains.samples[ids]) function AbstractMCMC.chainscat(c::RejectionChains ... ) RejectionChains( vcat( getfield.(c, :samples) ... ) ) end # Repeats rejected values function states(chains::RejectionChains{S,L,R}) where {S,L,R} V = Vector{S}[]; sizehint!(V, length(chains)) for i=1:length(chains) x = chains.samples[i] s = getfield.(x, :state) r = getfield.(x, :reject) r = 1 .+ sum.(r) v = vcat( fill.(s, r) ... ) push!(V, v) end return V # optionally: vcat(V ... ) end # Get info on chain function info(chains::RejectionChains) data = Dict() rejections = [getfield.(x, :reject) for x=chains.samples] data[:chain_length] = [ sum(1 .+ sum.(r)) for r=rejections ] data[:rejection_rate] = sum.(rejections) ./ data[:chain_length] return data end # Convert to table function convert(::Type{<:DataFrame}, chains::RejectionChains) df = DataFrame() for i=1:length(chains) for v in (:state, :logprob, :reject) df["$(v)_$i"] = getfield.(chains.sample[i], v) end end return df end function is_multilevel(chains::RejectionChains{S,L,R}) where {S,L,R} return R <: Vector end