Skip to content
Snippets Groups Projects
tables.jl 756 B
Newer Older


# Write to table
function convert(::Type{<:AbstractVector{<:DataFrame}}, chains::RejectionChains)
    dfs = DataFrame[]; sizehint!(dfs,length(chains.samples))
    for i=1:nchains(chains)
        n = length(chains.samples[i])
        df = DataFrame()
        for k in (:state, :logprob, :reject)
             df[:, k] = getfield.(chains.samples[i], k)
        end
        push!(dfs, df)
    end
    return dfs
end

# Read from table
function convert(::Type{<:RejectionChains}, df::DataFrame)
    s = df[:, :state] 
    l = df[:, :logprob] 
    r = df[:, :reject]
    c = RejectionChains(s,l,r)
    return c
end

function convert(::Type{<:RejectionChains}, df::Vector{<:DataFrame})
    return AbstractMCMC.chainscat(convert.(RejectionChains, df) ... )
end