Skip to content
Snippets Groups Projects
files.jl 1.72 KiB
Newer Older

function write_chain(csv_path::String, data::MultilevelChainSampler.SimpleChains)   
    states = data.states
    info   = get_info(data)

    # write to .csv files
    name, ext = splitext(csv_path)
    ext == ".csv" || error("Only CSV files allowed.")
    CSV.write("$name.states.csv",  DataFrame(states, :auto))
    CSV.write("$name.info.csv",    DataFrame(info))
    # ( ... ignore energies )

    # check if files were created
    ispath("$name.states.csv") && ispath("$name.info.csv")
end

function load_data(csv_path::String, tabs = [ "states", "info", "metrics" ])

    name, ext = splitext(csv_path)
    ext == ".csv" || error("Only CSV files allowed.")

    vals = Dict( (v=>nothing for v in tabs)... )
    for k in tabs
        f = "$name.$k.csv"
        if ispath(f)
            vals[k] = CSV.read(f, DataFrame)
        else
            print("File '$f' not found")
        end
    end
    return vals
end

function calculate_metrics(csv_path::String, d::Distribution, Ns=nothing)
    states = load_data(csv_path, ["states"])

    # Logarithmic time scale by default
    if isnothing(Ns) 
        Nmax = size(states,1)
        Ns = round.(Int, exp.([0:.01:1 ... ]  .* log(Nmax)) )
    end

    # Calculate the metrics
    n_chains = size(states, 2)
    ks = zeros(length(Ns), n_chains)
    cm = zeros(length(Ns), n_chains)
    for n=ProgressBar(1:n_chains)
        for (i,N) in enumerate(Ns)
            ks[i,n] = distance_ks(states[1:N, n], d)
            cm[i,n] = distance_ks(states[1:N, n], d)
        end
    end
    
    # Write to file
    columns = [ map(i->"$l_$i", 1:n_chains) for l in ["ks", "cm"] ]
    df = DataFrame( hcat(ks, cm), columns )
    df.N = Ns 
    name, ext = splitext(csv_path)
    CSV.write("$name.metrics.csv", df)
end