Skip to content
Snippets Groups Projects
Commit ea04fa80 authored by Luca Lenz's avatar Luca Lenz
Browse files

fixed more chain size errors, updated tests

parent b0890a60
No related branches found
No related tags found
No related merge requests found
...@@ -21,18 +21,24 @@ DelayedAcceptance(p, n::Distribution, args...) = DelayedAcceptance(p, rand(n), a ...@@ -21,18 +21,24 @@ DelayedAcceptance(p, n::Distribution, args...) = DelayedAcceptance(p, rand(n), a
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, da::DelayedAcceptance; x0=nothing, f0=nothing, kwargs...) function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, da::DelayedAcceptance; x0=nothing, logp_x0=nothing, kwargs...)
if isnothing(x0) if isnothing(x0)
x0 = propose(rng, da.proposals) x0 = propose(rng, da.proposals)
f0 = Tuple( evaluate(model, x0[1:l]) for l=1:length(x0) ) logp_x0 = Tuple( evaluate(model, x0[1:l]) for l=1:length(x0) )
elseif isnothing(f0) elseif isnothing(logp_x0)
f0 = Tuple( evaluate(model, x0[1:l]) for l=1:length(x0) ) logp_x0 = Tuple( evaluate(model, x0[1:l]) for l=1:length(x0) )
end end
# create subchain from course components # create subchain from course components
subchain = Tuple( [z] for z in zip(x0[1:end-1], f0[1:end-1]) ) #subchain = Tuple( [z] for z in zip(x0[1:end-1], logp_x0[1:end-1]) )
state = (x0, f0, true)
#create empy subchain
subchain = Tuple( Vector{Tuple{X,L}}() for (X,L) in zip(eltype.(x0[1:end-1]), eltype.(logp_x0[1:end-1])) )
state = (x0, logp_x0, true)
return (state[1:2]..., subchain ) , state return (state[1:2]..., subchain ) , state
end end
...@@ -43,11 +49,11 @@ function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, da::Dela ...@@ -43,11 +49,11 @@ function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, da::Dela
if length(da.proposals) == 2 if length(da.proposals) == 2
subsampler = MetropolisHastings(da.proposals[1], true) subsampler = MetropolisHastings(da.proposals[1], true)
submodel = LogDensity( x -> evaluate(model, (x,)) ) submodel = LogDensity( x -> evaluate(model, (x,)) )
subchain = sample(rng, submodel, subsampler, da.subchainlen[1]; x0=x[1], f0=logp_x[1]) subchain = sample(rng, submodel, subsampler, da.subchainlen[1]; x0=x[1], logp_x0=logp_x[1], discard_initial=1)
subchain = (subchain,) subchain = (subchain,)
else else
subsampler = DelayedAcceptance( MultiProposal(da.proposals[1:end-1]...), da.subchainlen[1:end-1], true) subsampler = DelayedAcceptance( MultiProposal(da.proposals[1:end-1]...), da.subchainlen[1:end-1], true)
subchain = sample(rng, model, subsampler, da.subchainlen[end]; x0=x[1:end-1], f0=logp_x[1:end-1]) subchain = sample(rng, model, subsampler, da.subchainlen[end]; x0=x[1:end-1], logp_x0=logp_x[1:end-1], discard_initial=1)
end end
# proposed sample is final sample of subchain # proposed sample is final sample of subchain
...@@ -86,8 +92,17 @@ function AbstractMCMC.samples(sample, model::AbstractLogDensity, da::DelayedAcce ...@@ -86,8 +92,17 @@ function AbstractMCMC.samples(sample, model::AbstractLogDensity, da::DelayedAcce
end end
return s return s
end end
function AbstractMCMC.save!!(samples::Tuple, sample, iteration::Int, model::AbstractLogDensity, da::DelayedAcceptance{P,L,true}, N::Integer; kwargs...) where {P,L} function AbstractMCMC.save!!(samples::Tuple, sample, iteration::Int, model::AbstractLogDensity, da::DelayedAcceptance{P,L,true}, N::Integer; kwargs...) where {P,L}
s = BangBang.append!!.(samples, (sample[3] ... , [ (sample[1][end], sample[2][end])] )) # append subchain
Y_subchain = getindex.(sample[3], UnitRange.(1, lastindex.(sample[3]) .- 1) )
Y_subchain = (Y_subchain ... , Vector{eltype(samples[end])}())
s = BangBang.append!!.(samples, Y_subchain)
# append current state
Y_sample = map(x->[x...], zip.(sample[1], sample[2]))
s = BangBang.append!!.(s, Y_sample)
n = (da.subchainlen... , N) n = (da.subchainlen... , N)
n = reverse(cumprod(reverse(n))) n = reverse(cumprod(reverse(n)))
any(s .!== samples) && sizehint!.(s, n) any(s .!== samples) && sizehint!.(s, n)
...@@ -108,7 +123,16 @@ function AbstractMCMC.samples(sample, model::AbstractLogDensity, da::DelayedAcce ...@@ -108,7 +123,16 @@ function AbstractMCMC.samples(sample, model::AbstractLogDensity, da::DelayedAcce
return s return s
end end
function AbstractMCMC.save!!(samples::Tuple, sample, iteration::Int, model::AbstractLogDensity, da::DelayedAcceptance{P,L,false}, N::Integer; kwargs...) where {P,L} function AbstractMCMC.save!!(samples::Tuple, sample, iteration::Int, model::AbstractLogDensity, da::DelayedAcceptance{P,L,false}, N::Integer; kwargs...) where {P,L}
s = BangBang.append!!.(samples, (map(x -> first.(x), sample[3]) ... , [sample[1][end]] )) # append subchain
subsamp = map(x->first.(x), sample[3])
Y_subchain = getindex.(subsamp, UnitRange.(1, lastindex.(subsamp) .- 1) )
Y_subchain = (Y_subchain ... , Vector{eltype(samples[end])}())
s = BangBang.append!!.(samples, Y_subchain)
# append current state
Y_sample = map(x->[x], sample[1])
s = BangBang.append!!.(s, Y_sample)
n = (da.subchainlen... , N) n = (da.subchainlen... , N)
n = reverse(cumprod(reverse(n))) n = reverse(cumprod(reverse(n)))
any(s .!== samples) && sizehint!.(s, n) any(s .!== samples) && sizehint!.(s, n)
......
...@@ -8,15 +8,15 @@ MetropolisHastings(p::AbstractProposal, return_logprob=false) = MetropolisHastin ...@@ -8,15 +8,15 @@ MetropolisHastings(p::AbstractProposal, return_logprob=false) = MetropolisHastin
## Intitialize chain ## Intitialize chain
function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, mh::MetropolisHastings; x0=nothing, f0=nothing, kwargs...) function AbstractMCMC.step(rng::AbstractRNG, model::AbstractLogDensity, mh::MetropolisHastings; x0=nothing, logp_x0=nothing, kwargs...)
if isnothing(x0) if isnothing(x0)
x0 = propose(rng, mh.proposal) x0 = propose(rng, mh.proposal)
f0 = evaluate(model, x0) logp_x0 = evaluate(model, x0)
elseif isnothing(f0) elseif isnothing(logp_x0)
f0 = evaluate(model, x0) logp_x0 = evaluate(model, x0)
end end
state = (x0, f0, true) state = (x0, logp_x0, true)
return state[1:2], state return state[1:2], state
end end
......
...@@ -78,10 +78,10 @@ end ...@@ -78,10 +78,10 @@ end
p3 = vcat(p,p,p) p3 = vcat(p,p,p)
da = DelayedAcceptance(p3, (2,3), true) da = DelayedAcceptance(p3, (2,3), true)
c = sample(t, da, 101) c = sample(t, da, 101)
@test length.(c) == (1+100*2+100*3,1+100*3, 1+100) @test length.(c) == (1+100*2*3,1+100*3, 1+100)
c = sample(t, da, 100, discard_initial=1) c = sample(t, da, 100, discard_initial=1)
@test length.(c) == (100*2+100*3,100*3, 100) @test length.(c) == (100*2*3,100*3, 100)
end end
#@which sample(LogDensity(energy), da, 100) #@which sample(LogDensity(energy), da, 100)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment