import osmnx as ox
import networkx as nx
import pandas as pd

# import geopandas as gpd
import numpy as np

# import matplotlib.pyplot as plt
import os

from pathlib import Path

from shapely.strtree import STRtree

# import src.PopulationModule as pm
import src.GravityModule2 as gm

def add_emergencies(G, north, east, south, west, date, key, tag):
    nodes, edges = ox.graph_to_gdfs(G)

    # check if emergency-filtered file osm exists already
    if tag == "*":
        filetag = "all"
        filetag = tag
    if not Path(
        os.system(f"mkdir data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}")
            f"osmium extract -b {west},{south},{east},{north} data/historic-data.nosync/{date}/germany.osm.pbf\
            -o data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/nofilter.osm.pbf --overwrite"

    if not Path(
            f"osmium tags-filter data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/nofilter.osm.pbf\
            -o data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/filter-{key}-{filetag}.osm.pbf --overwrite"
            f"osmium cat data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/filter-{key}-{filetag}.osm.pbf -o data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/filter-{key}-{filetag}.osm.bz2 --overwrite"
            f"bzip2 -d data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/filter-{key}-{filetag}.osm.bz2 -f"

    # get fire-stations/hospitals
    df = ox.geometries_from_xml(
    if len(df) == 0:
        print(f"No {key}_{tag} in specified area.")
    emergencies = df[~df[key].isna()]

    # set true/false for emergency nodes
    if key not in nodes.columns:
        nodes[key] = np.full(len(nodes), False)

    emerg_geos = emergencies.geometry
    node_tree = STRtree(nodes.geometry)

    for j, geo in enumerate(emerg_geos):
        # find nearest emegreny to road i
        emergency_point = geo.centroid
        idx = node_tree.nearest(emergency_point)

        # set emergency
        nodes.iloc[idx, nodes.columns.get_loc(key)] = emergencies[key].iloc[j]

    nx.set_node_attributes(G, nodes[key], name=key)

def emergency_graph_from_osmfile(
    osm_file, emergency, west, south, east, north, date=None, roads="all"
    # check if unfiltered osm-file with respected bounding box exists
    G = gm.graph_from_osmfile(
        osm_file, west, south, east, north, date=date, roads=roads

    # check if emergency-filtered file osm exists already
    if not Path(
            f"osmium tags-filter data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/nofilter.osm.pbf\
            -o data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/filter-{emergency}.osm.pbf --overwrite"
            f"osmium cat data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/filter-{emergency}.osm.pbf -o data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/filter-{emergency}.osm.bz2 --overwrite"
            f"bzip2 -d data/cache/osmfiles/{date}/{north}-{south}-{west}-{east}/filter-{emergency}.osm.bz2 -f"

    nodes, edges = ox.graph_to_gdfs(G)

    # get fire-stations/hospitals
    df = ox.geometries_from_xml(
    emergencies = df[df["amenity"] == emergency]

    # set true/false for emergency nodes
    nodes["emergency"] = np.full(len(nodes), False)

    emerg_geos = emergencies.geometry
    node_tree = STRtree(nodes.geometry)

    index_by_id = dict((id(pt), i) for i, pt in enumerate(nodes.geometry))

    for geo in emerg_geos:
        # find nearest emegreny to road i
        emergency_point = geo.centroid
        nearest_node = node_tree.nearest(emergency_point)

        # find index and osmid
        idx = index_by_id[id(nearest_node)]
        # n_osmid = nodes.iloc[idx].name

        # set emergency to True
        nodes.iloc[idx, nodes.columns.get_loc("emergency")] = True

    G = ox.graph_from_gdfs(nodes, edges)

    # impute speed on all edges missing data
    G = ox.add_edge_speeds(G)
    # calculate travel time (seconds) for all edges
    G = ox.add_edge_travel_times(G)

    # population and population per box
    # city_df = pd.read_csv("data/city_pop_clean.csv", index_col=0)
    pop = pm.pop_guess_from_poi(north, south, east, west)

    # add bbox to graph

    # boxcount = pm.nodes_in_grid_bbox(G, 50)
    R = gm.residential_graph(date, north, south, west, east)

    # G = pm.population_from_voronoi_polys(G, pop, boxcount)
    G = pm.population_from_voronoi_polys(G, R, pop, box_bins=10)

    G.bbox = [north, west, south, east]
    return G

def set_shortest_path_to_emergency(G, weight="travel_time"):
    nodes = ox.graph_to_gdfs(G, edges=False)
    path_length_emerg = nx.multi_source_dijkstra_path_length(
        G, list(nodes[nodes["emergency"]].index), weight=weight
    path_length_emerg = {k: v / 60 for k, v in path_length_emerg.items()}

    # if 'flood' in weight:
    #    nx.set_node_attributes(G, path_length_emerg, 'shortest_path_to_emergency_flood')
    # else:
    #    nx.set_node_attributes(G, path_length_emerg, 'shortest_path_to_emergency')
    return path_length_emerg

# def add_emergencies_to_gdfs(nodes, edges, emergency_df):
#    """
#    Add fire-stations or hospitals to the nodes and edges.
#    Split existing edges connecting the new emergency roads in a smart way.
#    May only return an UNDIRECTED graph, as only one road is split.
#    """
#    #warnings.filterwarnings("ignore", category=FutureWarning) #ignore deprecation warnings
#    nodes.loc[:, 'ref'] = 0
#    #edges_updated = edges.copy()
#    edg_tree = STRtree(edges.geometry)
#    index_by_id = dict((id(pt), i) for i, pt in enumerate(edges.geometry))
#    for i, v in emergency_df.iterrows():
#        #find nearest emegreny to road i
#        emergency_point = v.geometry.centroid
#        nearest_road = edg_tree.nearest(emergency_point)
#        #find index and edges u,v,k
#        idx = index_by_id[id(nearest_road)]
#        e_idx = edges.iloc[idx].name
#        #split road
#        p, q = nearest_points(nearest_road, emergency_point)
#        #new road to emergency
#        emergency_linestring = LineString([p, q])
#        #update indices
#        p_id = id(p)
#        q_id = id(q)
#        pi_id = (e_idx[0], p_id, 0)
#        pj_id = (p_id, e_idx[1], 0)
#        pq_id = (p_id, q_id, 0)
#        #add emergency-linestring to df
#        edg_e = edges.iloc[idx].copy()
# = pq_id
#        edg_e['osmid'] = id(emergency_linestring)
#        edg_e['length'] *= emergency_linestring.length / nearest_road.length
#        edg_e['geometry'] = emergency_linestring
#        #edg_e['from'] = q_id
#        #edg_e['to'] = p_id
#        #update nodes
#        nd_p = nodes.iloc[0].copy()
#        nd_q = nd_p.copy()
# = p_id
# = q_id
#        nd_p['x'] = p.x
#        nd_p['y'] = p.y
#        nd_p['geometry'] = p
#        nd_q['x'] = q.x
#        nd_q['y'] = q.y
#        nd_q['ref'] = 'emergency'
#        nd_q['geometry'] = q
#        nodes = nodes.append([nd_p, nd_q])
#        #split edge
#        #allow small inaccuracies at split
#        buff = p.buffer(0.000001)
#        road_split = split(nearest_road, buff)
#        #check if road was split, or if nearest point to emergency is at one endpoint of the road
#        #this can be done by comparing the max split lenth to the nearest road_geometry length.
#        #if road was split, update dataframes
#        split_lengths = (road.length for road in road_split)
#        split_compare = abs(max(split_lengths) - nearest_road.length)
#        if split_compare > 1e-5:
#            first_seg, buff_seg, last_seg = road_split
#            line = LineString(list(first_seg.coords) + list(p.coords) + list(last_seg.coords))
#            a, b = split(line, p)
#            #update dataframes with split edges
#            edg_a = edges.iloc[idx].copy()
#   = pi_id
#            edg_a['osmid'] = id(a)
#            edg_a['length'] *= a.length/(a.length+b.length)
#            edg_a['geometry'] = a
#            edg_b = edges.iloc[idx].copy()
#   = pj_id
#            edg_b['osmid'] = id(b)
#            edg_b['length'] *= b.length/(a.length+b.length)
#            edg_b['geometry'] = b
#            edges = edges.append([edg_a, edg_b, edg_e])
#            #drop old edge
#            edges = edges.drop(e_idx)
#    return nodes, edges

# def prepare_emergency_graph(osm_file, emergency, west, south, east, north, GCC=True):
#    #warnings.filterwarnings("ignore", category=FutureWarning) #ignore deprecation warnings
#    #check if unfiltered osm-file with respected bounding box exists
#    if not Path(f"data/cache/osmfiles/{north}-{south}-{west}-{east}/nofilter.osm.pbf").is_file():
#        os.system(f"mkdir data/cache/osmfiles/{north}-{south}-{west}-{east}")
#        os.system(f"osmium extract -b {west},{south},{east},{north} {osm_file} -o data/cache/osmfiles/{north}-{south}-{west}-{east}/nofilter.osm.pbf --overwrite")
#    #check if highway-filtered file osm exists already
#    if not Path(f"data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-highway.osm").is_file():
#        os.system(f"osmium tags-filter data/cache/osmfiles/{north}-{south}-{west}-{east}/nofilter.osm.pbf w/highway=motorway,trunk,primary,secondary,tertiary,unclassified,residential,motorway_link,trunk_link,primary_link,secondary_link,tertiary_link,living_street -o data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-highway.osm.pbf --overwrite")
#        os.system(f"osmium cat data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-highway.osm.pbf -o data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-highway.osm.bz2 --overwrite")
#        os.system(f"bzip2 -d data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-highway.osm.bz2 -f")
#    #prepare graph from osmfile
#    G = ox.graph_from_xml(f"data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-highway.osm", retain_all=True, simplify=True)
#    if GCC:
#        G = ox.utils_graph.get_largest_component(G, strongly=True)
#    #check if emergency file exists alreadt
#    if not Path(f"data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-{emergency}.osm").is_file():
#        os.system(f"osmium tags-filter data/cache/osmfiles/{north}-{south}-{west}-{east}/nofilter.osm.pbf emergency={emergency} -o data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-{emergency}.osm.pbf --overwrite")
#        os.system(f"osmium cat data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-{emergency}.osm.pbf -o data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-{emergency}.osm.bz2 --overwrite")
#        os.system(f"bzip2 -d data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-{emergency}.osm.bz2 -f")
#    #get df with emergencies
#    emergency_df = ox.geometries_from_xml(f"data/cache/osmfiles/{north}-{south}-{west}-{east}/filter-{emergency}.osm")
#    U = ox.get_undirected(G)
#    nodes, edges = ox.graph_to_gdfs(U, fill_edge_geometry=True)
#    #add emergency nodes to graph and split nereast edges to link them to existing graph
#    nodes, edges = add_emergencies_to_gdfs(nodes, edges, emergency_df)
#    G = ox.graph_from_gdfs(nodes, edges)
#    G = ox.get_undirected(G)
#    # impute speed on all edges missing data
#    G = ox.add_edge_speeds(G)
#    # calculate travel time (seconds) for all edges
#    G = ox.add_edge_travel_times(G)
#    return G