diff --git a/criteria_cache.py b/criteria_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..b04c925052d0b07819b6625710669a8049194629 --- /dev/null +++ b/criteria_cache.py @@ -0,0 +1,68 @@ +# coding = utf-8 + +import argparse, os, re, json, glob +import pandas as pd +import networkx as nx +import numpy as np + + +from strpython.eval.automatic_annotation import AnnotationAutomatic,save_cache,add_cache +from strpython.models.str import STR +from tqdm import tqdm,TqdmSynchronisationWarning +from joblib import Parallel, delayed +from multiprocessing import cpu_count + + +import warnings +warnings.simplefilter("ignore", TqdmSynchronisationWarning) +tqdm.pandas() + +annotater = AnnotationAutomatic() + +parser = argparse.ArgumentParser() + +parser.add_argument("csvinputdir") +parser.add_argument("graph_dir") +parser.add_argument("output_file") + +args = parser.parse_args() + +if not os.path.exists(args.csvinputdir) or not os.path.exists(args.graph_dir): + raise FileNotFoundError("Error in Input") + +all_cp=[] +fns = glob.glob("{0}/*".format(args.csvinputdir)) +for fn in fns: + df = pd.read_csv(fn) + cps=df["G1 G2".split()].apply(lambda x:"_".join(np.sort(x.values).astype(str)) ,axis=1).values + all_cp.extend(cps.tolist()) +all_cp=set(all_cp) +df = pd.DataFrame([cp.split("_") for cp in all_cp],columns="G1 G2".split()) +str_graph_path = args.graph_dir + + +strs = {} +for file in tqdm(glob.glob(os.path.join(str_graph_path, "*.gexf")),desc="Load Graphs"): + id_ = int(re.findall("\d+", file)[-1]) + strs[id_] = STR.from_networkx_graph(nx.read_gexf(file)) + +#print(strs) +def foo(x): + try: + return annotater.all(strs[int(x.G1)], strs[int(x.G2)],int(x.G1), int(x.G2)) + except KeyError as e: + add_cache(int(x.G1), int(x.G2),[0, 0, 0, 0]) + return [0, 0, 0, 0] + + +df["res"] = df.progress_apply(lambda x: foo(x), axis=1) #Parallel(n_jobs=4)(delayed(foo)(x) for x in tqdm(df.itertuples(),total=df.size,desc="Extracting Crit"))# +df.res=df.res.apply(lambda x :list(map(int,x)) if x else []) +df["c1"] = df.res.apply(lambda x: x[0] if len(x)>0 else 0) +df["c2"] = df.res.apply(lambda x: x[1] if len(x)>0 else 0) +df["c3"] = df.res.apply(lambda x: x[2] if len(x)>0 else 0) +df["c4"] = df.res.apply(lambda x: x[3] if len(x)>0 else 0) + +del df["res"] +save_cache() +df.to_csv(args.output_file) + diff --git a/strpython/eval/automatic_annotation.py b/strpython/eval/automatic_annotation.py index b9f5fc1dd90443debafed84a47de60a51375c0c0..89f918289851bd430d3cef6b976a6416805d9104 100644 --- a/strpython/eval/automatic_annotation.py +++ b/strpython/eval/automatic_annotation.py @@ -56,9 +56,7 @@ class AnnotationAutomatic(object): def all(self,str1,str2,id1=None,id2=None): cache_data=get_from_cache(id1,id2) if not cache_data: - crit_ = [self.criterion1(str1, str2), self.criterion2(str1, str2), - self.criterion3(str1, str2, id1, id2), - self.criterion4(str1, str2, id1, id2)] + crit_ = [self.criterion1(str1, str2), self.criterion2(str1, str2),self.criterion3(str1, str2, id1, id2),self.criterion4(str1, str2, id1, id2)] add_cache(id1,id2,crit_) return crit_ return cache_data diff --git a/strpython/helpers/collision.py b/strpython/helpers/collision.py index 5c1b1345b4afab70c043d240634d4285b14372b7..7f4647b166d8986ad3bd512bb855bb990c66bf3e 100644 --- a/strpython/helpers/collision.py +++ b/strpython/helpers/collision.py @@ -80,9 +80,9 @@ def getGEO(id_se): return None data=data[0] - if "path" in data: + if "path" in data.other: return explode(gpd.read_file(os.path.join(config.osm_boundaries_directory, data.other["path"]))).convex_hull - elif "coord" in data: + elif "coord" in data.other: return gpd.GeoDataFrame(gpd.GeoSeries([Point(data.coord.lon, data.coord.lat).buffer(1.0)])).rename( columns={0: 'geometry'}) return None diff --git a/strpython/models/str.py b/strpython/models/str.py index 2d673ba3fb27a977de11d77eed0463a0962e91d8..72aa9caf1ce9cc3dc1519e99026eb2b9493a75d9 100644 --- a/strpython/models/str.py +++ b/strpython/models/str.py @@ -96,7 +96,7 @@ class STR(object): except KeyError: # If no label found, grab one from the geo-database data = gazetteer.get_by_id(nod) if data: - sp_en[nod] = data[0].name + sp_en[nod] = data[0].label str_ = STR(tagged_, sp_en,toponym_first=False) str_.set_graph(g) @@ -310,6 +310,7 @@ class STR(object): data = gazetteer.get_by_id(id_se) if len(data) > 0: STR.__cache_entity_data[id_se] = data[0] + return data[0] def transform_spatial_entities(self, transform_map: dict): """ @@ -492,17 +493,17 @@ class STR(object): data_se1, data_se2 = self.get_data(se1), self.get_data(se2) - if "P47" in data_se2 and se1 in get_p47_adjacency_data(data_se2): + if "P47" in data_se2.other and se1 in get_p47_adjacency_data(data_se2): return True # print("P47") - elif "P47" in data_se1 and se2 in get_p47_adjacency_data(data_se1): + elif "P47" in data_se1.other and se2 in get_p47_adjacency_data(data_se1): return True # print("P47") if collisionTwoSEBoundaries(se1, se2): return True - if "coord" in data_se1 and "coord" in data_se2: + if data_se1 and data_se2 and "coord" in data_se1.other and "coord" in data_se2.other: if Point(data_se1.coord.lon, data_se1.coord.lat).distance( Point(data_se2.coord.lon, data_se2.coord.lat)) < 1 and len( set(data_se1.class_) & stop_class) < 1 and len(set(data_se2.class_) & stop_class) < 1: @@ -642,7 +643,7 @@ class STR(object): data = gazetteer.get_by_id(se)[0] try: points.append(Point(data.coord.lon, data.coord.lat)) - label.append(data.name) + label.append(data.label) # class_.append(most_common(data["class"])) except KeyError: pass