Skip to content
Snippets Groups Projects
rejection_chains.jl 2.45 KiB
Newer Older
# each chain may have different length
struct RejectionChains{S,L,R} <: AbstractMCMC.AbstractChains
    samples::Vector{ 
        Vector{ 
            @NamedTuple begin
                state::S
                logprob::L
                reject::R
            end
        }
    }
    info::Dict{Symbol, Vector}
# calculate rejection info
function RejectionChains(samples::Vector{Vector{@NamedTuple{state::S, logprob::L, reject::R}}}) where {S,L,R}
    rejections = [ getfield.(s, :reject) for s=samples ]
    data = Dict{Symbol, Vector}()
    data[:chain_length] = [ sum(1 .+ sum.(r)) for r=rejections ]
    data[:rejection_rate] = sum.(rejections) ./ data[:chain_length]
    RejectionChains(samples, data)
end


# construct single chain
function RejectionChains(s::Vector{S}, l::Vector{L}, r::Vector{R}) where {S,L,R}
    samples = NamedTuple{ (:state, :logprob, :reject)}.([zip(s, l, r)...] )
    return RejectionChains([samples])
end

length(chains::RejectionChains) = length(chains.samples) 
eltype(chains::RejectionChains{S,L,R}) where {S,L,R} = S

getindex(chains::RejectionChains, id::Integer) = chains.samples[id]
getindex(chains::RejectionChains, ids::OrdinalRange) = RejectionChains(chains.samples[ids])

function AbstractMCMC.chainscat(c::RejectionChains ... )
    infos = getfield.(c, :info)
    lens = length.(c)
    dict = Dict{Symbol, Vector}()
    for k in union(keys.(infos) ... )
        vals = [ haskey(infos[i],k) ? 
                    infos[i][k] : repeat([missing], lens[i]) 
                    for i = 1:length(c) 
                ]
        println
        dict[Symbol(k)] = vcat(vals ... )
    end
    RejectionChains( vcat( getfield.(c, :samples) ... ), dict)
is_multilevel(chains::RejectionChains{S,L,R}) where {S,L,R} = R <: Vector

# Repeats rejected values
function states(chains::RejectionChains{S,L,R}) where {S,L,R}
    V = Vector{S}[]; sizehint!(V, length(chains))
    for i=1:length(chains)
        x = chains.samples[i]
        s = getfield.(x, :state)
        r = getfield.(x, :reject)
        r = 1 .+ sum.(r)
        v = vcat( fill.(s, r) ... )
        push!(V, v)
    end
    return V # optionally: vcat(V ... )
end

# Get info on chain
info(chains::RejectionChains) = chains.info

# Convert to table
function convert(::Type{<:DataFrame}, chains::RejectionChains)
    df = DataFrame()
    for i=1:length(chains)
        for v in (:state, :logprob, :reject)
            df["$(v)_$i"] = getfield.(chains.sample[i], v)
        end
    end
    return df
end