Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Standard Metropolis-Hastings algorithm
References
* Hastings, W.K. (1970). "Monte Carlo Sampling Methods Using Markov Chains and Their Applications".
Biometrika, Volume 57, Issue 1
"""
struct MetropolisHastings{P <: AbstractProposal} <: RejectionBasedSampler
proposal :: P
end
# Initialize samples container
function AbstractMCMC.samples(sample, model::AbstractModel, sampler::MetropolisHastings; kwargs...)
(accept, x, f_x) = sample
return (;
states = typeof(x)[],
logprobs = typeof(f_x)[],
rejections = Int[],
)
end
# Store sample to container
function AbstractMCMC.save!!(samples, sample, ::Integer, ::AbstractModel, ::RejectionBasedSampler; kwargs... )
(accept, x, f_x) = sample
if sample[1] # accepted
push!(samples.states, x)
push!(samples.logprobs, f_x)
push!(samples.rejections, 0)
else # rejected
samples.rejections[end] += 1
end
return samples
end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings; x0=nothing, f0=nothing, kwargs...)
# Initialize states
_resample = isnothing(x0)
x = _resample ? rand(rng, sampler.proposal) : x0
# Initialize logprobs
_recompute = _resample || isnothing(f0)
f_x = _recompute ? logdensity(model, x) : f0
return (true, x, f_x), (x, f_x)
end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings, state; kwargs...)
# Load old state
(x, f_x) = state
#propose new state
y = propose(rng, sampler.proposal, x)
f_y = logdensity(model, y)
q_ratio = logpratio(sampler.proposal, x, y)
logA = min( f_y - f_x + q_ratio, 0)
# Accept / Reject step
if log(rand(rng)) < logA
return (true, y, f_y), (y, f_y) # accept
else
return (false, x, f_x), (x, f_x) # reject
end
end
function AbstractMCMC.bundle_samples(samples, ::AbstractModel, ::MetropolisHastings, state, chain_type::Type{<:NamedTuple}; kwargs...)
return samples
end
function _chain_info(samples, model::AbstractModel, ::MetropolisHastings)
n_total = sum( 1 .+ samples.rejections)
n_reject = sum(samples.rejections)
info = Dict(
:chain_length => n_total,
:rejection_rate => n_reject / n_total,
)
if model isa SampledLogDensity
info[:total_costs] = n_total .* length(model)
end
return info
end