import networkx as nx
import osmnx as ox
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
import numpy as np
import warnings
from mpl_toolkits.axes_grid1 import make_axes_locatable


def map_highway_to_number_of_lanes(highway_type):
    if highway_type == "motorway" or highway_type == "trunk":
        return 4
    elif highway_type == "primary":
        return 3
    elif (
        highway_type == "secondary"
        or highway_type == "motorway_link"
        or highway_type == "trunk_link"
        or highway_type == "primary_link"
    ):
        return 2
    else:
        return 1


# def set_number_of_lanes(G):
#    edges = ox.graph_to_gdfs(G, nodes=False)
#    lanes = {e : map_highway_to_number_of_lanes(highway_type) for e, highway_type in edges['highway'].items()}
#    nx.set_edge_attributes(G, lanes, 'lanes')
#    return G


def set_number_of_lanes(G):
    edges = ox.graph_to_gdfs(G, nodes=False)
    edges["lanes"] = [
        np.mean(list(map(float, v[0])))
        if type(v[0]) == list
        else float(v[0])
        if type(v[0]) == str
        else map_highway_to_number_of_lanes(v[1])
        if np.isnan(v[0])
        else v[0]
        for k, v in edges[["lanes", "highway"]].iterrows()
    ]
    nx.set_edge_attributes(G, edges["lanes"], "lanes")
    return G


def set_actual_speed_diff(G, load_kw="load", alpha=50.0, kph_min=10):
    loads = nx.get_edge_attributes(G, load_kw)
    lengths = nx.get_edge_attributes(G, "length")
    lanes = nx.get_edge_attributes(G, "lanes")
    speed_lim = {e: v for e, v in nx.get_edge_attributes(G, "speed_kph").items()}
    speeds = {
        e: np.divide(alpha * lanes[e] * lengths[e], loads[e]) - 9
        for e in G.edges(keys=True)
    }

    actual_speeds = {
        e: kph_min if v < kph_min else speed_lim[e] if v > speed_lim[e] else v
        for e, v in speeds.items()
    }

    speed_diff = {e: (speed_lim[e] - v) for e, v in actual_speeds.items()}

    if load_kw == "load":
        nx.set_edge_attributes(G, actual_speeds, "actual_speed")
        nx.set_edge_attributes(G, speed_diff, "speed_diff")
    elif load_kw == "load_flood":
        nx.set_edge_attributes(G, actual_speeds, "actual_speed_flood")
        nx.set_edge_attributes(G, speed_diff, "speed_diff_flood")
    else:
        print("wrong load_kw. Returning G")
    return G


def set_daganzo_velocities(G, alpha, kph_min):
    G = set_number_of_lanes(G)
    loads = nx.get_edge_attributes(G, "load")
    lengths = nx.get_edge_attributes(G, "length")  # length in m
    lanes = nx.get_edge_attributes(G, "lanes")
    speed_lim = {e: v for e, v in nx.get_edge_attributes(G, "speed_kph").items()}
    velocities = {
        e: speed_lim[e]
        if loads[e] == 0
        else 60
        * 60
        / 1000
        * (alpha * np.divide(lanes[e] * lengths[e], loads[e] * 2) - 5 / 2)
        for e in G.edges(keys=True)
    }

    daganzo_velo = {
        e: kph_min
        if v < kph_min
        else speed_lim[e]
        if v > speed_lim[e]
        else np.round(v, 1)
        for e, v in velocities.items()
    }

    nx.set_edge_attributes(G, daganzo_velo, "effective_velocity")


# def actual_speed(G, alpha = 50., kph_min = 10):
#    loads = nx.get_edge_attributes(G, 'load')
#    lengths = nx.get_edge_attributes(G, 'length')
#    lanes = nx.get_edge_attributes(G, 'lanes')
#    speed_lim = { e : v for e, v in nx.get_edge_attributes(G, 'speed_kph').items()}
#    speeds = { e : alpha*lanes[e]*lengths[e]/loads[e] for e in G.edges(keys=True) }
#
#    actual_speeds = { e : kph_min if v < kph_min else speed_lim[e] if v > speed_lim[e]  else v for e, v in speeds.items() }
#
#    nx.set_edge_attributes(G, actual_speeds, 'actual_speed')
#    return G


def set_speed_diff(g):
    speed = nx.get_edge_attributes(g, "actual_speed")
    speed_kph = nx.get_edge_attributes(g, "speed_kph")

    speed_diff = {k: speed_kph[k] - v for k, v in speed.items()}
    nx.set_edge_attributes(g, speed_diff, "speed_diff")
    return g


def split_graph_attributes(
    G, attr="speed_diff", vmin=0, vmax=0, cmapsteps=8, cmap="viridis"
):
    edges = ox.graph_to_gdfs(G, nodes=False)

    vals = pd.Series(nx.get_edge_attributes(G, attr))

    cmap = plt.cm.get_cmap(cmap).copy()
    cmaplist = [cmap(i) for i in range(cmap.N)]

    cmap_discrete = mpl.colors.LinearSegmentedColormap.from_list(
        "viridis_discrete", cmaplist, cmap.N
    )

    cmap_discrete.set_under("dimgrey")
    cmap_discrete.set_over("lightgrey")

    # define the bins and normalize
    if vmin == 0 and vmax == 0:
        bounds = np.ceil(np.linspace(min(vals), max(vals), cmapsteps))
    else:
        bounds = np.ceil(np.linspace(vmin, vmax, cmapsteps))
    norm = mpl.colors.BoundaryNorm(bounds, cmap_discrete.N)

    scalar_mapper = plt.cm.ScalarMappable(cmap=cmap_discrete, norm=norm)

    ec = dict(vals.map(scalar_mapper.to_rgba))
    nx.set_edge_attributes(G, ec, "ec")

    # split graph according to attributes
    # create generators
    edg_2 = (
        edge
        for edge, hw in edges["highway"].items()
        if hw == "motorway"
        or hw == "trunk"
        or hw == "primary"
        or hw == "motorway_link"
        or hw == "trunk_link"
        or hw == "primary_link"
    )

    edg_1 = (
        edge
        for edge, hw in edges["highway"].items()
        if hw == "secondary"
        or hw == "secondary_link"
        or hw == "tertiary"
        or hw == "tertiary_link"
    )

    G2 = G.edge_subgraph(edg_2)
    G1 = G.edge_subgraph(edg_1)

    warnings.filterwarnings("ignore", category=DeprecationWarning)

    fig, ax = ox.plot_graph(
        G,
        figsize=(20, 12),
        edge_color=pd.Series(nx.get_edge_attributes(G, "ec")),
        edge_linewidth=0.2,
        node_size=0,
        show=False,
    )
    ox.plot_graph(
        G2,
        ax=ax,
        figsize=(20, 12),
        edge_color=pd.Series(nx.get_edge_attributes(G2, "ec")),
        edge_linewidth=3,
        node_size=0,
        show=False,
    )
    ox.plot_graph(
        G1,
        ax=ax,
        figsize=(20, 12),
        edge_color=pd.Series(nx.get_edge_attributes(G1, "ec")),
        edge_linewidth=1,
        node_size=0,
        show=False,
    )

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="2%", pad=-3)

    cb = fig.colorbar(scalar_mapper, cax=cax)
    cb.set_label(r"$\Delta V_{i,j}$ [km/h]", color="w")
    cb.ax.yaxis.set_tick_params(color="w")
    plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="w")
    cb.ax.tick_params(width=1.0)