#* Functions for plotting a power grid described by a network data dictionary
#*------------------------------------------------------------------------------

#* NOTE: Utility functions used are defined in PlotUtils.jl

#*------------------------------------------------------------------------------

#=
Identifies and plots the overhead transmission line segments in the network data dictionary. 
=#
function plot_pg_overhead_tl_segments(
        network_data::Dict{String,<:Any},
        settings = Dict{String,Any}();
        mode = :sum, # :sum or :single
        figpath::String
    )
    
    ### Setup figure and axes
    figure::Figure, ax::PyObject = plt.subplots()::Tuple{Figure,PyObject}
    ax.set_aspect("equal")
    w::Float64, h::Float64 = plt.figaspect(2/3)::Vector{Float64}
    figure.set_size_inches(1.5w, 1.5h)

    ### Set plot settings and plot power grid
    settings = _recursive_merge(
        _default_settings(:plot_pg_overhead_tl_segments), settings
    )

    nx = pyimport("networkx")
    G::PyObject = nx.Graph()

    ### Plot buses into graph
    G, bus_markers, bus_labels = _draw_buses!(G, network_data, settings)

    ### Identify overhead transmission line segments and get their positions
    seg_data = calc_overhead_tl_segments(
        network_data, settings["d_twrs"], mode=mode
    )
    seg_pos = Dict{Int64,Tuple{Float64,Float64}}()
    seg_counter = 0
    for tl in collect(values(seg_data))
        for i in 1:tl["N_seg"]
            seg_pos[seg_counter+i] = (tl["seg_lons"][i], tl["seg_lats"][i])
        end
        seg_counter += tl["N_seg"]
    end
    ### Draw overhead transmission line segments as nodes
    nx.draw_networkx_nodes(
        G, seg_pos, 
        nodelist = [i for i in 1:length(seg_pos)],
        node_size = settings["node_size"],
        alpha = settings["alpha"]
    )

    ### Check for a predefined area to show
    if haskey(settings, "area")
        area = settings["area"]
        plt.xlim(area[1], area[2])
        plt.ylim(area[3], area[4])
    end

    ### Axes settings
    ax.tick_params(
        left = settings["draw_ticks"][1],
        bottom = settings["draw_ticks"][2], 
        labelleft = settings["draw_ticks"][3], 
        labelbottom = settings["draw_ticks"][4]
    )
    plt.xlabel(settings["xlabel"])
    plt.ylabel(settings["ylabel"], rotation=90)

    ### Draw legend, if wanted
    if settings["draw_legend"] == true
        plt.legend(bus_markers, bus_labels)
    end

    plt.savefig(figpath, bbox_inches="tight")
    plt.close("all")

    return nothing
end

#*------------------------------------------------------------------------------

function plot_pg_map(
        network_data::Dict{String,<:Any},
        settings = Dict{String,Any}(); # dictionary containing plot settings
        wind = ("", 0), # optional .nc file with wind data and frame to plot
        figpath = "" # where to save the figure
    )

    ### Python imports
    cartopy = pyimport("cartopy")
    cticker = pyimport("cartopy.mpl.ticker")
    
    ### Setup figure and plot geographic map
    fig::Figure = plt.figure()
    w::Float64, h::Float64 = plt.figaspect(2/3)::Vector{Float64}
    fig.set_size_inches(1.5w, 1.5h)
    ax = fig.add_subplot(projection=cartopy.crs.PlateCarree())
    ax.set_aspect("equal")

    ### Set plot settings
    settings = _recursive_merge(_default_settings(:plot_pg), settings)
    ### Set area to plot
    xmin, xmax, ymin, ymax = _get_pg_area(network_data, settings["area_offset"])
    ax.set_extent([xmin, xmax, ymin, ymax])

    ### Get state borders
    states_provinces = cartopy.feature.NaturalEarthFeature(
        category="cultural",
        name="admin_1_states_provinces_lines",
        scale="50m",
        facecolor="none"
    )
    ### Add wanted features to plot
    ax.add_feature(cartopy.feature.LAND)
    ax.add_feature(cartopy.feature.OCEAN)
    ax.add_feature(cartopy.feature.COASTLINE)
    ax.add_feature(cartopy.feature.BORDERS)
    ax.add_feature(states_provinces, edgecolor="gray")

    ### Show longitude and latitude values
    gl = ax.gridlines(crs=cartopy.crs.PlateCarree(), draw_labels=true)
    gl.xlabels_top = false
    gl.ylabels_right = false
    gl.xlines = false
    gl.ylines = false
    gl.xformatter = cartopy.mpl.gridliner.LONGITUDE_FORMATTER
    gl.yformatter = cartopy.mpl.gridliner.LATITUDE_FORMATTER

    _plot_pg!(ax, network_data, settings, figpath, wind)
        
    return nothing
end

#*------------------------------------------------------------------------------

#=
Plots the power grid described by the network data dictionary (NDD). Possible plot settings are shown in _default_settings.
=#
function plot_pg(
        network_data::Dict{String,<:Any},
        settings = Dict{String,Any}(); # dictionary containing plot settings
        wind = ("", 0), # optional .nc file with wind data and frame to plot
        figpath = "" # where to save the figure
    )

    ### Setup figure
    figure::Figure, ax::PyObject = plt.subplots()::Tuple{Figure,PyObject}
    w::Float64, h::Float64 = plt.figaspect(2/3)::Vector{Float64}
    figure.set_size_inches(1.5w, 1.5h)
    ax.set_aspect("equal")

    ### Set plot settings and plot power grid
    settings = _recursive_merge(_default_settings(:plot_pg), settings)
    _plot_pg!(ax, network_data, settings, figpath, wind)

    return nothing
end

#=
Plots the power grid described by the network data dictionary (NDD) onto an already existing axes. Possible plot settings are shown in _default_settings.
=#
function _plot_pg!(
        ax::PyObject, # axes to draw power grid onto
        network_data::Dict{String,<:Any},
        settings::Dict{String,<:Any}, # dictionary containing plot settings
        figpath::String, # where to save the figure
        wind = ("", 0) # optional .nc file with wind data and frame to plot
    )

    ### Draw power grid graph
    nx = pyimport("networkx")
    G::PyObject = nx.Graph() # empty graph
    _draw_pg!(ax, G, network_data, settings, wind) # draw power grid onto ax
   
    plt.savefig(figpath, bbox_inches="tight") # save figure
    plt.close("all") # close figure
    return nothing
end

#*------------------------------------------------------------------------------

#=
Draws a graph for the power grid described by the NDD with options according to the dictionary "settings" (see _default_settings for possible options).
=#
function _draw_pg!(
        ax::PyObject, # axes to draw power grid onto
        G::PyObject, # power grid graph
        network_data::Dict{String,<:Any},
        settings::Dict{String,<:Any}, # dictionary containing plot settings
        wind = ("", 0) # optional .nc file with wind data and frame to plot
    )

    ### Plot optional wind field
    if isempty(wind[1]) == false # optional .nc file given
        ax, wind_cbar = _draw_wind!(ax, settings, wind)
    end 

    ### Plot buses into graph
    G, bus_markers, bus_labels = _draw_buses!(G, network_data, settings)

    ### Plot branches into graph
    G, br_markers, br_labels, br_cbar = _draw_branches!(
        G, network_data, settings
    )

    ### Check for a predefined area to show
    if haskey(settings, "area")
        area = settings["area"]
        plt.xlim(area[1], area[2])
        plt.ylim(area[3], area[4])
    end

    ### Axes settings
    ax.tick_params(
        left = settings["draw_ticks"][1],
        bottom = settings["draw_ticks"][2], 
        labelleft = settings["draw_ticks"][3], 
        labelbottom = settings["draw_ticks"][4]
    )
    plt.xlabel(settings["xlabel"])
    plt.ylabel(settings["ylabel"], rotation=90)

    ### Draw legend, if wanted
    if settings["draw_legend"] == true
        all_markers = vcat(bus_markers, br_markers)
        all_labels = vcat(bus_labels, br_labels)
        plt.legend(all_markers, all_labels)
    end

    return ax, G
end

#*------------------------------------------------------------------------------

function _draw_wind!(
        ax::PyObject, # axes to draw wind field onto
        settings::Dict{String,<:Any}, # dictionary containing plot settings
        wind = ("", 0) # .nc file with wind data and frame to plot
    )
    
    wind_lons, wind_lats, wind_speeds = get_windfield(wind[1], wind[2])
    N_frames, max_ws = get_winddata(wind[1])

    wind_speeds = transpose(wind_speeds) # correct dimensions for contour plot
    levels = LinRange(0, max_ws, settings["Wind"]["levels"])
    cs = ax.contourf(
        wind_lons, wind_lats, wind_speeds, 
        cmap = settings["Wind"]["cmap"], 
        alpha = settings["Wind"]["alpha"], 
        levels = levels
    )
    cbar = plt.colorbar(cs, ticks=[0:5:max_ws])
    cbar.ax.set_ylabel(
        settings["Wind"]["cbar_label"], rotation=-90, va="bottom"
    )

    return ax, cbar
end

#*------------------------------------------------------------------------------

#=
Draws buses contained in the NDD as nodes into graph G. The nodes are displayed according to the settings dictionary (see _default_settings).
=#
function _draw_buses!(
        G::PyObject, # power grid graph
        network_data::Dict{String,<:Any},
        settings::Dict{String,<:Any} # dictionary containing plot settings 
    )

    ### Python imports
    nx = pyimport("networkx")
    mlines = pyimport("matplotlib.lines")
    
    bustypes = get_bustypes(network_data) # types of all buses
    pos = Dict(
        b["index"] => (b["bus_lon"], b["bus_lat"]) 
        for b in collect(values(network_data["bus"]))
    ) # geographic bus locations
    bus_markers = [
        mlines.Line2D([], [], color=b["color"], marker=b["marker"], ls="None")
        for b in collect(values(settings["Buses"]))
        if b["label"] != "nolabel" && b["show"] == true
    ] # markers for legend
    bus_labels = [
        b["label"] for b in collect(values(settings["Buses"])) 
        if b["label"] != "nolabel" && b["show"] == true
    ] # labels for legend

    ### Draw different buses as nodes
    for (type, buses) in bustypes
        bus_settings = settings["Buses"][type]

        ### Filter out isolated empty buses and plot them invisible
        if type == "Empty bus" && settings["Buses"]["Empty bus"]["show_isolated"] == false
            isolated_buses = _get_isolated_buses(network_data)
            filter!(i -> i ∉ isolated_buses, bustypes["Empty bus"])
            nx.draw_networkx_nodes(
                G, pos,
                nodelist = isolated_buses,
                node_shape = bus_settings["marker"],
                node_size = bus_settings["size"],
                node_color = bus_settings["color"],
                alpha = 0 # invisible
            )
        end

        ### Plot buses
        if bus_settings["show"] == true
            nx.draw_networkx_nodes(
                G, pos,
                nodelist = buses,
                node_shape = bus_settings["marker"],
                node_size = bus_settings["size"],
                node_color = bus_settings["color"],
                alpha = bus_settings["alpha"]
            )
        end
    end

    return G, bus_markers, bus_labels
end

#*------------------------------------------------------------------------------

#=
Draws branches contained in the NDD as edges into graph G. The branches are displayed according to the settings dictionary (see _default_settings).
=#
function _draw_branches!(
        G::PyObject, # power grid graph
        network_data::Dict{String,<:Any},
        settings::Dict{String,<:Any} # dictionary containing plot settings 
    )

    br_settings = settings["Branches"]
    br_coloring = br_settings["br_coloring"]

    ### Draw branches according to coloring mode
    if br_coloring == "equal"
        G, br_markers, br_labels, cbar = _draw_br_equal!(
            G, network_data, br_settings
        )
    elseif br_coloring == "voltage"
        G, br_markers, br_labels, cbar = _draw_br_voltage!(
            G, network_data, br_settings
        )
    elseif br_coloring in ["MW-loading","Mvar-loading","MVA-loading"]
        G, br_markers, br_labels, cbar = _draw_br_branchloads!(
            G, network_data, br_settings
        )
    else
        throw(ArgumentError("Unknown branch coloring $br_coloring."))
    end

    return G, br_markers, br_labels, cbar
end

#*------------------------------------------------------------------------------

#=
Draws branches contained in the NDD with coloring mode "equal". All branches are displayed using the same color.
=#
function _draw_br_equal!(
        G::PyObject, # power grid graph
        network_data::Dict{String,<:Any},
        br_settings::Dict{String,<:Any} # dictionary containing plot settings 
    )
    
    ### Python imports
    nx = pyimport("networkx")
    
    pos = Dict(
        b["index"] => (b["bus_lon"], b["bus_lat"]) 
        for b in collect(values(network_data["bus"]))
    ) # geographic bus locations
    branches = collect(values(network_data["branch"])) # branch dictionaries
    
    ### Get edges contained in the NDD
    if br_settings["br_status"] == "active" # only plot active branches
        edges = [(b["f_bus"],b["t_bus"]) for b in branches if b["br_status"]==1]
    elseif br_settings["br_status"] == "inactive" # only plot active branches
        edges = [(b["f_bus"],b["t_bus"]) for b in branches if b["br_status"]==0]
    elseif br_settings["br_status"] == "all" # plot all branches
        edges = [(b["f_bus"],b["t_bus"]) for b in branches]
    else
        br_status = br_settings["br_status"]
        throw(ArgumentError("Unknown branch status $br_status."))
    end

    ### Draw edges
    drawn_edges = nx.draw_networkx_edges(
        G, pos, 
        edgelist = edges,
        width = br_settings["br_lw"],
        edge_color = br_settings["br_color"],
        alpha = br_settings["br_alpha"]
    )

    return G, [], [], nothing    
end

#=
Draws branches contained in the NDD with coloring mode "MW-loading", "Mvar-loading" or "MVA-loading". The branches are colored depending on their loading (flow/capacity).
=#
function _draw_br_branchloads!(
        G::PyObject, # power grid graph
        network_data::Dict{String,<:Any},
        br_settings::Dict{String,<:Any} # dictionary containing plot settings
    )

    ### Python imports
    nx = pyimport("networkx")

    pos = Dict(
        b["index"] => (b["bus_lon"], b["bus_lat"]) 
        for b in collect(values(network_data["bus"]))
    ) # geographic bus locations
    branches = collect(values(network_data["branch"])) # branch dictionaries
    br_coloring = br_settings["br_coloring"] # what kind of loading to use
    br_status = br_settings["br_status"]

    ### Get edges contained in the NDD and their loadings
    if br_status == "active" # only plot active branches
        edges = [(b["f_bus"],b["t_bus"]) for b in branches if b["br_status"]==1]
        branchloads = [b[br_coloring] for b in branches if b["br_status"]==1]
    elseif br_settings["br_status"] == "inactive" # only plot inactive branches
        edges = [(b["f_bus"],b["t_bus"]) for b in branches if b["br_status"]==0]
        branchloads = [b[br_coloring] for b in branches if b["br_status"]==0]
    elseif br_status == "all" # plot all branches
        edges = [(b["f_bus"],b["t_bus"]) for b in branches]
        branchloads = [b[br_coloring] for b in branches]
    else
        throw(ArgumentError("Unknown branch status $br_status."))
    end

    ### Draw edges
    cmap = plt.cm.inferno_r
    vmin, vmax = 0., 1.
    drawnedges = nx.draw_networkx_edges(
        G, pos, 
        edgelist = edges,
        width = br_settings["br_lw"],
        edge_color = branchloads,
        edge_cmap = cmap,
        edge_vmin = vmin,
        edge_vmax = vmax,
        alpha = br_settings["br_alpha"]
    )

    ### Add colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin, vmax))
    cbar = plt.colorbar(sm)
    cbar.ax.set_ylabel(
        "Line $br_coloring " * L"$F_{ij}/C_{ij}$", rotation=-90, va="bottom"
    )

    return G, [], [], cbar
end

#=
Draws branches contained in the NDD with coloring mode "voltage". Transmission lines are colored according to their voltage levels.
=#
function _draw_br_voltage!(
        G::PyObject, # power grid graph
        network_data::Dict{String,<:Any},
        br_settings::Dict{String,<:Any} # dictionary containing plot settings
    )
    
    ### Python imports
    nx = pyimport("networkx")
    mlines = pyimport("matplotlib.lines")

    pos = Dict(
        b["index"] => (b["bus_lon"], b["bus_lat"]) 
        for b in collect(values(network_data["bus"]))
    ) # geographic bus locations
    branches = collect(values(network_data["branch"])) # branch dictionaries
    br_markers = Array{PyObject,1}() # markers for legend
    br_labels = Array{String,1}() # labels for legend
    
    ### Get edges contained in the NDD and their voltage levels
    if br_settings["br_status"] == "active" # only plot active branches
        edges = [
            (b["f_bus"],b["t_bus"]) for b in branches 
            if b["br_status"] == 1
        ]
        voltages = [
            string(b["tl_voltage"]) for b in branches 
            if b["br_status"] == 1
        ]
    elseif br_settings["br_status"] == "inactive" # only plot inactive branches
        edges = [
            (b["f_bus"],b["t_bus"]) for b in branches 
            if b["br_status"] == 0
        ]
        voltages = [
            string(b["tl_voltage"]) for b in branches 
            if b["br_status"] == 0
        ]
    elseif br_settings["br_status"] == "all" # plot all branches
        edges = [(b["f_bus"],b["t_bus"]) for b in branches]
        voltages = [string(b["tl_voltage"]) for b in branches]
    else
        br_status = br_settings["br_status"]
        throw(ArgumentError("Unknown branch status $br_status."))
    end

    ### Assign colors to voltage levels and add markers and labels for legend
    voltages[voltages .== "0.0"] .= "k" # transformers
    mcolors = pyimport("matplotlib.colors")
    tableau = [
        key for key in keys(mcolors.TABLEAU_COLORS) 
        if key ∉ ["tab:orange", "tab:green"] # orange and green used for buses
    ]
    for (i, v) in enumerate(sort(unique(filter(v -> v != "k", voltages))))
        voltages[voltages .== v] .= tableau[i]
        push!(br_markers, mlines.Line2D([], [], color=tableau[i], ls="-"))
        push!(br_labels, string(v) * " kV")
    end
    
    ### Draw edges
    drawnedges = nx.draw_networkx_edges(
        G, pos, 
        edgelist = edges,
        width = br_settings["br_lw"],
        edge_color = voltages,
        alpha = br_settings["br_alpha"]
    )
    
    return G, br_markers, br_labels, nothing
end