Skip to content
Snippets Groups Projects
Commit 3ed30a51 authored by Luca Lenz's avatar Luca Lenz
Browse files

tested MH sampler, added multilevel state and proposal

parent 9c50c233
No related branches found
No related tags found
No related merge requests found
......@@ -2,9 +2,12 @@
module MultilevelChainSampler
include("base.jl")
export value, propose, evaluate
export value, propose, is_symmetric, evaluate
export Sample
export StaticProposal, RandomWalk, CyclicWalk
export LogDensity
include("hastings.jl")
export MetropolisHastings
end
\ No newline at end of file
......@@ -23,7 +23,9 @@ value(s::Sample) = s.sample
abstract type AbstractProposal{S<:AbstractSample,issymmetric} end
SymmetricProposal = AbstractProposal{S,true} where {S<:AbstractSample}
AsymmetricProposal = AbstractProposal{S,false} where {S<:AbstractSample}
Base.eltype(p::AbstractProposal{S, issymmetric}) where {S, issymmetric} = S
is_symmetric(p::AbstractProposal{S,issymmetric}) where {S, issymmetric} = issymmetric
"""
To define a proposal implement the sampling methods
- initialization x₀ ~ p₀(⋅) by `x₀ = propose(p)`
......
using AbstractMCMC
using Distributions
using Random
include("base.jl")
struct DelayedAcceptance{P <: Tuple{Vararg{<:AbstractProposal}}} <: AbstractMCMC.AbstractSampler
end
\ No newline at end of file
include("base.jl")
#include("propose.jl")
using AbstractMCMC
using Distributions
using Random
......@@ -11,8 +8,8 @@ struct MetropolisHastings{P <: AbstractProposal} <: AbstractMCMC.AbstractSampler
end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, s::MetropolisHastings; kwargs...)
sample = initialize(rng, s.proposal)
logp = logdensity(model, sample)
sample = propose(rng, s.proposal)
logp = evaluate(model, sample)
state = (sample, logp, true)
return sample, state
end
......@@ -20,8 +17,8 @@ end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, s::MetropolisHastings, state; kwargs...)
sample, logp, accept = state
new_sample = propose(s.proposal, sample)
new_logp = logdensity(model, new_sample)
q_logratio = logdensityratio(s.proposal, new_sample, sample)
new_logp = evaluate(model, new_sample)
q_logratio = logpdfratio(s.proposal, new_sample, sample)
if log(rand(rng)) < new_logp - logp + q_logratio
return new_sample, (new_sample, new_logp, true)
......@@ -30,10 +27,3 @@ function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, s::Metro
end
end
## =============== tests ===========================
#using Test
#p = StaticProposal(Uniform(-1, 1))
#s = MetropolisHastings(p)
#m = LogDensity(Normal(0, 1))
#@time c = sample(m, s, 100000)
\ No newline at end of file
## Multi-level state
abstract type HierarchicalState <: AbstractState end
#=
# by default, the multilevel state wraps a tuple of states
struct MultilevelState{T <: Tuple} <: HierarchicalState
levels :: T
end
MultilevelState(levels...; kwargs...) = MultilevelState(levels)
course(::MultilevelState) = MultilevelState(s.levels[1:end-1])
fine(s::MultilevelState) = s.levels[end]
combine(s::MultilevelState, fine_state) = MultilevelState(s.levels..., fine_state)
level(s::MultilevelState) = length(s.levels)
# Multi-level log density
abstract type HierarchicalLogDensity{S <: HierarchicalState} <: AbstractLogDensity{S} end
struct MultilevelLogDensity{F, S} <: HierarchicalLogDensity{S}
logdensities :: F
end
logdensity(d::HierarchicalLogDensity, state::HierarchicalState; level=length(state)) = logdensity(d, state)
=#
#level(s::MultilevelState) = length(s.states)
#Base.eltype(states::MultilevelState) = eltype(states.states)
## construct from single level states
#MultilevelState(states::State ... ) = MultilevelState(getindex.([states ... ], :state))
include("base.jl")
## recursive construction
#course(states::MultilevelState) = MultilevelState(states[1:end-1])
#fine(states::MultilevelState) = states[end]
#combine(states::MultilevelState, state) = MultilevelState(vcat(states.states ... , state))
# For hierarchical multilevel implement these methods
function combine end
function course end
function fine end
# === Multilevel sample
abstract type HierarchicalSample <: AbstractSample end
#=
## Multi-level log density
abstract type HierarchicalLogDensity{S} <: AbstractLogDensity{S} end
struct MultilevelLogDensity{F, S} <: HierarchicalLogDensity{S}
logdensities :: F
struct MultilevelSample{S <: Tuple{Vararg{<:SimpleSample}}} <: HierarchicalSample
samples :: S
end
logdensity(d::MultilevelLogDensity, state::MultilevelState) = logdensity(d.logdensities, state.states
MultilevelSample(mls::SimpleSample...) = MultilevelSample{typeof(mls)}(mls)
Base.length(s::MultilevelSample{S}) where {S} = length(S.types)
combine(mls::MultlevelSample, fine_sample::SimpleSample) = MultilevelSample(mls.samples..., fine_sample)
combine(samples::SimpleSample ... ) = MultlevelSample(samples... )
course(mls::MultilevelSample) = length(mls) == 2 ? mls.samples[1] : MultilevelSample(mls.samples[1:end-1]...)
fine(mls::MultilevelSample) = mls.samples[end]
#logdensity(d::HierarchicalLogDensity, state::HierarchicalState, level::Int)
function logdensity(d::MultilevelLogDensity, state::HierarchicalState, level::Int)
logdensity(d.logdensities[level], state.states[level]
## === Multilevel proposal
abstract type HierarchicalProposal{S <: HierarchicalState, issymmetric} <: AbstractProposal{S, issymmetric} end
struct MultilevelProposal{P <: Tuple{Vararg{<:AbstractProposal}}, S, issymmetric} <: HierarchicalProposal{S, issymmetric}
proposals :: P
end
#MultilevelLogDensity(logdensities::AbstractVector{<:AbstractLogDensity}) = MultilevelLogDensity(logdensities
=#
function MultilevelProposal(proposals::AbstractProposal...)
issymmetric = all(is_symmetric.(proposals))
MultilevelProposal{typeof(proposals), issymmetric}(proposals)
end
Base.length(mlp::MultilevelProposal) = length(mlp.proposals)
#abstract type AbstractMultilevelLogDensity <: LogDensity{S} end
combine(mlp::MulilevelProposal, fine_proposal::SimpleProposal) = MultilevelProposal(mlp.proposals..., fine_proposal)
course(mlp::MultilevelProposal) = MultilevelProposal(mlp.proposals[1:end-1] ... )
fine(mlp::MultilevelProposal) = mlp.proposals[end]
#abstract type AbstractMultilevelLogDensity <: LogDensity{S} end
propose(rng::AbstractRNG, mlp::MultilevelProposal) = MultlevelSample([ propose(rng, p) for p in mlp.proposals]...)
propose(rng::AbstractRNG, mlp::MultilevelProposal{S}, mls::S) where {S<:MultilevelSample} = S([ propose(rng, p, s) for (p,s) in zip(mlp.proposals, mls.samples)]...)
logpdfratio(mlp::MultilevelProposal{S, false}, X::S, Y::S) where {S <: MultilevelSample} = sum([logpdfratio(p, x, y) for (p,x,y) in zip(mlp.proposals, x.samples, y.samples)])
#struct MultilevelLogDensity <: AbstractMultilevelLogDensity
#end
## === Multilevel log density
#abstract type HierarchicalLogDensity{S <: HierarchicalState} <: AbstractLogDensity{S} end
......@@ -36,3 +36,13 @@ d = LogDensity(Uniform(0,1))
@test evaluate(d, Sample(0.0)) == d(Sample(0.0))
end
@testset "hastings" begin
p = RandomWalk(Dirac(0.0), Uniform(-.1, .1))
mh = MetropolisHastings(p)
t = LogDensity(Normal(0,1))
c = sample(t, mh, 100000)
c = value.(c)
@test ( mean(c), 0.0 , atol=0.1 )
@test ( std(c), 1.0 , atol=0.1 )
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment