Skip to content
Snippets Groups Projects
metropolis_hastings.jl 2.43 KiB
Newer Older
"""
 Standard Metropolis-Hastings algorithm

 References
    
    * Hastings, W.K. (1970). "Monte Carlo Sampling Methods Using Markov Chains and Their Applications". 
      Biometrika, Volume 57, Issue 1

""" 
struct MetropolisHastings{P <: AbstractProposal} <: RejectionBasedSampler
    proposal :: P
end 

# Initialize samples container
function AbstractMCMC.samples(sample, model::AbstractModel, sampler::MetropolisHastings; kwargs...)
    (accept, x, f_x) = sample
    return (; 
        states = typeof(x)[],
        logprobs = typeof(f_x)[],
        rejections = Int[], 
    )
end

# Store sample to container
function AbstractMCMC.save!!(samples, sample, ::Integer, ::AbstractModel, ::RejectionBasedSampler; kwargs... )
    (accept, x, f_x) = sample

    if sample[1] # accepted
        push!(samples.states, x)
        push!(samples.logprobs, f_x)
        push!(samples.rejections, 0)
    
    else # rejected
        samples.rejections[end] += 1 
    end
    return samples
end


function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings; x0=nothing, f0=nothing, kwargs...)
    
    # Initialize states
    _resample = isnothing(x0)  
    x = _resample ? rand(rng, sampler.proposal) : x0
    
    # Initialize logprobs
    _recompute = _resample || isnothing(f0)
    f_x = _recompute ? logdensity(model, x) : f0

    return (true, x, f_x), (x, f_x)
end

function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings, state; kwargs...)
    
    # Load old state 
    (x, f_x) = state

    #propose new state  
    y       = propose(rng, sampler.proposal, x)
    f_y     = logdensity(model, y)
    
    q_ratio = logpratio(sampler.proposal, x, y)
    logA = min( f_y - f_x + q_ratio, 0)
    
    # Accept / Reject step
    if log(rand(rng)) < logA 
        return (true, y, f_y), (y, f_y) # accept
    else
        return (false, x, f_x), (x, f_x) # reject
    end
end


function AbstractMCMC.bundle_samples(samples, ::AbstractModel, ::MetropolisHastings, state, chain_type::Type{<:NamedTuple}; kwargs...)
    return samples
end

function _chain_info(samples, model::AbstractModel, ::MetropolisHastings)
    n_total = sum( 1 .+ samples.rejections)
    n_reject = sum(samples.rejections)
    info = Dict(
        :chain_length   => n_total,
        :rejection_rate => n_reject / n_total,
    )

    if model isa SampledLogDensity
        info[:total_costs] = n_total .* length(model)
    end

    return info
end