Commit 7d76d6c6 authored by Paul Schultz's avatar Paul Schultz
Browse files

numba now uses caching

parent 71cec808
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -4,6 +4,7 @@ from numba import njit, float64, complex128, void, int32
# from assimulo.solvers import CVODE
# from assimulo.problem import Explicit_Problem
import scipy.sparse as sps
import os
# The dynamics should be structured according to what is calculated on the network links.
......@@ -154,24 +155,33 @@ def define_network_rhs_codegen(node_list, Y):
l_indptr = len(indptr)
indices = coupling_sp.indices
file_header = """
import numpy as np
from numba import njit, float64, complex128, void, int32
"""
def_network_rhs_string = """
@njit(float64[:](float64[:], float64), cache=True)
def network_rhs_numba(y, t):
dydt = np.empty(total_length + length)
dydt = np.empty({total_length} + {length})
v = y[:total_length].view(np.complex128)
omega = y[total_length:]
dv = dydt[:total_length].view(np.complex128)
domega = dydt[total_length:]
v = y[:{total_length}].view(np.complex128)
omega = y[{total_length}:]
dv = dydt[:{total_length}].view(np.complex128)
domega = dydt[{total_length}:]
i = np.zeros(l_indptr - 1, dtype=np.complex128)
i = np.zeros({l_indptr} - 1, dtype=np.complex128)
index = 0
for row, number_of_entries in enumerate(indptr[1:]):
while index < number_of_entries:
i[row] += data[index] * v[indices[index]]
index += 1
"""
""".format(total_length=total_length,
length=length,
l_indptr=l_indptr
)
def_network_rhs_string += "".join([node.node_dynamics_string(j) for j, node in enumerate(node_list)])
......@@ -179,10 +189,23 @@ def network_rhs_numba(y, t):
return dydt
"""
context = globals()
context.update(locals())
exec def_network_rhs_string in context
cdir = os.path.join(os.getcwd(), "__psdcache__")
if not os.path.exists(cdir):
os.mkdir(cdir)
os.mknod(os.path.join(cdir, "__init__.py"))
with open(os.path.join(cdir, "compile_function.py"), "w") as text_file:
text_file.write(file_header)
text_file.write("\n")
text_file.write("data=np.array({})".format(data.tolist()))
text_file.write("\n")
text_file.write("indptr=np.array({})".format(indptr.tolist()))
text_file.write("\n")
text_file.write("indices=np.array({})".format(indices.tolist()))
text_file.write("\n")
text_file.write(def_network_rhs_string)
from __psdcache__.compile_function import network_rhs_numba
return network_rhs_numba
......
......@@ -45,6 +45,7 @@ def define_gen_rc(brp, rhs, init=None, method="krylov"):
ic[2 * system_size + batch + 1] += .1 * (1. - 2. * np.random.random())
else:
print("failed")
ic = init
return ic, ()
return generate_run_conditions
......@@ -59,9 +60,10 @@ def main(sim_dir=default_dir, create_test_data=True, run_test=True, flag_baobab=
import time
print "start", "{0.tm_year}-{0.tm_mon}.{0.tm_mday}.-{0.tm_hour}h{0.tm_min}m.timestamp".format(time.localtime(time.time()))
#node_list, Y = load_PYPSA('microgrid_testcase.npz')
# node_list, Y = load_PYPSA('microgrid_testcase.npz')
# load_flow_sol = None
node_list, Y, load_flow_sol = load_PyPSA_df("bus_admittance.npy", "bus_parameters")
#rhs = define_network_rhs(node_list, Y)
# rhs = define_network_rhs(node_list, Y)
rhs = define_network_rhs_codegen(node_list, Y)
print "compilation finished", "{0.tm_year}-{0.tm_mon}.{0.tm_mday}.-{0.tm_hour}h{0.tm_min}m.timestamp".format(time.localtime(time.time()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment