Skip to content
Snippets Groups Projects
multilevel.jl 1.42 KiB
Newer Older
abstract type AbstractMultilevelModel <: AbstractModel end

struct MultilevelLogDensity{ L <: Tuple{Vararg{<:AbstractModel}}}  <: AbstractMultilevelModel
    proxies :: L
end
MultilevelLogDensity(v::L... ) where {L <: Tuple{Vararg{<:AbstractModel}}} = MultilevelLogDensity(v)
MultilevelLogDensity(v::Vector{ <: AbstractModel} ) = MultilevelLogDensity( tuple(v...) )

length(m::MultilevelLogDensity) = length(m.proxies)
logdensity(m::MultilevelLogDensity, x, level=length(m)) = logdensity(m.proxies[level], x)

function Base.show(io::IO, m::SampledLogDensity)
    println(io, "MultilevelLogDensity ", length(m), " {", eltype(m.samples), "} samples")
end

struct MultilevelSampledLogDensity{X, F <: Function} <: AbstractMultilevelModel
    density :: SampledLogDensity{X, F}
    nlevels :: Vector{Int}
    # @assert maximum(nlevels) <= length(density)
end 
MultilevelSampledLogDensity(f::Function, n::Vector{<:Integer}, d::Int=1, a=0, b=1) = MultilevelSampledLogDensity(SampledLogDensity(f, maximum(n), d, a, b), n)

length(m::MultilevelSampledLogDensity) = length(m.nlevels)
logdensity(m::MultilevelSampledLogDensity, x, level=length(m), cache=nothing) = logdensity(m.density, x, level, cache)

logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Nothing) = logdensity(m.density, x, m.nlevels[level])
logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Number)  = logdensity(m.density, x, m.nlevels[level], (level-1, cache))