From 5783f9cb24d747ffae6695d5b937c26e3578ef77 Mon Sep 17 00:00:00 2001
From: Fize Jacques <jacques.fize@cirad.fr>
Date: Fri, 15 Mar 2019 17:41:17 +0100
Subject: [PATCH] Spatial Relation extraction faster !

---
 setup.py                               |   4 +-
 strpython/eval/automatic_annotation.py |  26 +--
 strpython/helpers/collision.py         |  10 +-
 strpython/models/spatial_relation.py   | 268 +++++++++++++++++++++++++
 strpython/models/str.py                |   4 +-
 5 files changed, 292 insertions(+), 20 deletions(-)
 create mode 100644 strpython/models/spatial_relation.py

diff --git a/setup.py b/setup.py
index d1fb08b..4a646c0 100644
--- a/setup.py
+++ b/setup.py
@@ -7,14 +7,14 @@ setup(
     version='0.1',
     packages=['strpython', 'strpython.nlp', 'strpython.nlp.ner', 'strpython.nlp.exception', 'strpython.nlp.pos_tagger',
               'strpython.nlp.disambiguator', 'strpython.nlp.disambiguator.models',
-              'strpython.nlp.disambiguator.delozier', 'strpython.eval', 'strpython.tt4py', 'strpython.config',
+              'strpython.nlp.disambiguator.delozier', 'strpython.eval', 'strpython.config',
               'strpython.models', 'strpython.models.transformation', 'strpython.helpers'],
     url='',
     license='MIT',
     author='Jacques Fize',
     author_email='jacques.fize@cirad.fr',
     description="Module developed in the context of a thesis. This module comprise all implementation of algorithms, "
-                "model for text matching based on spatial features "
+                "model for text matching based on spatial features ", install_requires=['tqdm']
 )
 # Put default config file if not exists
 home = str(Path.home())
diff --git a/strpython/eval/automatic_annotation.py b/strpython/eval/automatic_annotation.py
index fcd487e..784c493 100644
--- a/strpython/eval/automatic_annotation.py
+++ b/strpython/eval/automatic_annotation.py
@@ -1,4 +1,6 @@
 # coding = utf-8
+import json
+
 import numpy as np
 
 from ..models.str import STR
@@ -17,7 +19,8 @@ class AnnotationAutomatic(object):
         self.matching_cache = MatchingCache(dataset)
         self.adj_rel_db = AdjacencyRelation()
         self.inc_rel_db = InclusionRelation()
-
+        self.inclusion = json.load(open("notebooks/inclusion.json"))
+        self.adjacency = json.load(open("notebooks/adjacency.json"))
     def all(self, str1, str2, id1=None, id2=None):
         """
 
@@ -54,7 +57,6 @@ class AnnotationAutomatic(object):
         -------
 
         """
-        print("CRIT1")
         return int(len(set(str1.graph.nodes.keys()) & set(str2.graph.nodes.keys())) > 0)
 
     def criterion2(self, str1: STR, str2: STR):
@@ -69,14 +71,14 @@ class AnnotationAutomatic(object):
         -------
 
         """
-        print("CRIT2")
+
         stop_en = set(str1.graph.nodes.keys()) & set(str2.graph.nodes.keys())
         for es in str1.spatial_entities:
             for es2 in str2.spatial_entities:
                 if not es in stop_en and not es2 in stop_en and es != es2:
-                    if self.inc_rel_db.is_relation(es, es2):
+                    if self.inclusion[es][es2]:
                         return 1
-                    if self.adj_rel_db.is_relation(es, es2):
+                    if self.adjacency[es][es2]:
                         return 1
         return 0
 
@@ -96,7 +98,7 @@ class AnnotationAutomatic(object):
         -------
 
         """
-        print("CRIT3")
+
         try:
             c1 = str1.get_cluster(id1)
         except:
@@ -115,12 +117,12 @@ class AnnotationAutomatic(object):
         mean = np.mean(c1.area_)
         c1=c1[c1.area_ >= mean]
         return int(c1.intersects(c2).any())
-        for ind, rows in c1.iterrows():
-            if rows.area < mean:
-                break
-            for ind2, rows2 in c2.iterrows():
-                if rows.geometry.intersects(rows2.geometry):
-                    return 1
+        # for ind, rows in c1.iterrows():
+        #     if rows.area < mean:
+        #         break
+        #     for ind2, rows2 in c2.iterrows():
+        #         if rows.geometry.intersects(rows2.geometry):
+        #             return 1
 
         return 0
 
diff --git a/strpython/helpers/collision.py b/strpython/helpers/collision.py
index dccf8b8..25d4f95 100644
--- a/strpython/helpers/collision.py
+++ b/strpython/helpers/collision.py
@@ -112,11 +112,13 @@ def collide(se1, se2):
     if not type(data_se1) == gpd.GeoDataFrame or not type(data_se2) == gpd.GeoDataFrame:
         return False
     try:
-        if data_se1.intersects(data_se2):
-            return True
+        if data_se1.envelope.intersects(data_se2.envelope):
+            if data_se1.intersects(data_se2):
+                return True
     except:
-        if data_se1.intersects(data_se2).any():
-            return True
+        if data_se1.envelope.intersects(data_se2.envelope).any():
+            if data_se1.intersects(data_se2).any():
+                return True
     return False
 
 
diff --git a/strpython/models/spatial_relation.py b/strpython/models/spatial_relation.py
new file mode 100644
index 0000000..baa7c10
--- /dev/null
+++ b/strpython/models/spatial_relation.py
@@ -0,0 +1,268 @@
+# coding = utf-8
+from multiprocessing import cpu_count
+
+from shapely.geometry import Point, Polygon
+import geopandas as gpd
+import pandas as pd
+import numpy as np
+
+from joblib import Parallel, delayed
+
+from mytoolbox.env import in_notebook
+
+if in_notebook():
+    from tqdm._tqdm_notebook import tqdm_notebook as tqdm
+else:
+    from tqdm import tqdm
+
+from strpython.helpers.collision import getGEO
+from strpython.helpers.geodict_helpers import gazetteer
+
+from mytoolbox.structure.objectify import objectify
+
+class MetaCollector():
+    __cache_entity_data = {}
+
+    def __init__(self):
+        pass
+
+    def get_data(self, id_se):
+        """
+        Return an gazpy.Element object containing information about a spatial entity.
+
+        Parameters
+        ----------
+        id_se : str
+            Identifier of the spatial entity
+
+        Returns
+        -------
+        gazpy.Element
+            data
+        """
+
+        if id_se in MetaCollector.__cache_entity_data:
+            return MetaCollector.__cache_entity_data[id_se]
+        data = gazetteer.get_by_id(id_se)
+        if len(data) > 0:
+            MetaCollector.__cache_entity_data[id_se] = data[0]
+            return data[0]
+
+    def is_relation(self, id_se1: str, id_se2: str):
+        """
+        Return True if the relation defined exist between the two entities
+        Parameters
+        ----------
+        id_se1 : str
+            Identifier of the first spatial entity
+        id_se2 : str
+            Identifier of the second spatial entity
+
+        Returns
+        -------
+        bool
+            if relation exists
+        """
+        raise NotImplementedError()
+
+
+class RelationExtractor(MetaCollector):
+    __cache_entity_data = {}
+
+    def __init__(self, spatial_entities):
+        self.spatial_entities = spatial_entities
+
+        data = [[sp_id, getGEO(sp_id)] for sp_id in
+                             tqdm(spatial_entities, desc="Retrieving Geometries...")]
+
+        self.all_geometry = []
+        for i in data:
+            if not isinstance(i[1], gpd.GeoDataFrame) and not isinstance(i[1], gpd.GeoSeries):
+                self.all_geometry.append([i[0], Polygon()])
+            else:
+                self.all_geometry.append([i[0], i[1].geometry.values[0]])
+
+        self.adjacency_geom, self.inclusion_geom = {}, {}
+        self.adjacency_meta, self.inclusion_meta = {}, {}
+
+    def get_relation_geometry_based(self):
+        if not self.all_geometry:
+            raise ValueError("No geometry extracted. Check the `spatial_entities` arg during the initialization.")
+
+        gdf_intersect = gpd.GeoDataFrame(self.all_geometry, columns="id geometry".split())
+        for row in tqdm(gdf_intersect.itertuples(), total=len(gdf_intersect), desc="Computing intersections..."):
+            try:
+                gdf_intersect["{0}".format(row.id)] = gdf_intersect.intersects(row.geometry)
+            except Exception as e:
+                print(e)
+
+        gdf_within = gpd.GeoDataFrame(self.all_geometry, columns="id geometry".split())
+        for row in tqdm(gdf_within.itertuples(), total=len(gdf_within), desc="Computing contains..."):
+            try:
+                gdf_within["{0}".format(row.id)] = gdf_within.geometry.within(row.geometry)
+            except Exception as e:
+                print(e)
+
+        corr_ = gdf_intersect.iloc[:, 2:] ^ gdf_within.iloc[:,2:]  # An entity cannot be related to an other entity by two type of relation
+        adj_ = gdf_intersect.iloc[:, 2:] & corr_  # because if include and not adjacent does not mean Adjacent !
+
+        gdf_adjacency = gdf_within.iloc[:, :2]
+        gdf_adjacency = pd.concat((gdf_adjacency, adj_), axis=1)  # Stuck id and geom to adjacency data
+
+        del gdf_adjacency["geometry"]
+        del gdf_within["geometry"]
+
+        # Transform to dict for a fastest access !
+        self.adjacency_geom = gdf_adjacency.set_index("id")
+        self.inclusion_geom = gdf_within.set_index("id")
+
+    def get_relation_meta_based(self):
+        meta_adj_extractor = AdjacencyMetaRelation(self.spatial_entities)
+        meta_inc_extractor = InclusionMetaRelation(self.spatial_entities)
+
+        adj_res = {}
+        for i in tqdm(range(len(self.spatial_entities)), desc="Retrieve Adjacency based on meta-data"):
+            se1 = self.spatial_entities[i]
+            sub_spat = self.spatial_entities[i:len(self.spatial_entities)]
+            res = Parallel(n_jobs=4, backend="multiprocessing")(delayed(meta_adj_extractor.is_relation)(se1, se2) for se2 in sub_spat)
+            for j in range(len(sub_spat)):
+                se2 = sub_spat[j]
+                if not se1 in adj_res: adj_res[se1] = {}
+                if not se2 in adj_res: adj_res[se2] = {}
+                adj_res[se1][se2] = res[j]
+                adj_res[se2][se1] = res[j]
+
+        inc_res = {}
+        for se1 in tqdm(self.spatial_entities, desc="Retrieve Inclusion based on meta_data"):
+            res = Parallel(n_jobs=4, backend="multiprocessing")(delayed(meta_inc_extractor.is_relation)(se1, se2) for se2 in self.spatial_entities)
+            #res= [meta_inc_extractor.is_relation(se1, se2) for se2 in self.spatial_entities]
+            for i,se2 in enumerate(self.spatial_entities):
+                if not se1 in inc_res: inc_res[se1] = {}
+                adj_res[se1][se2] = res[j]
+
+        self.adjacency_meta = pd.DataFrame.from_dict(adj_res)
+        self.inclusion_meta = pd.DataFrame.from_dict(inc_res)
+
+    def fuse_meta_and_geom(self):
+        # To apply logical combination correctly !
+        self.adjacency_meta.sort_index(inplace= True)
+        self.inclusion_meta.sort_index(inplace=True)
+        self.adjacency_geom.sort_index(inplace=True)
+        self.inclusion_geom.sort_index(inplace=True)
+
+        self.adjacency_meta.sort_index(axis=1,inplace=True)
+        self.inclusion_meta.sort_index(axis=1,inplace=True)
+        self.adjacency_geom.sort_index(axis=1,inplace=True)
+        self.inclusion_geom.sort_index(axis=1,inplace=True)
+
+        df_adj = self.adjacency_meta.copy()
+        df_inc = self.inclusion_meta.copy()
+        df_adj.iloc[:,:] = self.adjacency_meta | self.adjacency_geom
+        df_inc.iloc[:,:] = self.inclusion_meta | self.inclusion_geom
+
+        return df_adj, df_inc
+
+class AdjacencyMetaRelation(MetaCollector):
+
+    def __init__(self, spatial_entities):
+        MetaCollector.__init__(self)
+        self.p_47_dict = {}
+        self.distances_is_inf_to = {}
+        self.get_all_p47(spatial_entities)
+        self.get_all_distances(spatial_entities)
+
+    def get_all_p47(self, ses):
+        p47_dict = {}
+        for es in tqdm(ses, "Extract all P47 data..."):
+            data = self.get_data(es)
+            p47se1 = []
+            if "P47" in data:
+                for el in data.other.P47:
+                    d = gazetteer.get_by_other_id(el, "wikidata")
+                    if not d: continue
+                    p47se1.append(d[0].id)
+            p47_dict[data.id] = p47se1
+        self.p_47_dict = p47_dict
+
+    def compute_dist(self,data_se1, data_se2):
+        stop_class = {"A-PCLI", "A-ADM1"}
+        if data_se1["lat"] == np.inf :
+            return np.inf
+        return Point(data_se1["lon"], data_se1["lat"]).distance(Point(data_se2["lon"], data_se2["lon"]))
+
+    def get_all_distances(self, spatial_entities):
+        stop_class = {"A-PCLI", "A-ADM1"}  # Country or First Adminstrative cut
+
+        data = {}
+        for sp_ in spatial_entities:
+            dd=self.get_data(sp_)
+            if "coord" in  dd:
+                data[sp_]={"lat":dd.coord.lat,"lon":dd.coord.lon}
+            else:
+                data[sp_] = {"lat": np.inf, "lon": np.inf}
+
+        dist_all = {}
+        for es in tqdm(spatial_entities):
+            res_ = Parallel(n_jobs=4, backend="multiprocessing")(delayed(self.compute_dist)(data[es], data[es2]) for es2 in spatial_entities)
+            for ix,es2 in enumerate(spatial_entities):
+                if not es in dist_all:dist_all[es]={}
+                if not es2 in dist_all: dist_all[es2] = {}
+                dist_all[es][es2]=res_[ix]
+                dist_all[es2][es] = res_[ix]
+
+        for se1 in tqdm(spatial_entities, desc="Compute Distances"):
+            if not se1 in self.distances_is_inf_to: self.distances_is_inf_to[se1] = {}
+            for se2 in spatial_entities:
+                data_se1, data_se2 = data[se1],  data[se2]
+                if data_se1 and data_se2 and "coord" in data_se1 and "coord" in data_se2:
+                    not_in_stop = len(set(data_se1.class_) & stop_class) < 1 and len(
+                        set(data_se2.class_) & stop_class) < 1
+                    self.distances_is_inf_to[se1][se2] = dist_all[se1][se2] < 1 and not_in_stop
+                else:
+                    self.distances_is_inf_to[se1][se2] = False
+
+    def is_relation(self, id_se1: str, id_se2: str):
+        if id_se1 in self.p_47_dict[id_se2]:
+            return True
+
+        elif id_se2 in self.p_47_dict[id_se1]:
+            return True
+
+        if self.distances_is_inf_to[id_se1][id_se2] or self.distances_is_inf_to[id_se1][id_se2]:
+            return True
+        return False
+
+
+class InclusionMetaRelation(MetaCollector):
+    _inc_chain_cache = {}
+
+    def __init__(self,spatial_entities):
+        MetaCollector.__init__(self)
+        self._inc_chain_cache = {}
+        for se in tqdm(spatial_entities,desc="Extract Inclusion Chains"):
+            inc_chain_P131, inc_chain_P706 = self.get_inclusion_chain(se, "P131"), self.get_inclusion_chain(se,"P706")
+            inc_chain = inc_chain_P131
+            inc_chain.extend(inc_chain_P706)
+            inc_chain = set(inc_chain)
+            self._inc_chain_cache[se] = inc_chain
+
+    def is_relation(self, id_se1: str, id_se2: str):
+
+        if id_se1 in self._inc_chain_cache[id_se2]:
+            return True
+
+        return False
+
+    def get_inclusion_chain(self, id_, prop):
+        """
+        For an entity return it geographical inclusion tree using a property.
+        """
+        arr__ = []
+        current_entity = gazetteer.get_by_id(id_)[0]
+        if "inc_" + prop in current_entity.other:
+            arr__ = current_entity.other["inc_" + prop]
+        elif "inc_geoname" in current_entity.other:
+            arr__ = current_entity.other.inc_geoname
+        if isinstance(arr__, str):
+            arr__ = [arr__]
+        return arr__
diff --git a/strpython/models/str.py b/strpython/models/str.py
index 837b327..6612e89 100644
--- a/strpython/models/str.py
+++ b/strpython/models/str.py
@@ -12,7 +12,7 @@ import networkx as nx
 import pandas as pd
 from shapely.geometry import MultiPoint, Polygon, Point, LineString
 from sklearn.cluster import MeanShift, estimate_bandwidth, dbscan
-import matplotlib.pyplot as plt
+
 
 from ..helpers.geodict_helpers import gazetteer
 from ..helpers.relation_extraction import AdjacencyRelation, InclusionRelation
@@ -567,7 +567,7 @@ class STR(object):
             Matplotlib figure instance
         """
 
-
+        import matplotlib.pyplot as plt
         world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
         base = world.plot(color='white', edgecolor='black', figsize=(16, 9))
         points = []
-- 
GitLab