Skip to content
Snippets Groups Projects
Commit ea04fa80 authored by Luca Lenz's avatar Luca Lenz
Browse files

fixed more chain size errors, updated tests

parent b0890a60
No related branches found
No related tags found
No related merge requests found
......@@ -21,18 +21,24 @@ DelayedAcceptance(p, n::Distribution, args...) = DelayedAcceptance(p, rand(n), a
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, da::DelayedAcceptance; x0=nothing, f0=nothing, kwargs...)
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, da::DelayedAcceptance; x0=nothing, logp_x0=nothing, kwargs...)
if isnothing(x0)
x0 = propose(rng, da.proposals)
f0 = Tuple( evaluate(model, x0[1:l]) for l=1:length(x0) )
elseif isnothing(f0)
f0 = Tuple( evaluate(model, x0[1:l]) for l=1:length(x0) )
logp_x0 = Tuple( evaluate(model, x0[1:l]) for l=1:length(x0) )
elseif isnothing(logp_x0)
logp_x0 = Tuple( evaluate(model, x0[1:l]) for l=1:length(x0) )
end
# create subchain from course components
subchain = Tuple( [z] for z in zip(x0[1:end-1], f0[1:end-1]) )
#subchain = Tuple( [z] for z in zip(x0[1:end-1], logp_x0[1:end-1]) )
state = (x0, f0, true)
#create empy subchain
subchain = Tuple( Vector{Tuple{X,L}}() for (X,L) in zip(eltype.(x0[1:end-1]), eltype.(logp_x0[1:end-1])) )
state = (x0, logp_x0, true)
return (state[1:2]..., subchain ) , state
end
......@@ -43,11 +49,11 @@ function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, da::Dela
if length(da.proposals) == 2
subsampler = MetropolisHastings(da.proposals[1], true)
submodel = LogDensity( x -> evaluate(model, (x,)) )
subchain = sample(rng, submodel, subsampler, da.subchainlen[1]; x0=x[1], f0=logp_x[1])
subchain = sample(rng, submodel, subsampler, da.subchainlen[1]; x0=x[1], logp_x0=logp_x[1], discard_initial=1)
subchain = (subchain,)
else
subsampler = DelayedAcceptance( MultiProposal(da.proposals[1:end-1]...), da.subchainlen[1:end-1], true)
subchain = sample(rng, model, subsampler, da.subchainlen[end]; x0=x[1:end-1], f0=logp_x[1:end-1])
subchain = sample(rng, model, subsampler, da.subchainlen[end]; x0=x[1:end-1], logp_x0=logp_x[1:end-1], discard_initial=1)
end
# proposed sample is final sample of subchain
......@@ -86,8 +92,17 @@ function AbstractMCMC.samples(sample, model::AbstractLogDensity, da::DelayedAcce
end
return s
end
function AbstractMCMC.save!!(samples::Tuple, sample, iteration::Int, model::AbstractLogDensity, da::DelayedAcceptance{P,L,true}, N::Integer; kwargs...) where {P,L}
s = BangBang.append!!.(samples, (sample[3] ... , [ (sample[1][end], sample[2][end])] ))
# append subchain
Y_subchain = getindex.(sample[3], UnitRange.(1, lastindex.(sample[3]) .- 1) )
Y_subchain = (Y_subchain ... , Vector{eltype(samples[end])}())
s = BangBang.append!!.(samples, Y_subchain)
# append current state
Y_sample = map(x->[x...], zip.(sample[1], sample[2]))
s = BangBang.append!!.(s, Y_sample)
n = (da.subchainlen... , N)
n = reverse(cumprod(reverse(n)))
any(s .!== samples) && sizehint!.(s, n)
......@@ -108,7 +123,16 @@ function AbstractMCMC.samples(sample, model::AbstractLogDensity, da::DelayedAcce
return s
end
function AbstractMCMC.save!!(samples::Tuple, sample, iteration::Int, model::AbstractLogDensity, da::DelayedAcceptance{P,L,false}, N::Integer; kwargs...) where {P,L}
s = BangBang.append!!.(samples, (map(x -> first.(x), sample[3]) ... , [sample[1][end]] ))
# append subchain
subsamp = map(x->first.(x), sample[3])
Y_subchain = getindex.(subsamp, UnitRange.(1, lastindex.(subsamp) .- 1) )
Y_subchain = (Y_subchain ... , Vector{eltype(samples[end])}())
s = BangBang.append!!.(samples, Y_subchain)
# append current state
Y_sample = map(x->[x], sample[1])
s = BangBang.append!!.(s, Y_sample)
n = (da.subchainlen... , N)
n = reverse(cumprod(reverse(n)))
any(s .!== samples) && sizehint!.(s, n)
......
......@@ -8,15 +8,15 @@ MetropolisHastings(p::AbstractProposal, return_logprob=false) = MetropolisHastin
## Intitialize chain
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, mh::MetropolisHastings; x0=nothing, f0=nothing, kwargs...)
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, mh::MetropolisHastings; x0=nothing, logp_x0=nothing, kwargs...)
if isnothing(x0)
x0 = propose(rng, mh.proposal)
f0 = evaluate(model, x0)
elseif isnothing(f0)
f0 = evaluate(model, x0)
logp_x0 = evaluate(model, x0)
elseif isnothing(logp_x0)
logp_x0 = evaluate(model, x0)
end
state = (x0, f0, true)
state = (x0, logp_x0, true)
return state[1:2], state
end
......
......@@ -78,10 +78,10 @@ end
p3 = vcat(p,p,p)
da = DelayedAcceptance(p3, (2,3), true)
c = sample(t, da, 101)
@test length.(c) == (1+100*2+100*3,1+100*3, 1+100)
@test length.(c) == (1+100*2*3,1+100*3, 1+100)
c = sample(t, da, 100, discard_initial=1)
@test length.(c) == (100*2+100*3,100*3, 100)
@test length.(c) == (100*2*3,100*3, 100)
end
#@which sample(LogDensity(energy), da, 100)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment