Skip to content
Snippets Groups Projects
Commit 1b3dd496 authored by Luca Lenz's avatar Luca Lenz
Browse files

updated sampling, added chains container

parent 0924d12e
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
julia_version = "1.10.0-rc2" julia_version = "1.10.0-rc2"
manifest_format = "2.0" manifest_format = "2.0"
project_hash = "bc36870fa7535d5a2c031fdc22e8505968678ceb" project_hash = "12bd3e4726969de53e935d129f12bb2195eaf926"
[[deps.ADTypes]] [[deps.ADTypes]]
git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245" git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245"
......
...@@ -5,6 +5,7 @@ version = "0.1.0" ...@@ -5,6 +5,7 @@ version = "0.1.0"
[deps] [deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
......
using Revise
using MultilevelChainSampler
include("proxies.jl")
include("analysis.jl")
include("visualize.jl")
module MultilevelChainSampler module MultilevelChainSampler
#using AdvancedMH using AbstractMCMC: AbstractMCMC, AbstractModel
#using AdvancedMH: DensityModelOrLogDensityModel using BangBang
using AbstractMCMC: AbstractMCMC, AbstractModel, AbstractSampler, sample
using Distributions using Distributions
using Random, StatsBase using Random, StatsBase
using Primes, HaltonSequences using Primes, HaltonSequences
import Base: rand, length import Base: rand, length, size
import Distributions: logpdf import Distributions: logpdf
import StatsBase: sample import StatsBase: sample
#import AdvancedMH.logdensity
#export CyclicWalk
#export SampledDensityModel
export LogDensity, SampledLogDensity export LogDensity, SampledLogDensity
export MultilevelLogDensity, MultilevelSampledLogDensity
export logdensity export logdensity
export propose, logpratio export propose, logpratio
export MetropolisHastings export MetropolisHastings, ChristenFox
export RandomWalk, CyclicWalk export RandomWalk, CyclicWalk
export sample export sample
export DefaultChains
include("models/all.jl") include("models/all.jl")
include("proposals/all.jl") include("proposals/all.jl")
include("algos/metropolis.jl") include("chains/all.jl")
#include("models/sample_based.jl") include("algos/all.jl")
end end
......
abstract type RejectionBasedSampler <: AbstractMCMC.AbstractSampler end
# Ignore chain length by default
function AbstractMCMC.save!!(samples, sample, iterations::Integer, model::AbstractModel, sampler::RejectionBasedSampler, ::Integer; kwargs...)
AbstractMCMC.save!!(samples, sample, iterations, model, sampler; kwargs...)
end
# Size hint sample container chain length
function AbstractMCMC.samples(sample, model::AbstractModel, sampler::RejectionBasedSampler, N::Integer; kwargs...)
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...)
T = typeof(samples)
for (k,t) in zip(T.parameters[1], T.types)
if t <: AbstractVector
sizehint!(getfield(samples, k), N)
end
end
return samples
end
include("metropolis_hastings.jl")
include("christen_fox.jl")
"""
Delayed Acceptance algorithm
If `saveproxies == true` save accepted states on lower levels.
This is particularly useful for MLMC integration.
References
* Christen, J.A. and Fox, C. (2005). "Markov chain Monte Carlo Using an Approximation"
Journal of Computational and Graphical Statistics
"""
struct ChristenFox{saveproxies, P <: AbstractProposal} <: RejectionBasedSampler
proposal :: P
end
ChristenFox(proposal::AbstractProposal, saveproxies::Bool=false) = ChristenFox{saveproxies, typeof(proposal)}(proposal)
# Initialize samples container
function AbstractMCMC.samples(sample, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs...) where {P}
return (; rejections=[zeros(Int, length(model))], transitions=[sample[2:end]])
end
# Store sample to container
function AbstractMCMC.save!!(samples, sample, ::Integer, model::AbstractMultilevelModel, ::ChristenFox{false, P}; kwargs... ) where {P}
if sample[1] == length(model)+1 # sample was accepted
push!(samples.rejections, zeros(Int, length(model)))
push!(samples.transitions, sample[2:end])
else
samples.rejections[end][sample[1]] += 1
end
return samples
end
# Chain step
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}; kwargs...) where {P}
x = rand(rng, sampler.proposal)
f_x = [ logdensity(model, x, i) for i=1:length(model) ]
return (length(model), x, f_x[end]), (x, f_x)
end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractMultilevelModel, sampler::ChristenFox{false, P}, state; kwargs...) where {P}
x, f_x = state
y = propose(rng, sampler.proposal, x)
f_y = [ logdensity(model, y, 1) ]
q = logpratio(sampler.proposal, x, y)
A_1 = min( f_y[1] - f_x[1] + q, 0)
accept = log(rand(rng)) < A_1
if !accept return (1, x, f_x[end]), (x, f_x) end
for l = 2:length(model)
push!(f_y, logdensity(model, y, l) )
A_l = f_y[l] - f_x[l] + f_y[l-1] - f_x[l-1]
accept = log(rand(rng)) < A_l
if !accept return (l, y, f_y[end]), (x, f_x) end
end
return (length(model)+1, y, f_y[end]), (y, f_y)
end
# Collect samples
function AbstractMCMC.bundle_samples(samples, m::AbstractMultilevelModel, ::ChristenFox{false,P}, state, chain_type::Type; kwargs ... ) where {P}
states = getindex.(samples.transitions, 1)
logprobs = getindex.(samples.transitions, 2)
info = Dict()
N = sum( sum.(samples.rejections) .+ 1)
info[:rejection_rate] = sum(samples.rejections) ./ N
if m isa MultilevelSampledLogDensity
nevals = N .- cumsum([0, sum(samples.rejections)[1:end-1]...])
info[:evaluations] = sum( m.nlevels .* nevals )
end
return SimpleChains(states, logprobs, samples.rejections; info...)
end
## Modified to save lower level accepted steps
abstract type RejectionBasedSampler <: AbstractSampler end
struct MetropolisHastings{P <: AbstractProposal} <: RejectionBasedSampler
proposal :: P
end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings; kwargs...)
x = rand(rng, sampler.proposal)
f_x = logdensity(model, x)
return (true, x, f_x), (x, f_x)
end
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings, state; kwargs...)
x, f_x = state
y = propose(rng, sampler.proposal, x)
f_y = logdensity(model, y)
q = logpratio(sampler.proposal, x, y)
A = min( f_y - f_x + q, 0)
accept = log(rand(rng)) < A
if accept
return (true, y, f_y), (y, f_y)
else
return (false, x, f_x), (x, f_x)
end
end
\ No newline at end of file
"""
Standard Metropolis-Hastings algorithm
References
* Hastings, W.K. (1970). "Monte Carlo Sampling Methods Using Markov Chains and Their Applications".
Biometrika, Volume 57, Issue 1
"""
struct MetropolisHastings{P <: AbstractProposal} <: RejectionBasedSampler
proposal :: P
end
# Initialize samples container
function AbstractMCMC.samples(sample, model::AbstractModel, sampler::MetropolisHastings; kwargs...)
return (; rejections=[0], transitions=[sample[2:end]])
end
# Store sample to container
function AbstractMCMC.save!!(samples, sample, ::Integer, ::AbstractModel, ::RejectionBasedSampler; kwargs... )
if sample[1] # accepted
push!(samples.rejections, 0)
push!(samples.transitions, sample[2:end])
else # rejected
samples.rejections[end] += 1
end
return samples
end
# Initial step
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings; kwargs...)
x = rand(rng, sampler.proposal)
f_x = logdensity(model, x)
return (true, x, f_x), (x, f_x)
end
# Accept / Reject step
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractModel, sampler::MetropolisHastings, state; kwargs...)
x, f_x = state
y = propose(rng, sampler.proposal, x)
f_y = logdensity(model, y)
q = logpratio(sampler.proposal, x, y)
A = min( f_y - f_x + q, 0)
if log(rand(rng)) < A
return (true, y, f_y), (y, f_y) # accept
else
return (false, x, f_x), (x, f_x) # reject
end
end
# Collect samples
function AbstractMCMC.bundle_samples(samples, m::AbstractModel, ::MetropolisHastings, state, chain_type::Type; kwargs ... )
states = getindex.(samples.transitions, 1)
logprobs = getindex.(samples.transitions, 2)
info = Dict()
N = sum( sum.(samples.rejections) .+ 1)
info[:rejection_rate] = sum(samples.rejections) ./ N
if m isa SampledLogDensity
info[:evaluations] = N * length(m)
end
return SimpleChains(states, logprobs, samples.rejections; info...)
end
abstract type RejectionBasedChains <: AbstractMCMC.AbstractChains end
include("simple.jl")
#include("multi.jl")
function Base.show(io::IO, c::RejectionBasedChains)
print(io, "Chain ", size(c), " {", eltype(c), "}")
end
function Base.display(c::RejectionBasedChains)
println(c)
info = get_info(c)
nchains = size(c, 2)
for k = keys(info)
if nchains == 1
println(" ", k, " : ", info[k][1])
else
m, s = mean(info[k]), std(info[k])
println(" ", k, " : ", m, " ± ", s)
end
end
end
\ No newline at end of file
struct MultiChains{X, L <: Number} <: RejectionBasedChains
states :: Matrix{X}
logprobs :: Matrix{L}
levels :: Matrix{<:Integer}
info :: Vector{<:NamedTuple}
end
Base.eltype(c::MultiChains{X,L}) where {X,L} = X
Base.length(c::MultiChains) = length(c.states)
function MultiChains(states::Vector, logprobs::Vector, levels::Vector, rejections::Vector; info... )
N = sum(rejections)
repetitions = sum.(rejections) .+ 1
MultiChains(
reshape(vcat(fill.(states, repetitions)), N, 1),
reshape(vcat(fill.(logprobs, repetitions)), N, 1),
reshape(vcat(fill.(levels, repetitions)), N, 1),
[(; rejection_rate = sum(rejections) ./ N, info...) ]
)
end
function AbstractMCMC.chainscat(chains::MultiChains ... )
DefaultChains(
hcat((c.states for c in chains) ... ),
hcat((c.logprobs for c in chains) ... ),
hcat((c.levels for c in chains) ... ),
vcat((c.info for c in chains)...)
)
end
function Base.show(io::IO, c::SimpleChains)
println(io, "Chain ", size(c), " {", eltype(c), "}")
nt = eltype(c.info)
for param in nt.parameters[1]
x = getfield.(c.info, param)
if size(c, 2) == 1
println(io, " ", param, " ", x)
else
println(io, " ", param, " ", mean(x), " ± ", std(x) )
end
end
end
struct SimpleChains{X, L <: Number} <: RejectionBasedChains
states :: Matrix{X}
logprobs :: Matrix{L}
info :: Vector{<:NamedTuple}
end
Base.eltype(c::SimpleChains{X,L}) where {X,L} = X
Base.size(c::SimpleChains, args...) = size(c.states, args...)
function get_info(c::SimpleChains)
fields = fieldnames(eltype(c.info))
(; (f => getfield.(c.info, f) for f=fields) ... )
end
function SimpleChains(states::Vector, logprobs::Vector, rejections::Vector; info... )
repetitions = sum.(rejections) .+ 1
N = sum(repetitions)
states = reshape(vcat(fill.(states, repetitions)...), N, 1)
logprobs = reshape(vcat(fill.(logprobs, repetitions)...), N, 1)
info = [(; rejections = sum(rejections), info...) ]
SimpleChains(states, logprobs, info)
end
function AbstractMCMC.chainscat(chains::SimpleChains ... )
DefaultChains(
hcat((c.states for c in chains) ... ),
hcat((c.logprobs for c in chains) ... ),
vcat((c.info for c in chains)...)
)
end
#=
function Base.show(io::IO, c::SimpleChains)
println(io, "Chain ", size(c), " {", eltype(c), "}")
nt = eltype(c.info)
for param in nt.parameters[1]
x = getfield.(c.info, param)
if size(c, 2) == 1
println(io, " ", param, " ", x)
else
println(io, " ", param, " ", mean(x), " ± ", std(x) )
end
end
end
=#
struct LogDensity{L} <: AbstractModel abstract type AbstractLogDensity <: AbstractModel end
struct LogDensity{L} <: AbstractLogDensity
density :: L density :: L
end end
...@@ -6,4 +8,5 @@ logdensity(model::LogDensity{<:Function}, x) = model.density(x) ...@@ -6,4 +8,5 @@ logdensity(model::LogDensity{<:Function}, x) = model.density(x)
logdensity(model::LogDensity{<:Distribution}, x) = logpdf(model.density, x) logdensity(model::LogDensity{<:Distribution}, x) = logpdf(model.density, x)
logdensity(model::AbstractMCMC.LogDensityModel, x) = LogDensityProblems.logdensity(model.logdensity, x) logdensity(model::AbstractMCMC.LogDensityModel, x) = LogDensityProblems.logdensity(model.logdensity, x)
include("sample_based.jl") include("samplebased.jl")
\ No newline at end of file include("multilevel.jl")
\ No newline at end of file
abstract type AbstractMultilevelModel <: AbstractModel end
struct MultilevelLogDensity{ L <: Tuple{Vararg{<:AbstractModel}}} <: AbstractMultilevelModel
proxies :: L
end
MultilevelLogDensity(v::L... ) where {L <: Tuple{Vararg{<:AbstractModel}}} = MultilevelLogDensity(v)
MultilevelLogDensity(v::Vector{ <: AbstractModel} ) = MultilevelLogDensity( tuple(v...) )
length(m::MultilevelLogDensity) = length(m.proxies)
logdensity(m::MultilevelLogDensity, x, level=length(m)) = logdensity(m.proxies[level], x)
function Base.show(io::IO, m::SampledLogDensity)
println(io, "MultilevelLogDensity ", length(m), " {", eltype(m.samples), "} samples")
end
struct MultilevelSampledLogDensity{X, F <: Function} <: AbstractMultilevelModel
density :: SampledLogDensity{X, F}
nlevels :: Vector{Int}
# @assert maximum(nlevels) <= length(density)
end
MultilevelSampledLogDensity(f::Function, n::Vector{<:Integer}, d::Int=1, a=0, b=1) = MultilevelSampledLogDensity(SampledLogDensity(f, maximum(n), d, a, b), n)
length(m::MultilevelSampledLogDensity) = length(m.nlevels)
logdensity(m::MultilevelSampledLogDensity, x, level=length(m), cache=nothing) = logdensity(m.density, x, level, cache)
logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Nothing) = logdensity(m.density, x, m.nlevels[level])
logdensity(m::MultilevelSampledLogDensity, x, level::Int, cache::Number) = logdensity(m.density, x, m.nlevels[level], (level-1, cache))
struct SampledLogDensity{T, F} <: AbstractModel struct SampledLogDensity{T, F <: Function} <: AbstractLogDensity
samples :: Vector{T} samples :: Vector{T}
func :: F func :: F
end end
function SampledLogDensity(f::Function, n::Int=32, d::Int=1, a=0, b=1)
p = shuffle!(primes(d^2+1)[1:d])
z = hcat((Halton(p[i])[1:n] for i=1:d) ... )
z = [Tuple(z[i,:]) for i=1:n]
return SampledLogDensity(z, f)
end
length(m::SampledLogDensity) = length(m.samples) length(m::SampledLogDensity) = length(m.samples)
logdensity(m::SampledLogDensity, x, level=length(m), cache=nothing) = logdensity(m, x, level, cache) logdensity(m::SampledLogDensity, x, level=length(m), cache=nothing) = logdensity(m, x, level, cache)
function logdensity(m::SampledLogDensity, x, level::Int64, cache::Nothing) function logdensity(m::SampledLogDensity, x, level::Int64, cache::Nothing)
y = [m.func(x, z) for z=m.samples[1:level]] y = [m.func(x, z) for z=m.samples[1:level]] #TODO: parallelize evaluations of m.func !
return mean(y) return mean(y)
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment