generate_transform.py 3.54 KiB
# coding = utf-8

# coding = utf-8
import argparse
import glob
import logging
import time
from concurrent.futures import ThreadPoolExecutor

from progressbar import ProgressBar, Timer, Bar, ETA, Counter

from strpython.helpers.boundary import get_all_shapes
from strpython.nlp.disambiguator.geodict_gaurav import *
from strpython.pipeline import *

parser = argparse.ArgumentParser()
parser.add_argument("graphs_input_dir")
parser.add_argument("graphs_output_dir")

subparsers = parser.add_subparsers(help='commands')

gen_parser = subparsers.add_parser(
    'generalisation', help='Apply a generalisation transformation on the generated STRs')
gen_parser.set_defaults(which="gene")
gen_parser.add_argument(
    '-t','--type_gen', help='Type of generalisation',default="all")
gen_parser.add_argument(
    '-n', help='Language',default=1)
gen_parser.add_argument(
    '-b','--bound', help='If Generalisation is bounded, this arg. correspond'
                         'to the maximal ',default="country")

ext_parser = subparsers.add_parser(
    'extension', help='Apply a extension process on the STRs')
ext_parser.set_defaults(which="ext")
ext_parser.add_argument(
    '-d','--distance', help='radius distance',default=150)
ext_parser.add_argument(
    '-u','--unit', help='unit used for the radius distance',default="km")
ext_parser.add_argument(
    '-a','--adjacent_count', help='number of adjacent SE add to the STR',default=1)

args = parser.parse_args()
if "which" in args:
    if args.which =="gene":
        args.type_trans="gen"
    elif args.which =="ext":
        args.type_trans="ext"

print("Parameters entered : ",args)


start = time.time()
class_=StanfordNER
# Initialise Graphs Transformers
pipeline= {
    "en":Pipeline(lang="english",tagger=Tagger(),ner=class_(lang="en")),
    "fr":Pipeline(lang="french",tagger=Tagger(),ner=class_(lang="fr")),
    "es":Pipeline(lang="espagnol",tagger=Tagger(),ner=class_(lang="es"))
}


associated_es={}
count_per_doc={}
# Read Input Files
import re
graphs_={}
if os.path.exists(args.graphs_input_dir):
    files_glob= glob.glob(args.graphs_input_dir+"/*.gexf")
    for fn in files_glob:
        id = int(re.findall("\d+", fn)[-1])
        graphs_[id]=STR.from_networkx_graph(nx.read_gexf(fn))
        associated_es[id]=graphs_[id].spatial_entities
    if not graphs_:
        print("No .gexf files found in {0}".format(args.graphs_input_dir))
        exit()
# If output Dir doesn't exists
if not os.path.exists(args.graphs_output_dir):
    os.makedirs(args.graphs_output_dir)

if not graphs_:
    print("No text files were loaded !")
    exit()

list_gs=[]
i=0

all_es=set([])
for k,v in associated_es.items():
    for k2 in v:
        all_es.add(k2)

logging.info("Get All Shapes from Database for all ES")
all_shapes=get_all_shapes(list(all_es))

for id_ in graphs_:
    graphs_[id].set_all_shapes(all_shapes)

def workSTR(id_doc,g,list_gs,pg,argu):
    global i
    str_ = pipeline["en"].transform(g, **vars(argu))
    list_gs.append(str_.graph)
    # Save Metadata

    # Save Graph structure
    nx.write_gexf(list_gs[-1], argu.graphs_output_dir + "/{0}.gexf".format(id_doc))
    i+=1
    pg.update(i)

queue=[]
with ThreadPoolExecutor(max_workers=4) as executor:
    with ProgressBar(max_value=len(graphs_),widgets=[' [', Timer(), '] ',Bar(),'(', Counter(),')','(', ETA(), ')']) as pg:
        pg.start()
        for id_doc in graphs_:

            workSTR(id_doc,graphs_[id_doc],list_gs,pg, args)

open(os.path.join(args.graphs_output_dir,"asso.json"),'w').write(json.dumps([associated_es,count_per_doc],indent=4))
print("--- %s seconds ---" % (time.time() - start))