Commit 4da62e60 authored by Fize Jacques's avatar Fize Jacques
Browse files

DEBUG

parent 755998a6
No related merge requests found
Showing with 77 additions and 10 deletions
+77 -10
# coding = utf-8
import argparse, os, re, json, glob
import pandas as pd
import networkx as nx
import numpy as np
from strpython.eval.automatic_annotation import AnnotationAutomatic,save_cache,add_cache
from strpython.models.str import STR
from tqdm import tqdm,TqdmSynchronisationWarning
from joblib import Parallel, delayed
from multiprocessing import cpu_count
import warnings
warnings.simplefilter("ignore", TqdmSynchronisationWarning)
tqdm.pandas()
annotater = AnnotationAutomatic()
parser = argparse.ArgumentParser()
parser.add_argument("csvinputdir")
parser.add_argument("graph_dir")
parser.add_argument("output_file")
args = parser.parse_args()
if not os.path.exists(args.csvinputdir) or not os.path.exists(args.graph_dir):
raise FileNotFoundError("Error in Input")
all_cp=[]
fns = glob.glob("{0}/*".format(args.csvinputdir))
for fn in fns:
df = pd.read_csv(fn)
cps=df["G1 G2".split()].apply(lambda x:"_".join(np.sort(x.values).astype(str)) ,axis=1).values
all_cp.extend(cps.tolist())
all_cp=set(all_cp)
df = pd.DataFrame([cp.split("_") for cp in all_cp],columns="G1 G2".split())
str_graph_path = args.graph_dir
strs = {}
for file in tqdm(glob.glob(os.path.join(str_graph_path, "*.gexf")),desc="Load Graphs"):
id_ = int(re.findall("\d+", file)[-1])
strs[id_] = STR.from_networkx_graph(nx.read_gexf(file))
#print(strs)
def foo(x):
try:
return annotater.all(strs[int(x.G1)], strs[int(x.G2)],int(x.G1), int(x.G2))
except KeyError as e:
add_cache(int(x.G1), int(x.G2),[0, 0, 0, 0])
return [0, 0, 0, 0]
df["res"] = df.progress_apply(lambda x: foo(x), axis=1) #Parallel(n_jobs=4)(delayed(foo)(x) for x in tqdm(df.itertuples(),total=df.size,desc="Extracting Crit"))#
df.res=df.res.apply(lambda x :list(map(int,x)) if x else [])
df["c1"] = df.res.apply(lambda x: x[0] if len(x)>0 else 0)
df["c2"] = df.res.apply(lambda x: x[1] if len(x)>0 else 0)
df["c3"] = df.res.apply(lambda x: x[2] if len(x)>0 else 0)
df["c4"] = df.res.apply(lambda x: x[3] if len(x)>0 else 0)
del df["res"]
save_cache()
df.to_csv(args.output_file)
......@@ -56,9 +56,7 @@ class AnnotationAutomatic(object):
def all(self,str1,str2,id1=None,id2=None):
cache_data=get_from_cache(id1,id2)
if not cache_data:
crit_ = [self.criterion1(str1, str2), self.criterion2(str1, str2),
self.criterion3(str1, str2, id1, id2),
self.criterion4(str1, str2, id1, id2)]
crit_ = [self.criterion1(str1, str2), self.criterion2(str1, str2),self.criterion3(str1, str2, id1, id2),self.criterion4(str1, str2, id1, id2)]
add_cache(id1,id2,crit_)
return crit_
return cache_data
......
......@@ -80,9 +80,9 @@ def getGEO(id_se):
return None
data=data[0]
if "path" in data:
if "path" in data.other:
return explode(gpd.read_file(os.path.join(config.osm_boundaries_directory, data.other["path"]))).convex_hull
elif "coord" in data:
elif "coord" in data.other:
return gpd.GeoDataFrame(gpd.GeoSeries([Point(data.coord.lon, data.coord.lat).buffer(1.0)])).rename(
columns={0: 'geometry'})
return None
......
......@@ -96,7 +96,7 @@ class STR(object):
except KeyError: # If no label found, grab one from the geo-database
data = gazetteer.get_by_id(nod)
if data:
sp_en[nod] = data[0].name
sp_en[nod] = data[0].label
str_ = STR(tagged_, sp_en,toponym_first=False)
str_.set_graph(g)
......@@ -310,6 +310,7 @@ class STR(object):
data = gazetteer.get_by_id(id_se)
if len(data) > 0:
STR.__cache_entity_data[id_se] = data[0]
return data[0]
def transform_spatial_entities(self, transform_map: dict):
"""
......@@ -492,17 +493,17 @@ class STR(object):
data_se1, data_se2 = self.get_data(se1), self.get_data(se2)
if "P47" in data_se2 and se1 in get_p47_adjacency_data(data_se2):
if "P47" in data_se2.other and se1 in get_p47_adjacency_data(data_se2):
return True
# print("P47")
elif "P47" in data_se1 and se2 in get_p47_adjacency_data(data_se1):
elif "P47" in data_se1.other and se2 in get_p47_adjacency_data(data_se1):
return True
# print("P47")
if collisionTwoSEBoundaries(se1, se2):
return True
if "coord" in data_se1 and "coord" in data_se2:
if data_se1 and data_se2 and "coord" in data_se1.other and "coord" in data_se2.other:
if Point(data_se1.coord.lon, data_se1.coord.lat).distance(
Point(data_se2.coord.lon, data_se2.coord.lat)) < 1 and len(
set(data_se1.class_) & stop_class) < 1 and len(set(data_se2.class_) & stop_class) < 1:
......@@ -642,7 +643,7 @@ class STR(object):
data = gazetteer.get_by_id(se)[0]
try:
points.append(Point(data.coord.lon, data.coord.lat))
label.append(data.name)
label.append(data.label)
# class_.append(most_common(data["class"]))
except KeyError:
pass
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment