# coding = utf-8
import copy
import os
import time
import warnings


from tqdm import tqdm
import folium
import geopandas as gpd
import networkx as nx
import pandas as pd
from shapely.geometry import MultiPoint, Polygon, Point, LineString
from sklearn.cluster import MeanShift, estimate_bandwidth, dbscan


from ..helpers.geodict_helpers import gazetteer
from ..helpers.relation_extraction import AdjacencyRelation, InclusionRelation


def get_inclusion_chain(id_, prop):
    """
    For an entity return it geographical inclusion tree using a property.
    """
    arr__ = []
    current_entity = gazetteer.get_by_id(id_)[0]
    if "inc_" + prop in current_entity.other:
        arr__ = current_entity.other["inc_" + prop]
    elif "inc_geoname" in current_entity.other:
        arr__ = current_entity.other.inc_geoname
    if isinstance(arr__, str):
        arr__ = [arr__]
    return arr__


class STR(object):
    """
    Str basic structure
    """
    __cache_entity_data = {}  #  Store data about entity requested

    def __init__(self, tagged_text, spatial_entities,toponym_first=True):
        """
        Constructor

        Parameters
        ----------
        tagged_text : list
            Text in forms of token associated with tag (2D array 2*t where t == |tokens| )
        spatial_entities : dict
            spatial entities associated with a text. Follow this structure {"<id>: <label>"}

        """

        self.tagged_text = tagged_text
        self.spatial_entities = spatial_entities
        if toponym_first:
            self.spatial_entities= {id_:topo for topo,id_ in self.spatial_entities.items()}

        for k in list(spatial_entities.keys()):
            if not k[:2] == "GD":
                del spatial_entities[k]

        self.adjacency_relationships = {}
        self.inclusion_relationships = {}

        self.adj_rel_db=AdjacencyRelation()
        self.inc_rel_db = InclusionRelation()

        self.graph = nx.MultiDiGraph()

    @staticmethod
    def from_networkx_graph(g: nx.Graph, tagged_: list = []):
        """
        Build a STR based on networkx graph

        Parameters
        ----------
        g : nx.Graph
            input graph
        tagged_ : list, optional
            tagged text (the default is []). A 2D array 2*t where t == |tokens|.

        Returns
        -------
        STR
            resulting STR
        """

        sp_en = {}
        for nod in g:
            try:
                sp_en[nod] = g.nodes[nod]["label"]
            except KeyError:  # If no label found, grab one from the geo-database
                data = gazetteer.get_by_id(nod)
                if data:
                    sp_en[nod] = data[0].label

        str_ = STR(tagged_, sp_en,toponym_first=False)
        str_.set_graph(g)
        return str_

    @staticmethod
    def from_dict(spat_ent: dict, tagged_: list = []):
        """
        Build a STR based on networkx graph

        Parameters
        ----------
        spat_ent : dict
            Dict of patial entities associated with a text. Follow this structure {"<id>: <label>"}
        tagged_ : list, optional
            tagged text (the default is []). A 2D array 2*t where t == |tokens|.

        Returns
        -------
        STR
            resulting STR
        """
        sp_en = {}
        for id_, label in spat_ent.items():
            sp_en[id_] = label

        str_ = STR(tagged_, sp_en,toponym_first=False)
        str_.build()
        return str_

    @staticmethod
    def from_pandas(dataf: pd.DataFrame, tagged: list = []):
        """
        Build a STR from a Pandas Dataframe with two column : id and label.

        Parameters
        ----------
        dataf : pd.DataFrame
            dataframe containing the spatial entities
        tagged : list, optional
            tagged text (the default is []). A 2D array 2*t where t == |tokens|.

        Returns
        -------
        STR
            resulting STR
        """

        return STR.from_dict(pd.Series(dataf.label.values, index=dataf.id).to_dict(), tagged)

    def set_graph(self, g):
        """
        Apply changes to the current STR based on Networkx Graph.

        Parameters
        ----------
        g : networkx.Graph
            input graph

        """

        self.graph = g
        rel_ = self.graph.edges(data=True)
        for edge in rel_:
            id1, id2 = edge[0], edge[1]
            if edge[2]["color"] == "green":
                self.add_adjacency_rel(edge[0], edge[1])
            elif edge[2]["color"] == "red":
                self.add_inclusion_rel(edge[0], edge[1])

    def add_spatial_entity(self, id, label=None, v=True):
        """
        Add a spatial entity to the current STR

        Parameters
        ----------
        id : str
            identifier of the spatial entity in Geodict
        label : str, optional
            if not available in Geodict (the default is None)

        """
        data_ = self.get_data(id)
        if not data_:
            warnings.warn("{0} wasn't found in Geo-Database".format(id))
            return False
        data_ = data_[0]
        if not label and v == True:
            warnings.warn("Label empty. @en label from Geo-Database will be used.")
            label = data_["en"]
        self.spatial_entities[id] = label
        self.graph.add_node(id, label=label)

    def add_spatial_entities(self, ids: list, labels: list = []):
        """
        Add spatial entities to the current STR

        Parameters
        ----------
        ids : list
            list of identifiers of each spatial entity
        labels : list, optional
            list of labels of each spatial entity

        """
        if not labels:
            warnings.warn("Labels list is empty. @en labels from Geo-Database will be used by default")
        for i in range(len(ids)):
            id = ids[i]
            try:
                label = labels[i]
            except:
                label = None
            self.add_spatial_entity(id, label, False)


    def add_adjacency_rel(self, se1, se2):
        """
        Add a adjacency relationship to the current STR.

        Parameters
        ----------
        se1 : str
            Identifier of the first spatial entity
        se2 : str
            Identifier of the second spatial entity

        """

        if not se1 in self.adjacency_relationships: self.adjacency_relationships[se1] = {}
        if not se2 in self.adjacency_relationships: self.adjacency_relationships[se2] = {}
        self.adjacency_relationships[se1][se2], self.adjacency_relationships[se2][se1] = True, True

    def add_inclusion_rel(self, se1, se2):
        """
        Add a inclusion relationship to the current STR.

        Parameters
        ----------
        se1 : str
            Identifier of the first spatial entity
        se2 : str
            Identifier of the second spatial entity

        """
        if not se1 in self.inclusion_relationships:
            self.inclusion_relationships[se1] = {}
        self.inclusion_relationships[se1][se2] = True




    def get_data(self, id_se):
        """
        Return an gazpy.Element object containing information about a spatial entity.

        Parameters
        ----------
        id_se : str
            Identifier of the spatial entity

        Returns
        -------
        gazpy.Element
            data
        """

        if id_se in STR.__cache_entity_data:
            return STR.__cache_entity_data[id_se]
        data = gazetteer.get_by_id(id_se)
        if len(data) > 0:
            STR.__cache_entity_data[id_se] = data[0]
            return data[0]

    def transform_spatial_entities(self, transform_map: dict):
        """
        Replace or delete certain spatial entities based on a transformation map

        Parameters
        ----------
        transform_map : dict
            New mapping for the spatial entities in the current STR. Format required : {"<id of the old spatial entity>":"<id of the new spatial entity>"}

        """

        final_transform_map = {}
        # Erase old spatial entities
        new_label = {}
        to_del = set([])
        for old_se, new_se in transform_map.items():
            data = self.get_data(new_se)
            to_del.add(old_se)
            if data:
                data = data[0]
                final_transform_map[old_se] = new_se
                if not new_se in self.spatial_entities:
                    self.add_spatial_entity(new_se, data.label.en)

                del self.spatial_entities[old_se]

                new_label[new_se] = data.label.en
            else:
                warnings.warn("{0} doesn't exists in the geo database!".format(new_se))

        self.graph = nx.relabel_nodes(self.graph, final_transform_map)

        for es in to_del:
            if es in self.graph._node:
                self.graph.remove_node(es)

        for se_ in new_label:
            self.graph.nodes[se_]["label"] = new_label[se_]

    def update(self):
        """
        Update the relationship between spatial entities in the STR. Used when transforming the STR.
        """

        nodes = copy.deepcopy(self.graph.nodes(data=True))
        self.graph.clear()
        self.graph.add_nodes_from(nodes)

        self.get_inclusion_relationships()
        for se1 in self.inclusion_relationships:
            for se2 in self.inclusion_relationships[se1]:
                if not se1 in self.graph.nodes or not se2 in self.graph.nodes:
                    continue
                if self.inclusion_relationships[se1][se2]:
                    self.graph.add_edge(se1, se2, key=0, color="red")

        self.get_adjacency_relationships()
        for se1 in self.adjacency_relationships:
            for se2 in self.adjacency_relationships[se1]:
                if not se1 in self.graph.nodes or not se2 in self.graph.nodes:
                    continue
                if self.adjacency_relationships[se1][se2]:
                    self.graph.add_edge(se1, se2, key=0, color="green")


    def get_inclusion_relationships(self):
        """
        Find all the inclusion relationships between the spatial entities declared in the current STR.

        """
        for se_ in tqdm(self.spatial_entities, desc="Extract Inclusion"):
            for se2_ in self.spatial_entities:
                if se_ != se2_ and self.inc_rel_db.is_relation(se_,se2_):
                    self.add_inclusion_rel(se_, se2_)

    def get_adjacency_relationships(self):
        """
        Find all the adjacency relationships between the spatial entities declared in the current STR.
        """

        for se1 in tqdm(self.spatial_entities, desc="Extract Adjacency Relationship"):
            for se2 in self.spatial_entities:
                if se1 != se2 and self.adj_rel_db.is_relation(se1, se2):
                    self.add_adjacency_rel(se1,se2)


    def build(self, inc=True, adj=True, verbose=False):
        """
        Build the STR

        Parameters
        ----------
        inc : bool, optional
            if inclusion relationship have to be included in the STR (the default is True)
        adj : bool, optional
            if adjacency relationship have to be included in the STR (the default is True)
        verbose : bool, optional
            Verbose mode activated (the default is False)

        Returns
        -------
        networkx.Graph
            graph representing the STR
        """

        nodes = []
        for k, v in self.spatial_entities.items():
            nodes.append((k, {"label": v}))

        graph = nx.MultiDiGraph()
        graph.add_nodes_from(nodes)

        if adj:
            self.get_adjacency_relationships()
            for se1 in self.adjacency_relationships:
                for se2 in self.adjacency_relationships[se1]:
                    if self.adjacency_relationships[se1][se2]:
                        graph.add_edge(se1, se2, key=0, color="green")
                        graph.add_edge(se2, se1, key=0, color="green")

        if inc:
            self.get_inclusion_relationships()
            for se1 in self.inclusion_relationships:
                for se2 in self.inclusion_relationships[se1]:
                    if self.inclusion_relationships[se1][se2]:
                        graph.add_edge(se1, se2, key=0, color="red")

        self.graph = graph
        return graph

    def save_graph_fig(self, output_fn, format="svg"):
        """
        Save the graphiz reprensentation of the STR graph.

        Parameters
        ----------
        output_fn : string
            Output filename
        format : str
            Output format (svg or pdf)

        """
        try:
            if format == "pdf":
                nx.nx_pydot.to_pydot(self.graph).write_pdf(output_fn)
            else:
                nx.nx_pydot.to_pydot(self.graph).write_svg(output_fn)

        except:
            print("Error while saving STR to {0}".format(format))

    def get_undirected(self,simple_graph=True):
        """
        Return the Undirected form of a STR graph.

        Returns
        -------
        networkx.Graph
            unidirected graph
        """
        if simple_graph:
            return  nx.Graph(self.graph)
        return nx.MultiGraph(self.graph)

    def get_geo_data_of_se(self):
        """
        Return Geographical information for each spatial entities in the STR

        Returns
        -------
        geopandas.GeoDataFrame
            dataframe containing geographical information of each entity in the STR
        """

        points, label, class_ = [], [], []
        for se in self.spatial_entities:
            data = gazetteer.get_by_id(se)[0]
            try:
                points.append(Point(data.coord.lon, data.coord.lat))
                label.append(data.label)
                # class_.append(most_common(data["class"]))
            except KeyError:
                pass
        # print(len(points),len(label),len(class_))
        df = gpd.GeoDataFrame({"geometry": points, "label": label})
        df["x"] = df.geometry.apply(lambda p: p.x)
        df["y"] = df.geometry.apply(lambda p: p.y)
        return df

    def get_cluster(self, id_=None):
        """
        Return the cluster detected using spatial entities position.

        Parameters
        ----------
        id_ : temp_file_id, optional
            if cached version of geoinfo (the default is None)

        Returns
        -------
        gpd.GeoDataFrame
            cluster geometry
        """

        if os.path.exists("./temp_cluster/{0}.geojson".format(id_)):
            return gpd.read_file("./temp_cluster/{0}.geojson".format(id_))

        data = self.get_geo_data_of_se()
        X = data[["x", "y"]].values
        if len(X) == 0:  # if zero samples return Empty GeoDataFrame
            return gpd.GeoDataFrame()
        try:
            bandwidth = estimate_bandwidth(X)
            ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
            ms.fit(X)
            data["cluster"] = ms.labels_
        except:
            samples, labels = dbscan(X)
            data["cluster"] = labels

        geo = data.groupby("cluster").apply(to_Polygon)
        cluster_polybuff = gpd.GeoDataFrame(geometry=geo)
        if id_:
            cluster_polybuff.to_file("./temp_cluster/{0}.geojson".format(id_))
        return cluster_polybuff

    def to_folium(self):
        """
        Use the folium package to project the STR on a map

        Returns
        -------
        folium.Map
            folium map instance
        """

        points = []
        for se in self.spatial_entities:
            data = gazetteer.get_by_id(se)[0]
            try:
                points.append(Point(data.coord.lon, data.coord.lat))
            except:
                pass

        lines_adj = []
        for se1 in self.adjacency_relationships:
            data_se1 = gazetteer.get_by_id(se1)[0]
            for se2 in self.adjacency_relationships[se1]:
                data_se2 = gazetteer.get_by_id(se2)[0]
                if self.adjacency_relationships[se1][se2]:
                    lines_adj.append(
                        LineString([(data_se1.coord.lon, data_se1.coord.lat),
                                    (data_se2.coord.lon, data_se2.coord.lat)])
                    )
        lines_inc = []
        for se1 in self.inclusion_relationships:
            data_se1 = data_se1 = gazetteer.get_by_id(se1)[0]
            for se2 in self.inclusion_relationships[se1]:
                if self.inclusion_relationships[se1][se2]:
                    data_se2 = data_se1 = gazetteer.get_by_id(se2)[0]
                    lines_inc.append(
                        LineString([
                            (data_se1.coord.lon, data_se1.coord.lat),
                            (data_se2.coord.lon, data_se2.coord.lat)]
                        )
                    )

        def to_fol(seris, color="#ff0000"):
            df = gpd.GeoDataFrame(geometry=seris.values)
            df.crs = {'init': 'epsg:4326'}
            return folium.features.GeoJson(df.to_json(), style_function=lambda x: {'color': color})

        gjson1 = to_fol(gpd.GeoSeries(points))
        gjson2 = to_fol(gpd.GeoSeries(lines_adj), color='#00ff00')
        gjson3 = to_fol(gpd.GeoSeries(lines_inc))

        map = folium.Map()
        map.add_child(gjson1)
        map.add_child(gjson2)
        map.add_child(gjson3)

        return map

    def map_projection(self, plt=False):
        """
        Return a matplotlib figure of the STR

        Parameters
        ----------
        plt : bool, optional
            if the user wish to use the plt.show() (the default is False)

        Returns
        -------
        plt.Figure
            Matplotlib figure instance
        """

        import matplotlib.pyplot as plt
        world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
        base = world.plot(color='white', edgecolor='black', figsize=(16, 9))
        points = []
        for se in self.spatial_entities:
            data = gazetteer.get_by_id(se)[0]
            try:
                points.append(Point(data.coord.lon, data.coord.lat))
            except:
                pass

        lines_adj = []
        for se1 in self.adjacency_relationships:
            data_se1 = gazetteer.get_by_id(se1)[0]
            for se2 in self.adjacency_relationships[se1]:
                data_se2 = gazetteer.get_by_id(se2)[0]
                if self.adjacency_relationships[se1][se2]:
                    lines_adj.append(
                        LineString([(data_se1.coord.lon, data_se1.coord.lat), (data_se2.coord.lon, data_se2.coord.lat)])
                    )
        lines_inc = []
        for se1 in self.inclusion_relationships:
            data_se1 = gazetteer.get_by_id(se1)[0]
            for se2 in self.inclusion_relationships[se1]:
                if self.inclusion_relationships[se1][se2]:
                    data_se2 = gazetteer.get_by_id(se2)[0]
                    lines_inc.append(
                        LineString([
                            (data_se1.coord.lon, data_se1.coord.lat),
                            (data_se2.coord.lon, data_se2.coord.lat)]
                        )
                    )

        gpd.GeoSeries(points).plot(ax=base, marker='o', markersize=5, color="blue")
        gpd.GeoSeries(lines_adj).plot(ax=base, color="green")
        gpd.GeoSeries(lines_inc).plot(ax=base, color="red")

        if not plt:
            return base
        plt.show()


def to_Polygon(x):
    """
    Return a polygon buffered representation for a set of points.

    Parameters
    ----------
    x : pandas.Series
        coordinates columns

    Returns
    -------
    shapely.geometry.Polygon
        polygon
    """

    points = [Point(z) for z in x[["x", "y"]].values]
    if len(points) > 2:
        coords = [p.coords[:][0] for p in points]
        poly = Polygon(coords).buffer(1)
        return poly
    elif len(points) == 1:
        return points[0].buffer(1)
    else:
        coords = [p.coords[:][0] for p in points]
        return LineString(coords).buffer(1)