Skip to content
Snippets Groups Projects
stats.jl 3.15 KiB
Newer Older
# Auto covariance
#=
function autocov(x::AbstractVector, lag::Int)
    m = mean(x)
    cov = (x[lag+1:end] .- m) .* (x[1:end-lag] .- m)
    cor = (x .- m) .^ 2
    return sum(cov)/sum(cor)
end
=#

function autocov(c::RejectionChains)
    X = states(c)
    C = [ autocov(X[i]) for i=1:nchains(c) ]    
    return C
end

function autocov(c::RejectionChains, lags::AbstractVector{<:AbstractVector})
    X = states(c)
    C = [ autocov(X[i], lags[i]) for i=1:nchains(c) ]    
    return C
end
function autocov(c::RejectionChains, lags::AbstractVector{<:Integer})
    X = states(c)
    C = [ autocov(X[i], lags) for i=1:nchains(c) ]    
    return C
end


# Log partition function, constant factor for normalize log density 
function logpart(logp::AbstractVector)
    return -log( sum( exp.(-logp) ) )
end

function logpart(c::RejectionChains)

    # single level
    if ! is_multilevel(c)
        reps = repetitions(c)
        return [ 
            (
                logprobs = getfield.(c.samples[i], :logprob);
                logprobs = vcat( fill.(logprobs, reps[i])... );
                logpart(logprobs)
            )
            for i=1:nchains(c)
        ]

    # multi level
    else
        reps = repetitions(c)
        lvls = levels(c); maxlvl = maximum(lvls)
        return [
            [ # return one constant per level
                (
                    idx = lvls[i] <= l;
                    logprobs = getindex.(getfield.(c.samples[i][idx], :logprob), l);
                    logprobs = vcat( fill.(logprobs, reps[i][idx])... );
                    logpart(logprobs)
                )
                for l = 1:maxlvl
            ]
            for i=1:nchains(c)
        ]
    end
end

function empirical_cdf(c::RejectionChains, x)
    x = convert(eltype(c), x)

    if ! is_multilevel(c)
        reps = repetitions(c)
        return [
            (
                s = getfield.(c.samples[i], :state);
                sum( reps[i] .* [ all(X .<= x) for X in s] ) / sum(reps[i]) 
            )
            for i=1:nchains(c)
        ]
    else

        reps = repetitions(c)
        lvls = levels(c); maxlvl = maximum(lvls)
        
        cdfs = [ 
            (
                idx = lvls[i] .<= 1;
                s = getfield.(c.samples[i], :state)[idx];
                [ sum(reps[i][idx] .* [ all(X .<= x) for X in s]) / sum(reps[i][idx]) ]
            )
            for i=1:nchains(c)
        ]
        sizehint!.(cdfs, maxlvl)

        logparts = logpart(c)
        #Z = mean(logparts) # average over chain, otherwise for each chain seperately
        for l=2:maxlvl
            for i=1:nchains(c)
                idx = lvls[i] <= l
                s = getfield.(c.samples[i], :state)[idx]; # ?????
                f = getfield.(c.samples[i], :logprob)[idx];
                f0 = getindex.(f, l-1)
                f1 = getindex.(f, l)
                Z0,Z1 = logparts[i][l-1:l]
                
                _cdf = sum( 
                    reps[i][idx] .* [ all(X .<= x) for x in s ] 
                                 .* (1 .- exp.(f1 - f0 - Z1 + Z0)) 
                ) / sum(reps[i][idx])
                push!(cdfs[i], _cdf) 
            end
        end
        return cdfs
    end 
end