Skip to content
Snippets Groups Projects
Commit f054e449 authored by Luca Lenz's avatar Luca Lenz
Browse files

fixed multilevel chains; implemented univariate auto correlation

parent ebb358ad
No related branches found
No related tags found
No related merge requests found
......@@ -26,10 +26,12 @@ export saves_reject
export saves_subchain
export has_fixed_length
export autocor
# source files
include("proposal.jl")
include("logdensity.jl")
#include("chains.jl")
include("chains.jl")
include("samplers/metropolis_hastings.jl")
include("samplers/delayed_acceptance.jl")
......
struct MLChains{S <: Vector{<:Vector}, saves_logprob, saves_reject, fixed_length} <: AbstractMCMC.AbstractChains
struct MultilevelChains{S <: Vector{<:Vector}, save_logprob, save_reject, fixed_length} <: AbstractMCMC.AbstractChains
samples::S
end
function MLChains(s::Vector{<:Vector}, saves_logprob::Bool=false, saves_reject::Bool=false, fixed_length::Bool=false)
return MLChains{typeof(s), saves_logprob, saves_reject, fixed_length}(s)
function MultilevelChains(s::Vector{<:Vector}; save_logprob::Bool=false, save_reject::Bool=false, fixed_length::Bool=false)
return MultilevelChains{typeof(s), save_logprob, save_reject, fixed_length}(s)
end
size(c::MLChains) = (length(c.samples) ..., )
Base.size(c::MultilevelChains) = (length(c.samples) ..., )
function subchain_indices(c::MLChains{S, saves_logprob, saves_reject, true}, level::Integer) where {S, saves_logprob, saves_reject}
# TODO: implement AbstractMCMC.chainscat( ... )
function subchain_indices(c::MultilevelChains{S, save_logprob, save_reject, true}, level::Integer) where {S, save_logprob, save_reject}
if level == 1 return 1:length(c.samples[1]) end
shape = length.(c.samples)
......@@ -30,218 +32,41 @@ function subchain_indices(c::MLChains{S, saves_logprob, saves_reject, true}, lev
end
function subchain(c::MultilevelChains{S, saves_logprob, saves_reject, fixed_length}, level::Integer) where {S,saves_logprob, saves_reject, fixed_length}
function subchain(c::MultilevelChains{S, save_logprob, save_reject, fixed_length}, level::Integer) where {S,save_logprob, save_reject, fixed_length}
if level == 1 return c.samples[1] end
# multilevel states
ns = subchain_indices(c, level)
if ! saves_logprob && ! saves_reject && fixed_length
end
x = map( i->first.( c.samples[i][ ns[i] ]), 1:level )
x = [ zip(x...) ... ]
end
#=
struct MultilevelChains{S <: Tuple{Vararg{Vector{<:NamedTuple}}}, fixed_length} <: AbstractMCMC.AbstractChains
samples::S
end
function MultilevelChains(s::Tuple{Vararg{Vector{<:NamedTuple}}})
fixed_length = ! any( hasfield.(eltype.(s), :n))
return MultilevelChains{typeof(s), fixed_length}(s)
end
# Has subchain length :n
function subchain_indices(c::MultilevelChains{S, false}, level::Integer) where {S}
if level == 1 return 1:length(c.samples[1]) end #all lowest level samples
# indices of subchain
ns = [ getfield.(c.samples[level], :n) ]
for j=level-1:-1:1
pushfirst!(ns, getfield.(c.samples[j][ns[1]], :n))
end
return ns
end
level > length(c.samples) && throw(ArgumentError("Level $level is larger than maximal level $(length(c.samples))"))
# Fixed size delayed acceptance
function subchain_indices(c::MultilevelChains{S, true}, level::Integer) where {S}
if level == 1 return 1:length(c.samples[1]) end # all lowest level samples
shape = length.(c.samples[1:level])
if all( mod.(shape[1:end-1], shape[2:end]) .== 0 ) # if discard_initial >= 1
N = shape[level]
ns = StepRange.(1, shape[1:level] .÷ N, shape)
return ns
elseif all( mod.(shape[1:end-1].-1, shape[2:end].-1) .== 0 )
N = shape[level]
shape = (shape[1:level] .- 1) .÷ (N-1)
ns1 = Base.vect.(ones(Int, level))
ns2 = collect.(StepRange.(2, shape, length.(c.samples[1:level])))
return vcat.(ns1, ns2)
else
throw(ErrorException("Subchain lengths $(length.(c.samples)) are not consistent"))
end
end
function subchain(c::MultilevelChains, level::Integer)
if level == 1 return c.samples[1] end
# multilevel states
# get samples
ns = subchain_indices(c, level)
x = map( i->getfield.( c.samples[i][ ns[i] ], :x ), 1:level )
x = [ zip(x...) ... ]
# parameters except for :x and :n
params = fieldnames(eltype(c.samples[1]))[2:end]
vals = map(param -> getfield.(c.samples[level], param), params)
# convert to vector of named tuples
T = NamedTuple{(:x, params... ), Tuple{eltype(x), eltype.(vals)... }}
return map(T, zip(x, vals...))
end
=#
#=
function MultilevelChains{S <: Tuple{Vararg{Vector{<:NamedTuple}}}}
samples::S
end
function view(c::MultilevelChains, i::Integer)
params = fieldnames(eltype(c.samples[1]))
if i == 1
return (;
k => view( getfield.(c.samples[1], f) )
for f in fieldnames(eltype(c.samples[1]))
)
samples = getindex.(c.samples[1:level], ns)
if fixed_length && ! save_logprob && ! save_reject
return [zip(samples... ) ... ]
end
params = fieldnames(eltype(c.samples[1]))
ns =
x = [
view(getfield.(c.samples[j], :x))
for j = 1:i
]
#x = view(getfield.(c.samples[j], :x))
getfield.(mc.samples[1])
end
function bundle_samples(samples, model::AbstractLogDensity, da::DelayedAcceptance, state, type=; kwargs...)
x =
if has_deterministic_length(da)
if discard_initial <= 1
initial = samples[]
end
# multilevel states, will always be returned
x = [ zip(map(t->first.(t), samples)) ... ]
# log density and rejection flag can be returned optionally
if save_logprob
logp_x = getindex.(c.samples[level], 2)
if ! save_reject return [ zip(x, logp_x) ... ] end
end
end
struct MultilevelChains{C <: Tuple{Vararg{Vector{<:NamedTuple}}}, L <: Tuple{Vararg{Vector{<:Integer}}}}
samples :: C
sublengths :: L
end
function MultilevelChains(samples)
typeof(samples) <: Tuple{Vararg{Vector{<:NamedTuple}}} || throw(ArgumentError("Got $typeof(samples) but expected Tuple{Vararg{Vector{<:NamedTuple}}}"))
# if subsamples samples have random length
has_random_length = any( hasfield.(eltype.(c.samples[1:i]), :n) )
if has_random_length
ns = [ [1:length(c.samples[end])...], getfield.(c.samples[end], :n) ]
for i=length(c.samples)-1:-1:2
pushfirst!(ns, getfield.(c.samples[j][ns[1]], :n))
end
samples ! # TODO remove :n field from samples
return MultilevelChains{typeof(samples), typeof(ns)}(samples, ns)
else
shape = length.(c.samples)
# if first sample was discarded
if all(i-> mod( shape[i], shape[i-1] ) == 0, 2:length(shape) )
ns = [ shape[end] ]
for i=length(c.samples)-1:-1:1
pushfirst!(ns, shape[i]÷shape[i+1])
end
ns = collect.(StepRange.(1, ns, shape))
ns = Tuple(ns)
return MultilevelChains{typeof(samples), typeof(ns)}(samples, ns)
# if first sample has not beed proposed by multilevel sampler
elseif all(i-> mod( shape[i]-1, shape[i-1]-1 ) == 0, 2:length(shape) )
ns = [ shape[end]-1 ]
for i=length(c.samples)-1:-1:1
pushfirst!(ns, (shape[i]-1)÷(shape[i+1]-1))
end
ns = vcat.( ones(Int, length(shape)), collect.(StepRange.(2, ns, shape)) )
ns = Tuple(ns)
return MultilevelChains{typeof(samples), typeof(ns)}(samples, ns)
else
throw(ArgumentError("Subchain lengths $(length.(c.samples)) are not consistent"))
end
if save_reject
reject = getindex.(c.samples[level], 2 + save_logprob)
if ! save_logp return [ zip(x, reject) ... ] end
end
end
function Base.getindex(m::MultilevelChains, i::Integer)
if i == 1 return c.samples[1] end
! # TODO: complete this function
end
function AbstractMCMC.bundle_samples(samples::Tuple, model::AbstractLogDensity, da::DelayedAcceptance, state, chain_type::MultilevelChains; kwargs...)
return MultilevelChains(samples)
end
=#
#=
function Base.getindex(m::MultilevelChains, i::Integer)
if i == 1 return c.samples[1] end
return [ zip(x, logp_x, reject) ... ]
has_random_length = any( hasfield.(eltype.(c.samples[1:i]), :n) )
if ! has_random_length
shape = length.(c.samples[1:i])
if mod( shape[i-1], shape[i] ) == 0
elseif mod( shape[i-1]-1 , shape[i]-1 ) == 0
else
throw(ArgumentError("Subchain lengths $(length.(c.samples)) are not consistent"))
end
else
ns = [ getfield.(c.samples[i], :n) ]
for j=i-1:-1:2
pushfirst!(ns, getfield.(c.samples[j][ns[1]], :n))
end
param = ( eltype(c.samples[1]).parameters[1][2:end] ... , )
#if :logp_x in param getfield(c.samples[i], :logp_x)
x = hcat( (getfield.(c.samples[i][ns[i]], :x) for i=1:length(c.samples) ) ... )
x = [ Tuple(x[i,:]) for i=1:size(x,1) ]
data = Dict{Symbol, Vector}( :x=>x )
if :logp_x in param data[:logp_x] = getfield.(c.samples[i], :logp_x) end
if :reject in param data[:reject] = getfield.(c.samples[i], :reject) end
states = NamedTuple{param}.(data)
end
end
# Auto correlation
function autocor(x::Vector, lag::Int)
m = mean(x)
d = x .- m
v = sum(d[1+lag:end] .* d[1:end-lag])
c = v / sum(d[1].^2)
return c
end
=#
\ No newline at end of file
autocor(x::Vector, lag::AbstractRange) = [autovar(x,i) for i in lag]
autocor(x::Vector, lag::Vector{<:Integer}) = [autovar(x,i) for i in lag]
\ No newline at end of file
......@@ -229,18 +229,12 @@ function AbstractMCMC.save!!(samples::Vector{<:Vector}, sample::Tuple, iteration
s = [s[1], ( (x..., n+1) for (x,n) in zip(s[2:end], length.(samples)[1:end-1]) )... ]
samples .= BangBang.push!!.(samples, s)
#println("\nIteration $iteration , level $(length(sample[5])+1)")
if sample[4]-1 > 0
s = _unpack_subsample(da, sample)
#println(" ", which(_unpack_subsample, typeof((da, sample)) ) )
#println("\tsubsample ", typeof.(s), " ", length.(s))
#println("\t s ", s)
#println("\tBefore ", typeof.(samples), " ", length.(samples))
s[2:end] .= [ map( t->(t[1:end-1]..., t[end]+n), x) for (x,n) in zip(s[2:end], length.(samples[1:end-1])) ]
samples[1] = BangBang.append!!(samples[1], s[1])
samples[2:end-1] .= BangBang.append!!.(samples[2:end-1], s[2:end])
end
#println("\tAfter ", typeof.(samples), " ", length.(samples))
return samples
end
......@@ -248,13 +242,10 @@ end
function AbstractMCMC.samples(sample::Tuple, model::AbstractLogDensity, da::MLDA{P,L,save_logprob,save_reject,true}, N::Integer; kwargs...) where {P,L <: Tuple{Vararg{<:Distribution}},save_logprob, save_reject}
return AbstractMCMC.samples(sample, model, da; kwargs...)
end
function AbstractMCMC.save!!(samples::Vector{<:Vector}, sample::Tuple, iteration::Integer, mode::AbstractLogDensity, da::MLDA{P,L,saves_logprob,saves_reject,true}, N::Integer; kwargs...) where {P,L <: Tuple{Vararg{<:Distribution}},saves_logprob, saves_reject}
function AbstractMCMC.save!!(samples::Vector{<:Vector}, sample::Tuple, iteration::Integer, mode::AbstractLogDensity, da::MLDA{P,L,save_logprob,saves_reject,true}, N::Integer; kwargs...) where {P,L <: Tuple{Vararg{<:Distribution}},save_logprob, save_reject}
return AbstractMCMC.save!!(samples, sample, iteration, mode, da; kwargs...)
end
#function AbstractMCMC.bundle_samples(samples::Vector{<:Vector}, model::AbstractLogDensity, da::MLDA{P,L,save_logprob,save_reject,true}, state, ::Type{MultilevelChains}; kwargs...) where {P,L,save_logprob,save_reject}
# return MultilevelChains(Tuple(samples...))
#end
AbstractMCMC.bundle_samples(samples::Vector{<:Vector}, model::AbstractLogDensity, da::MLDA{P,L,save_logprob,save_reject,true}, state, ::Type{MultilevelChains}; kwargs...) where {P,L <: Tuple{Vararg{<:Integer}},save_logprob,save_reject} = MultilevelChains(samples; fixed_length=true, save_logprob, save_reject)
AbstractMCMC.bundle_samples(samples::Vector{<:Vector}, model::AbstractLogDensity, da::MLDA{P,L,save_logprob,save_reject,true}, state, ::Type{MultilevelChains}; kwargs...) where {P,L <: Tuple{Vararg{<:Distribution}},save_logprob,save_reject} = MultilevelChains(samples; fixed_length=false, save_logprob, save_reject)
......@@ -91,26 +91,37 @@ end
c = sample(t, da, 100, discard_initial=1)
@test all( typeof.(c) == [ Vector{Tuple{Float64}}, Vector{Tuple{Float64, Int}}] )
@test length(c[2]) == 100
@test all( unique( last.(c[2]) ) .== last.(c[2]) )
da = MLDA(p2, (d,), true, false; save_subchain=true)
c = sample(t, da, 100, discard_initial=1)
@test all( typeof.(c) == [ Vector{Tuple{Float64,Float64}}, Vector{Tuple{Float64, Float64, Int}}] )
@test length(c[2]) == 100
@test issorted(last.(c[2]))
@test all( unique( last.(c[2]) ) .== last.(c[2]) )
da = MLDA(p2, (d,), false, true; save_subchain=true)
c = sample(t, da, 100, discard_initial=1)
@test all( typeof.(c) == [ Vector{Tuple{Float64, Bool}}, Vector{Tuple{Float64, Bool, Int}}] )
@test length(c[2]) == 100
@test issorted(last.(c[2]))
@test all( unique( last.(c[2]) ) .== last.(c[2]) )
da = MLDA(p2, (d,), true, true; save_subchain=true)
c = sample(t, da, 100, discard_initial=1)
@test all( typeof.(c) == [ Vector{Tuple{Float64,Float64,Bool}}, Vector{Tuple{Float64, Float64,Bool, Int}}] )
@test length(c[2]) == 100
@test issorted(last.(c[2]))
@test all( unique( last.(c[2]) ) .== last.(c[2]) )
da = MLDA(p3, (d,d,), false, false; save_subchain=true)
c = sample(t, da, 100, discard_initial=1)
@test all( typeof.(c) == [ Vector{Tuple{Float64}}, Vector{Tuple{Float64, Int}}, Vector{Tuple{Float64, Int}}] )
@test length(c[3]) == 100
@test issorted(last.(c[2])) && issorted(last.(c[3]))
@test all( unique( last.(c[2]) ) .== last.(c[2]) ) && all( unique( last.(c[3]) ) .== last.(c[3]) )
da = MLDA(p3, (d,d,), false, true; save_subchain=true)
......@@ -118,15 +129,31 @@ end
println(typeof.(c))
@test all( typeof.(c) .== [ Vector{Tuple{Float64, Bool}}, Vector{Tuple{Float64, Bool, Int}}, Vector{Tuple{Float64, Bool, Int}}] )
@test length(c[3]) == 100
@test issorted(last.(c[2])) && issorted(last.(c[3]))
@test all( unique( last.(c[2]) ) .== last.(c[2]) ) && all( unique( last.(c[3]) ) .== last.(c[3]) )
end
#=
@testset "chains" begin
energy(x::Tuple) = -sum(x.^2)
t = LogDensity(energy)
p = CyclicWalk(-1, 1, .1)
p2 = vcat(p,p)
# bundle samples as MultilevelChains
mlc = sample(t, da, 100; chain_type=MultilevelChains, discard_initial=1)
@test typeof(mlc) <: MultilevelChains
@test length(subchain(mlc,1)) == 600
@test length(subchain(mlc,2)) == 300
da = MLDA(p2, (3,); save_subchain=true)
c = sample(t, da, 100; chain_type=MultilevelChains, discard_initial=1)
@test typeof(c) <: MultilevelChains
s1 = subchain(c,1)
s2 = subchain(c,2)
@test length(s1) == 300
@test length(s2) == 100
@test eltype(s2) == Tuple{Float64, Float64}
#=
mlc = sample(t, da, 100 + 1; chain_type=MultilevelChains)
@test typeof(mlc) <: MultilevelChains
@test length(subchain(mlc,1)) == 600 + 1
......
using Test
using Distributions
using Statistics
using AbstractMCMC
using Revise
using MultilevelChainSampler
@testset "hastings" begin
# behaves like truncation on cyclic unit interval
t = LogDensity(Normal(0,1))
p = CyclicWalk(-1, 1, .1)
# ignore log probability
mh = MetropolisHastings(p, false)
c = sample(t, mh, 100000)
x = getfield.(c, :x)
@test length(c) == 100000
# sample multiple chains at once
mh = MetropolisHastings(p, false)
c = sample(t, mh, AbstractMCMC.MCMCSerial(), 1000, 3)
@test length(c) == 3
@test all(length.(c) .== 1000)
# sample multiple chains at once
mh = MetropolisHastings(p, true)
c = sample(t, mh, AbstractMCMC.MCMCSerial(), 1000, 3)
@test length(c) == 3
@test all( length(getfield.(c[1], :x)) == 1000)
# save logprob and rejection
mh = MetropolisHastings(p, true, true)
c = sample(t, mh, 1000)
@test length(c[1]) == 3
@test length(getfield.(c,:x))==1000
end
@testset "delayed_acceptance" begin
p = CyclicWalk(-1, 1, .1)
p2 = vcat(p,p)
energy(x::Tuple) = -sum(x.^2)
t = LogDensity(energy)
# sample from two level, ignore subchains
da = MLDA(p2, (2,))
c = sample(t, da, 100)
@test length(c) == 100
@test typeof(c[1]) <: NamedTuple
@test length(c[1]) == 1
@test length(c[1].x) == 2
# save save logprobs
da = MLDA(p2, (3,), true)
c = sample(t, da, 100)
@test length(c) == 100
@test typeof(c[1]) <: NamedTuple
@test length(c[1]) == 2
# save subchains
da = MLDA(p2, (3,); save_subchain=true)
c = sample(t, da, 100; discard_initial=1)
@test length.(c) == 100 .* size(da)
# sample from three level
p3 = vcat(p,p,p)
da = MLDA(p3, (2,3), true; save_subchain=true)
c = sample(t, da, 100, discard_initial=1)
@test length.(c) == 100 .* size(da)
# bundle samples as MultilevelChains
mlc = sample(t, da, 100; chain_type=MultilevelChains, discard_initial=1)
@test typeof(mlc) <: MultilevelChains
@test length(subchain(mlc,1)) == 600
@test length(subchain(mlc,2)) == 300
mlc = sample(t, da, 100 + 1; chain_type=MultilevelChains)
@test typeof(mlc) <: MultilevelChains
@test length(subchain(mlc,1)) == 600 + 1
@test length(subchain(mlc,2)) == 300 + 1
@test length(subchain(mlc,3)) == 100 + 1
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment