Newer
Older
# each chain may have different length
struct RejectionChains{S,L,R} <: AbstractMCMC.AbstractChains
samples::Vector{
Vector{
@NamedTuple begin
state::S
logprob::L
reject::R
end
}
}
info::Dict{Symbol, Vector}
# calculate rejection info
function RejectionChains(samples::Vector{Vector{@NamedTuple{state::S, logprob::L, reject::R}}}) where {S,L,R}
rejections = [ getfield.(s, :reject) for s=samples ]
data = Dict{Symbol, Vector}()
data[:chain_length] = [ sum(1 .+ sum.(r)) for r=rejections ]
data[:rejection_rate] = sum.(rejections) ./ data[:chain_length]
RejectionChains(samples, data)
end
# construct single chain
function RejectionChains(s::Vector{S}, l::Vector{L}, r::Vector{R}) where {S,L,R}
samples = NamedTuple{ (:state, :logprob, :reject)}.([zip(s, l, r)...] )
return RejectionChains([samples])
end
length(chains::RejectionChains) = length(chains.samples)
eltype(chains::RejectionChains{S,L,R}) where {S,L,R} = S
getindex(chains::RejectionChains, id::Integer) = chains.samples[id]
getindex(chains::RejectionChains, ids::OrdinalRange) = RejectionChains(chains.samples[ids])
function AbstractMCMC.chainscat(c::RejectionChains ... )
infos = getfield.(c, :info)
lens = length.(c)
dict = Dict{Symbol, Vector}()
for k in union(keys.(infos) ... )
vals = [ haskey(infos[i],k) ?
infos[i][k] : repeat([missing], lens[i])
for i = 1:length(c)
]
println
dict[Symbol(k)] = vcat(vals ... )
end
RejectionChains( vcat( getfield.(c, :samples) ... ), dict)
is_multilevel(chains::RejectionChains{S,L,R}) where {S,L,R} = R <: Vector
# Repeats rejected values
function states(chains::RejectionChains{S,L,R}) where {S,L,R}
V = Vector{S}[]; sizehint!(V, length(chains))
for i=1:length(chains)
x = chains.samples[i]
s = getfield.(x, :state)
r = getfield.(x, :reject)
r = 1 .+ sum.(r)
v = vcat( fill.(s, r) ... )
push!(V, v)
end
return V # optionally: vcat(V ... )
end
# Get info on chain
info(chains::RejectionChains) = chains.info
# Convert to table
function convert(::Type{<:DataFrame}, chains::RejectionChains)
df = DataFrame()
for i=1:length(chains)
for v in (:state, :logprob, :reject)
df["$(v)_$i"] = getfield.(chains.sample[i], v)
end
end
return df
end