# each chain may have different length struct RejectionChains{S,L,R} <: AbstractRejectionChains samples::Vector{ Vector{ @NamedTuple begin state::S logprob::L reject::R end } } info::Dict{Symbol, Vector} end # construct with empty info function RejectionChains(samples::Vector{Vector{@NamedTuple{state::S, logprob::L, reject::R}}}) where {S,L,R} RejectionChains(samples, Dict{Symbol, Vector}()) end # construct single chain function RejectionChains(s::Vector{S}, l::Vector{L}, r::Vector{R}) where {S,L,R} samples = NamedTuple{ (:state, :logprob, :reject) }.([zip(s, l, r)...] ) return RejectionChains([samples]) end # concatination method function AbstractMCMC.chainscat(c::RejectionChains ... ) infos = getfield.(c, :info) lens = length.(c) dict = Dict{Symbol, Vector}() for k in union(keys.(infos) ... ) vals = [ haskey(infos[i],k) ? infos[i][k] : repeat([missing], lens[i]) for i = 1:length(c) ] dict[Symbol(k)] = vcat(vals ... ) end RejectionChains( vcat( getfield.(c, :samples) ... ), dict) end # Properties nchains(chains::RejectionChains) = length(chains.samples) length(chains::RejectionChains) = sum(length.(chains.samples)) # total number of unique samples (=storage costs) eltype(chains::RejectionChains{S,L,R}) where {S,L,R} = S # state type # Get info on chain info(chains::RejectionChains) = chains.info # Get subset from list of chains samples function getindex(chains::RejectionChains, id::Integer) samples = [chains.samples[id]] info = Dict((k=>[chains.info[k][id]] for k in keys(chains.info)) ... ) RejectionChains(samples, info) end function getindex(chains::RejectionChains, ids::OrdinalRange) samples = chains.samples[ids] info = Dict((k=>chains.info[k][ids] for k in keys(chains.info)) ... ) RejectionChains(samples, infos) end # Multilevel logprobs or rejection is_multilevel(chains::RejectionChains{S,L,R}) where {S,L,R} = (L <: Union{AbstractVector, <:Tuple, <:NamedTuple}) levels(chains::RejectionChains{S,L,R}) where {S, L <: Number, R} = ones(Int, nchains(chains)) function levels(chains::RejectionChains{S,L,R}) where {S, L <: Union{AbstractVector, <:Tuple, <:NamedTuple}, R} return [ length.(getfield.(chains.samples[i], :logprob)) for i=1:nchains(chains) ] end # get number of repetitions function repetitions(chains::RejectionChains) return [ 1 .+ sum.(getfield.(chains.samples[i], :reject)) for i=1:nchains(chains) ] end # Chain of values with repetition of where rejected function states(chains::RejectionChains{S,L,R}) where {S,L,R} V = Vector{S}[] sizehint!(V, nchains(chains)) reps = repetitions(chains) for i=1:nchains(chains) x = chains.samples[i] s = getfield.(x, :state) v = vcat( fill.(s, reps[i]) ... ) push!(V, v) end return V # optionally: vcat(V ... ) end