Skip to content
Snippets Groups Projects
samplebased.jl 1.16 KiB

struct SampledLogDensity{T, F <: Function} <: AbstractLogDensity
    samples :: Vector{T}
    func :: F
end
function SampledLogDensity(f::Function, n::Int=32, d::Int=1, a=0, b=1) 
    p = shuffle!(primes(d^2+1)[1:d])
    z = hcat((Halton(p[i])[1:n] for i=1:d) ... )
    z = [Tuple(z[i,:]) for i=1:n]  
    return SampledLogDensity(z, f)
end
eltype(::SampledLogDensity{T,F}) where {T,F} = T
length(m::SampledLogDensity) = length(m.samples)

function Base.show(io::IO, m::SampledLogDensity)
    print(io, "SampledLogDensity ", length(m), " {", eltype(m), "} samples ")
end

logdensity(m::SampledLogDensity, x; level=length(m), cache=nothing) = _logdensity(m, x, level, cache)

function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Nothing)  
    y = [m.func(x, z) for z=m.samples[1:level]] #TODO: parallelize evaluations of m.func !
    return mean(y)
end

function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Tuple{Int64, <:Number})  
    if cache[1] <= level
        y = [m.func(x, z) for z=m.samples[cache[1]+1:level]]
        return mean(y) * length(y)/level + cache[2] * cache[1]/level
    else
        return logdensity(m, x; level, nothing) 
    end
end