Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
## Multi-level state
abstract type HierarchicalState <: AbstractState end
#=
# by default, the multilevel state wraps a tuple of states
struct MultilevelState{T <: Tuple} <: HierarchicalState
levels :: T
end
MultilevelState(levels...; kwargs...) = MultilevelState(levels)
course(::MultilevelState) = MultilevelState(s.levels[1:end-1])
fine(s::MultilevelState) = s.levels[end]
combine(s::MultilevelState, fine_state) = MultilevelState(s.levels..., fine_state)
level(s::MultilevelState) = length(s.levels)
# Multi-level log density
abstract type HierarchicalLogDensity{S <: HierarchicalState} <: AbstractLogDensity{S} end
struct MultilevelLogDensity{F, S} <: HierarchicalLogDensity{S}
logdensities :: F
end
logdensity(d::HierarchicalLogDensity, state::HierarchicalState; level=length(state)) = logdensity(d, state)
=#
#level(s::MultilevelState) = length(s.states)
#Base.eltype(states::MultilevelState) = eltype(states.states)
## construct from single level states
#MultilevelState(states::State ... ) = MultilevelState(getindex.([states ... ], :state))
## recursive construction
#course(states::MultilevelState) = MultilevelState(states[1:end-1])
#fine(states::MultilevelState) = states[end]
#combine(states::MultilevelState, state) = MultilevelState(vcat(states.states ... , state))
#=
## Multi-level log density
abstract type HierarchicalLogDensity{S} <: AbstractLogDensity{S} end
struct MultilevelLogDensity{F, S} <: HierarchicalLogDensity{S}
logdensities :: F
end
logdensity(d::MultilevelLogDensity, state::MultilevelState) = logdensity(d.logdensities, state.states
#logdensity(d::HierarchicalLogDensity, state::HierarchicalState, level::Int)
function logdensity(d::MultilevelLogDensity, state::HierarchicalState, level::Int)
logdensity(d.logdensities[level], state.states[level]
end
#MultilevelLogDensity(logdensities::AbstractVector{<:AbstractLogDensity}) = MultilevelLogDensity(logdensities
=#
#abstract type AbstractMultilevelLogDensity <: LogDensity{S} end
#abstract type AbstractMultilevelLogDensity <: LogDensity{S} end
#struct MultilevelLogDensity <: AbstractMultilevelLogDensity
#end