function write_chain(csv_path::String, data::MultilevelChainSampler.SimpleChains) states = data.states info = get_info(data) # write to .csv files name, ext = splitext(csv_path) ext == ".csv" || error("Only CSV files allowed.") CSV.write("$name.states.csv", DataFrame(states, :auto)) CSV.write("$name.info.csv", DataFrame(info)) # ( ... ignore energies ) # check if files were created ispath("$name.states.csv") && ispath("$name.info.csv") end function load_data(csv_path::String, tabs = [ "states", "info", "metrics" ]) name, ext = splitext(csv_path) ext == ".csv" || error("Only CSV files allowed.") vals = Dict( (v=>nothing for v in tabs)... ) for k in tabs f = "$name.$k.csv" if ispath(f) vals[k] = CSV.read(f, DataFrame) else print("File '$f' not found") end end return vals end function calculate_metrics(csv_path::String, d::Distribution, Ns=nothing) states = load_data(csv_path, ["states"]) # Logarithmic time scale by default if isnothing(Ns) Nmax = size(states,1) Ns = round.(Int, exp.([0:.01:1 ... ] .* log(Nmax)) ) end # Calculate the metrics n_chains = size(states, 2) ks = zeros(length(Ns), n_chains) cm = zeros(length(Ns), n_chains) for n=ProgressBar(1:n_chains) for (i,N) in enumerate(Ns) ks[i,n] = distance_ks(states[1:N, n], d) cm[i,n] = distance_ks(states[1:N, n], d) end end # Write to file columns = [ map(i->"$l_$i", 1:n_chains) for l in ["ks", "cm"] ] df = DataFrame( hcat(ks, cm), columns ) df.N = Ns name, ext = splitext(csv_path) CSV.write("$name.metrics.csv", df) end