Skip to content
Snippets Groups Projects
samplebased.jl 1007 B
Newer Older
struct SampledLogDensity{T, F <: Function} <: AbstractLogDensity
    samples :: Vector{T}
    func :: F
end
function SampledLogDensity(f::Function, n::Int=32, d::Int=1, a=0, b=1) 
    p = shuffle!(primes(d^2+1)[1:d])
    z = hcat((Halton(p[i])[1:n] for i=1:d) ... )
    z = [Tuple(z[i,:]) for i=1:n]  
    return SampledLogDensity(z, f)
end

length(m::SampledLogDensity) = length(m.samples)

logdensity(m::SampledLogDensity, x; level=length(m), cache=nothing) = _logdensity(m, x, level, cache)
function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Nothing)  
    y = [m.func(x, z) for z=m.samples[1:level]] #TODO: parallelize evaluations of m.func !
function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Tuple{Int64, <:Number})  
    if cache[1] <= level
        y = [m.func(x, z) for z=m.samples[cache[1]+1:level]]
        return mean(y) * length(y)/level + cache[2] * cache[1]/level
    else
        return logdensity(m, x; level, nothing)