""" Standard Metropolis-Hastings algorithm References * Hastings, W.K. (1970). "Monte Carlo Sampling Methods Using Markov Chains and Their Applications". Biometrika, Volume 57, Issue 1 """ struct MetropolisHastings{P <: AbstractProposal} <: RejectionBasedSampler proposal :: P end # Initialize samples container function AbstractMCMC.samples(sample, model::AbstractModel, sampler::MetropolisHastings; kwargs...) (accept, x, f_x) = sample return (; states = typeof(x)[], logprobs = typeof(f_x)[], rejections = Int[], ) end # Store sample to container function AbstractMCMC.save!!(samples, sample, ::Integer, ::AbstractModel, ::RejectionBasedSampler; kwargs... ) (accept, x, f_x) = sample if sample[1] # accepted push!(samples.states, x) push!(samples.logprobs, f_x) push!(samples.rejections, 0) else # rejected samples.rejections[end] += 1 end return samples end function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings; x0=nothing, f0=nothing, kwargs...) # Initialize states _resample = isnothing(x0) x = _resample ? rand(rng, sampler.proposal) : x0 # Initialize logprobs _recompute = _resample || isnothing(f0) f_x = _recompute ? logdensity(model, x) : f0 return (true, x, f_x), (x, f_x) end function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings, state; kwargs...) # Load old state (x, f_x) = state #propose new state y = propose(rng, sampler.proposal, x) f_y = logdensity(model, y) q_ratio = logpratio(sampler.proposal, x, y) logA = min( f_y - f_x + q_ratio, 0) # Accept / Reject step if log(rand(rng)) < logA return (true, y, f_y), (y, f_y) # accept else return (false, x, f_x), (x, f_x) # reject end end function AbstractMCMC.bundle_samples(samples, ::AbstractModel, ::MetropolisHastings, state, chain_type::Type{<:NamedTuple}; kwargs...) return samples end function _chain_info(samples, model::AbstractModel, ::MetropolisHastings) n_total = sum( 1 .+ samples.rejections) n_reject = sum(samples.rejections) info = Dict( :chain_length => n_total, :rejection_rate => n_reject / n_total, ) if model isa SampledLogDensity info[:total_costs] = n_total .* length(model) end return info end