Skip to content
Snippets Groups Projects
multi.jl 1.33 KiB
Newer Older
struct MultiChains{X, L <: Number} <: RejectionBasedChains
    states :: Matrix{X}
    logprobs :: Matrix{L}
    levels :: Matrix{<:Integer}
    info :: Vector{<:NamedTuple}
end
Base.eltype(c::MultiChains{X,L}) where {X,L} = X
Base.length(c::MultiChains) = length(c.states)


function MultiChains(states::Vector, logprobs::Vector, levels::Vector, rejections::Vector; info... )
    N = sum(rejections)
    repetitions = sum.(rejections) .+ 1
    MultiChains(
        reshape(vcat(fill.(states,   repetitions)),   N, 1),
        reshape(vcat(fill.(logprobs, repetitions)), N, 1),
        reshape(vcat(fill.(levels,   repetitions)), N, 1),
        [(; rejection_rate = sum(rejections) ./ N, info...) ]
    )
end



function AbstractMCMC.chainscat(chains::MultiChains ... )
    DefaultChains(
        hcat((c.states for c in chains) ... ),
        hcat((c.logprobs for c in chains) ... ),
        hcat((c.levels for c in chains) ... ),
        vcat((c.info for c in chains)...)
    )
end

function Base.show(io::IO, c::SimpleChains)
    println(io, "Chain ", size(c), " {", eltype(c), "}")
    nt = eltype(c.info)
    for param in nt.parameters[1]
        x = getfield.(c.info, param)
        if size(c, 2) == 1 
            println(io, "  ", param, " ", x)
        else
            println(io, "  ", param, " ", mean(x), " ± ", std(x) )
        end
    end
end