Skip to content
Snippets Groups Projects
multilevel.jl 1.63 KiB
Newer Older
abstract type AbstractMultilevelModel <: AbstractModel end

struct MultilevelLogDensity{ L <: Tuple{Vararg{<:AbstractModel}}}  <: AbstractMultilevelModel
    proxies :: L
end
MultilevelLogDensity(v::L... ) where {L <: AbstractModel} = MultilevelLogDensity(v)
MultilevelLogDensity(v::Vector{ <: AbstractModel} ) = MultilevelLogDensity( tuple(v...) )

length(m::MultilevelLogDensity) = length(m.proxies)
logdensity(m::MultilevelLogDensity, x; level=length(m), kwargs...) = logdensity(m.proxies[level], x)

function Base.show(io::IO, m::MultilevelLogDensity)
    println(io, "MultilevelLogDensity ", length(m))
    for i=1:length(m)
        println(io, "  ", m.proxies[i])
    end
end

struct MultilevelSampledLogDensity{X, F <: Function} <: AbstractMultilevelModel
    density :: SampledLogDensity{X, F}
    nlevels :: Vector{Int}
    # @assert maximum(nlevels) <= length(density)
end 
MultilevelSampledLogDensity(f::Function, n::Vector{<:Integer}, d::Int=1, a=0, b=1) = MultilevelSampledLogDensity(SampledLogDensity(f, maximum(n), d, a, b), n)

function Base.show(io::IO, m::MultilevelSampledLogDensity)
    print(io, "MultilevelSampledLogDensity ", m.nlevels, " {", eltype(m.density), "} samples ")
end

length(m::MultilevelSampledLogDensity) = length(m.nlevels)
logdensity(m::MultilevelSampledLogDensity, x; level=length(m), cache=nothing, kwargs...) = _logdensity(m, x, level, cache)

_logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Nothing) = logdensity(m.density, x; level=m.nlevels[level])
_logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Number)  = logdensity(m.density, x; level=m.nlevels[level], cache=(level-1, cache))