Skip to content
Snippets Groups Projects
metropolis_hastings.jl 1.97 KiB
Newer Older
"""
 Standard Metropolis-Hastings algorithm

 References
    
    * Hastings, W.K. (1970). "Monte Carlo Sampling Methods Using Markov Chains and Their Applications". 
      Biometrika, Volume 57, Issue 1

""" 
struct MetropolisHastings{P <: AbstractProposal} <: RejectionBasedSampler
    proposal :: P
end 


# Initialize samples container
function AbstractMCMC.samples(sample, model::AbstractModel, sampler::MetropolisHastings; kwargs...)
    return (; rejections=[0], transitions=[sample[2:end]])
end

# Store sample to container
function AbstractMCMC.save!!(samples, sample, ::Integer, ::AbstractModel, ::RejectionBasedSampler; kwargs... )
    if sample[1] # accepted
        push!(samples.rejections, 0)
        push!(samples.transitions, sample[2:end])
    else # rejected
        samples.rejections[end] += 1 
    end
    return samples
end

# Initial step
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings; kwargs...)
    x   = rand(rng, sampler.proposal)
    f_x = logdensity(model, x)
    return (true, x, f_x), (x, f_x)
end

# Accept / Reject step
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings, state; kwargs...)
    x, f_x = state

    y = propose(rng, sampler.proposal, x)
    f_y = logdensity(model, y)
    q = logpratio(sampler.proposal, x, y)
    
    A = min( f_y - f_x + q, 0)
    if log(rand(rng)) < A 
        return (true, y, f_y), (y, f_y) # accept
    else
        return (false, x, f_x), (x, f_x) # reject
    end
end

function total_costs(c::RejectionChains, m::SampledLogDensity)
    return c.info[:chain_length] .* length(m)
end

function AbstractMCMC.bundle_samples(samples, m::AbstractModel, ::MetropolisHastings, state, chain_type::Type; kwargs...)
    s = getindex.(samples.transitions, 1)
    l = getindex.(samples.transitions, 2)
    r = samples.rejections
    c = RejectionChains(s,l,r)
    if m isa SampledLogDensity
        c.info[:total_costs] = total_costs(c,m)