Skip to content
Snippets Groups Projects
base.jl 4.74 KiB
Newer Older
Luca Lenz's avatar
Luca Lenz committed

Luca Lenz's avatar
Luca Lenz committed

using Random
using Distributions
using AbstractMCMC

## === Sample 
abstract type AbstractSample end

# single level sample
abstract type SimpleSample <: AbstractSample end

""" simplest sample, just wraps a value
   (this is used for conventional single-level MCMC)"""
struct Sample{X} <: SimpleSample
    sample::X
Luca Lenz's avatar
Luca Lenz committed
end
Luca Lenz's avatar
Luca Lenz committed
Base.eltype(s::Sample{X}) where {X} = X
value(s::Sample) = s.sample


## === Proposal 
abstract type AbstractProposal{S<:AbstractSample,issymmetric} end
SymmetricProposal = AbstractProposal{S,true} where {S<:AbstractSample}
AsymmetricProposal = AbstractProposal{S,false} where {S<:AbstractSample}
Base.eltype(p::AbstractProposal{S, issymmetric}) where {S, issymmetric} = S 
is_symmetric(p::AbstractProposal{S,issymmetric}) where {S, issymmetric} = issymmetric 

Luca Lenz's avatar
Luca Lenz committed
"""
    To define a proposal implement the sampling methods 
        - initialization x₀ ~ p₀(⋅) by `x₀ = propose(p)`  
        - and transition y ~ q(⋅|x) by `y  = propose(p, x)`
Luca Lenz's avatar
Luca Lenz committed

Luca Lenz's avatar
Luca Lenz committed
    For asymmetric proposals also implement method
        for calculating  q(y|x)  by `logpdf(p,x,y)` 
"""
function propose end
propose(p::AbstractProposal; kwargs...) = propose(Random.default_rng(), p; kwargs...)
propose(p::AbstractProposal, x::AbstractSample; kwargs...) = propose(Random.default_rng(), p, x; kwargs...)
Luca Lenz's avatar
Luca Lenz committed

Luca Lenz's avatar
Luca Lenz committed
logpdfratio(p::SymmetricProposal, x::AbstractSample, y::AbstractSample) = 0
logpdfratio(p::AsymmetricProposal, x::AbstractSample, y::AbstractSample) = logpdf(p,x,y) - logpdf(p,y,x)

## Single-level proposals
SimpleProposal = AbstractProposal{S,issymmetric} where {S<:SimpleSample,issymmetric}
SimpleSymmetricProposal = SymmetricProposal{S} where {S<:SimpleSample}


# Static, sample independently of previous state 
struct StaticProposal{D<:Distribution,S} <: SimpleSymmetricProposal{S}
    distribution::D
end
function StaticProposal(d::Distribution)
    S = Sample{typeof(rand(d))}
    return StaticProposal{typeof(d),S}(d)
Luca Lenz's avatar
Luca Lenz committed
end
Luca Lenz's avatar
Luca Lenz committed
propose(rng::AbstractRNG, p::StaticProposal) = Sample(rand(rng, p.distribution))
propose(rng::AbstractRNG, p::StaticProposal, x) = propose(rng, p)
#logpdf(p::StaticProposal, x::Sample) = logpdf(p.distribution, x.sample) # won't be used


# Random Walk 
# (Caveat: only if step has zero mean is this actually symmetric! ) 
struct RandomWalk{D_init <: Distribution, D_step <: Distribution, F<:Function, S} <: SimpleSymmetricProposal{S}
    init::D_init
    step::D_step
    agg::F
end 
function RandomWalk(init::Distribution, step::Distribution, agg::Function=(+))
    Tinit = typeof(rand(init))
    Tstep = typeof(rand(step))

    # check that aggretation method yields propper type 
    if ! hasmethod(agg, (Tinit, Tstep)) 
        throw( ArgumentError("Aggregator $agg does not support ($Tinit, $Tstep) arguments.") )
    end
Luca Lenz's avatar
Luca Lenz committed

Luca Lenz's avatar
Luca Lenz committed
    # check that type is stable over chain 
    T = typeof(agg(rand(init), rand(step)))
    if T != Tinit 
        throw( ArgumentError("Aggreted type $T differes from initial type $Tinit.") )
Luca Lenz's avatar
Luca Lenz committed
    end
Luca Lenz's avatar
Luca Lenz committed

    return RandomWalk{typeof(init), typeof(step), typeof(agg), Sample{T}}(init, step, agg)
Luca Lenz's avatar
Luca Lenz committed
end
Luca Lenz's avatar
Luca Lenz committed
propose(rng::AbstractRNG, p::RandomWalk) = Sample(rand(rng, p.init))
propose(rng::AbstractRNG, p::RandomWalk, x::Sample) = Sample( p.agg(x.sample, rand(rng, p.step)) )
Luca Lenz's avatar
Luca Lenz committed

Luca Lenz's avatar
Luca Lenz committed
CyclicWalk(a::Real, b::Real, s::Real=(b-a)/2) = RandomWalk(Uniform(a,b), Uniform(0,s), (x,y)->mod(x+y-a,b-a)+a)
Luca Lenz's avatar
Luca Lenz committed

## === LogDensity
abstract type AbstractLogDensity{S <: AbstractSample} <: AbstractMCMC.AbstractModel end
Base.eltype(d::AbstractLogDensity{S}) where {S} = S 

""" Define an unnormalized logarithmic probability density 
    function `f(x)` by implementing  method `evaluate(f, x)`   
    where `x` is a sample that matches the sample type `eltype(f)`.
""" 
function evaluate end
(d::AbstractLogDensity{S})(x::S; kwargs...) where {S} = evaluate(d, x; kwargs...)


## Single-level log density 
abstract type SimpleLogDensity{S <: SimpleSample} <: AbstractLogDensity{S} end
Luca Lenz's avatar
Luca Lenz committed
# wrap a distribution, where logpdf is known
struct DistrLogDensity{D <: Distribution, S <: SimpleSample} <: SimpleLogDensity{S}
    d::D
end
function LogDensity(d::Distribution)
    S = Sample{typeof(rand(d))}
    return DistrLogDensity{typeof(d), S}(d)
end
evaluate(f::DistrLogDensity{D,S}, x::S) where {D <: Distribution,S <: Sample} = logpdf(f.d, x.sample) 
Luca Lenz's avatar
Luca Lenz committed
# wrap a specific method of a function 
struct FuncLogDensity{F <: Function, S <: SimpleSample} <: SimpleLogDensity{S}
    f :: F
end
function LogDensity(f::Function, T::Type)
    if ! isa(T, SimpleSample) T = State{T} end 

    if hasmethod(f, (T,)) 
        return FuncLogDensity{typeof(f), T}(f) 

    elseif hasmethod(f, (eltype(T),))
        g = x -> f(x.state)
        return FuncLogDensity{typeof(g), T}(g)
    else
        throw(ArgumentError("Function $f does not support $T argument."))
    end
end
evaluate(d::FuncLogDensity{F,S}, x::S) where {F <: Function, S} = d.f(x)