Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
function write_chain(csv_path::String, data::MultilevelChainSampler.SimpleChains)
states = data.states
info = get_info(data)
# write to .csv files
name, ext = splitext(csv_path)
ext == ".csv" || error("Only CSV files allowed.")
CSV.write("$name.states.csv", DataFrame(states, :auto))
CSV.write("$name.info.csv", DataFrame(info))
# ( ... ignore energies )
# check if files were created
ispath("$name.states.csv") && ispath("$name.info.csv")
end
function load_data(csv_path::String, tabs = [ "states", "info", "metrics" ])
name, ext = splitext(csv_path)
ext == ".csv" || error("Only CSV files allowed.")
vals = Dict( (v=>nothing for v in tabs)... )
for k in tabs
f = "$name.$k.csv"
if ispath(f)
vals[k] = CSV.read(f, DataFrame)
else
print("File '$f' not found")
end
end
return vals
end
function calculate_metrics(csv_path::String, d::Distribution, Ns=nothing)
states = load_data(csv_path, ["states"])
# Logarithmic time scale by default
if isnothing(Ns)
Nmax = size(states,1)
Ns = round.(Int, exp.([0:.01:1 ... ] .* log(Nmax)) )
end
# Calculate the metrics
n_chains = size(states, 2)
ks = zeros(length(Ns), n_chains)
cm = zeros(length(Ns), n_chains)
for n=ProgressBar(1:n_chains)
for (i,N) in enumerate(Ns)
ks[i,n] = distance_ks(states[1:N, n], d)
cm[i,n] = distance_ks(states[1:N, n], d)
end
end
# Write to file
columns = [ map(i->"$l_$i", 1:n_chains) for l in ["ks", "cm"] ]
df = DataFrame( hcat(ks, cm), columns )
df.N = Ns
name, ext = splitext(csv_path)
CSV.write("$name.metrics.csv", df)
end