Skip to content
Snippets Groups Projects
simple.jl 1.36 KiB
Newer Older
struct SimpleChains{X, L <: Number} <: RejectionBasedChains
    states :: Matrix{X}
    logprobs :: Matrix{L}
    info :: Vector{<:NamedTuple}
end
Base.eltype(c::SimpleChains{X,L}) where {X,L} = X
Base.size(c::SimpleChains, args...) = size(c.states, args...)

function get_info(c::SimpleChains)
    fields = fieldnames(eltype(c.info))
    (; (f => getfield.(c.info, f) for f=fields) ... )
end

function SimpleChains(states::Vector, logprobs::Vector, rejections::Vector; info... )
    repetitions = sum.(rejections) .+ 1
    N = sum(repetitions)
    states   = reshape(vcat(fill.(states,   repetitions)...), N, 1)
    logprobs = reshape(vcat(fill.(logprobs, repetitions)...), N, 1)
    info = [(; rejections = sum(rejections), info...) ]
    
    SimpleChains(states, logprobs, info)
end

function AbstractMCMC.chainscat(chains::SimpleChains ... )
    DefaultChains(
        hcat((c.states for c in chains) ... ),
        hcat((c.logprobs for c in chains) ... ),
        vcat((c.info for c in chains)...)
    )
end

#=
function Base.show(io::IO, c::SimpleChains)
    println(io, "Chain ", size(c), " {", eltype(c), "}")
    nt = eltype(c.info)
    for param in nt.parameters[1]
        x = getfield.(c.info, param)
        if size(c, 2) == 1 
            println(io, "  ", param, " ", x)
        else
            println(io, "  ", param, " ", mean(x), " ± ", std(x) )
        end
    end
end
=#