Skip to content
Snippets Groups Projects
rejection.jl 3 KiB
Newer Older
# each chain may have different length
struct RejectionChains{S,L,R} <: AbstractRejectionChains
    samples::Vector{ 
        Vector{ 
            @NamedTuple begin
                state::S
                logprob::L
                reject::R
            end
        }
    }
    info::Dict{Symbol, Vector}
end


# construct with empty info 
function RejectionChains(samples::Vector{Vector{@NamedTuple{state::S, logprob::L, reject::R}}}) where {S,L,R}    
    RejectionChains(samples, Dict{Symbol, Vector}())
end


# construct single chain
function RejectionChains(s::Vector{S}, l::Vector{L}, r::Vector{R}) where {S,L,R}
    samples = NamedTuple{ (:state, :logprob, :reject) }.([zip(s, l, r)...] )
    return RejectionChains([samples])
end

# concatination method
function AbstractMCMC.chainscat(c::RejectionChains ... )
    infos = getfield.(c, :info)
    lens = length.(c)
    dict = Dict{Symbol, Vector}()
    for k in union(keys.(infos) ... )
        vals = [ haskey(infos[i],k) ? 
                    infos[i][k] : repeat([missing], lens[i]) 
                    for i = 1:length(c) 
                ]
        dict[Symbol(k)] = vcat(vals ... )
    end
    RejectionChains( vcat( getfield.(c, :samples) ... ), dict)
end


# Properties
nchains(chains::RejectionChains) = length(chains.samples)
length(chains::RejectionChains) = sum(length.(chains.samples)) # total number of unique samples (=storage costs)
eltype(chains::RejectionChains{S,L,R}) where {S,L,R} = S       # state type

# Get info on chain
info(chains::RejectionChains) = chains.info

# Get subset from list of chains samples
function getindex(chains::RejectionChains, id::Integer)
    samples = [chains.samples[id]]
    info = Dict((k=>[chains.info[k][id]] for k in keys(chains.info)) ... ) 
    RejectionChains(samples, info)
end
function getindex(chains::RejectionChains, ids::OrdinalRange) 
    samples = chains.samples[ids]
    info = Dict((k=>chains.info[k][ids] for k in keys(chains.info)) ... ) 
    RejectionChains(samples, infos)
end

# Multilevel logprobs or rejection 
is_multilevel(chains::RejectionChains{S,L,R}) where {S,L,R} = (L <: Union{AbstractVector, <:Tuple, <:NamedTuple})
levels(chains::RejectionChains{S,L,R}) where {S, L <: Number, R} = ones(Int, nchains(chains))

function levels(chains::RejectionChains{S,L,R}) where {S, L <: Union{AbstractVector, <:Tuple, <:NamedTuple}, R} 
    return [
        length.(getfield.(chains.samples[i], :logprob)) 
        for i=1:nchains(chains)
    ]
end

# get number of repetitions
function repetitions(chains::RejectionChains)
    return [ 
        1 .+ sum.(getfield.(chains.samples[i], :reject)) 
        for i=1:nchains(chains) 
    ]
end

# Chain of values with repetition of where rejected 
function states(chains::RejectionChains{S,L,R}) where {S,L,R}
    V = Vector{S}[]
    sizehint!(V, nchains(chains))
    reps = repetitions(chains)
    for i=1:nchains(chains)
        x = chains.samples[i]
        s = getfield.(x, :state)
        v = vcat( fill.(s, reps[i]) ... )
        push!(V, v)
    end
    return V # optionally: vcat(V ... )
end