samplebased.jl 1.16 KiB
struct SampledLogDensity{T, F <: Function} <: AbstractLogDensity
samples :: Vector{T}
func :: F
end
function SampledLogDensity(f::Function, n::Int=32, d::Int=1, a=0, b=1)
p = shuffle!(primes(d^2+1)[1:d])
z = hcat((Halton(p[i])[1:n] for i=1:d) ... )
z = [Tuple(z[i,:]) for i=1:n]
return SampledLogDensity(z, f)
end
eltype(::SampledLogDensity{T,F}) where {T,F} = T
length(m::SampledLogDensity) = length(m.samples)
function Base.show(io::IO, m::SampledLogDensity)
print(io, "SampledLogDensity ", length(m), " {", eltype(m), "} samples ")
end
logdensity(m::SampledLogDensity, x; level=length(m), cache=nothing) = _logdensity(m, x, level, cache)
function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Nothing)
y = [m.func(x, z) for z=m.samples[1:level]] #TODO: parallelize evaluations of m.func !
return mean(y)
end
function _logdensity(m::SampledLogDensity, x, level::Int64, cache::Tuple{Int64, <:Number})
if cache[1] <= level
y = [m.func(x, z) for z=m.samples[cache[1]+1:level]]
return mean(y) * length(y)/level + cache[2] * cache[1]/level
else
return logdensity(m, x; level, nothing)
end
end