generate_str.py 5.76 KiB
import sys, os, re ,argparse, warnings, json

import logging
logger = logging.getLogger("elasticsearch")
logger.setLevel(logging.ERROR)

import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm
tqdm.pandas() # for progressbar when apply with dataframes
tqdm.monitor_interval = 0


from strpython.pipeline import Pipeline
from strpython.nlp.pos_tagger.tagger import Tagger
from strpython.models.str import STR

from strpython.nlp.ner.spacy import Spacy as spacy_ner
from strpython.nlp.ner.polyglot import Polyglot as poly_ner
from strpython.nlp.ner.stanford_ner import StanfordNER as stanford_ner

from strpython.nlp.disambiguator.wikipedia_cooc import WikipediaDisambiguator as wiki_d
from strpython.nlp.disambiguator.geodict_gaurav import GauravGeodict as shared_geo_d
from strpython.nlp.disambiguator.most_common import MostCommonDisambiguator as most_common_d

from mytoolbox.text.clean import *
from mytoolbox.exception.inline import safe_execute

from stop_words import get_stop_words

import logging
logger = logging.getLogger("elasticsearch")
logger.setLevel(logging.ERROR)
logger = logging.getLogger("Fiona")
logger.setLevel(logging.ERROR)


disambiguator_dict = {
    "occwiki" : wiki_d,
    "most_common" : most_common_d,
    "shareprop" : shared_geo_d
}

ner_dict = {
    "spacy": spacy_ner,
    "polyglot":poly_ner,
    "stanford":stanford_ner
}


help_input="""Filename of your input. Must be in Pickle format with the following columns :
  - filename : original filename that contains the text in `content`
  - id_doc : id of your document
  - content : text data associated to the document
  - lang : language of your document 
"""

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)

# REQUIRED
parser.add_argument("input_pkl",help=help_input)
# OPTIONAL
parser.add_argument("-n","--ner",
    help="The Named Entity Recognizer you wish to use",
    choices=list(ner_dict.keys()),
    default="spacy")
parser.add_argument("-d","--disambiguator",
    help="The Named Entity disambiguator you wish to use",
    choices=list(disambiguator_dict.keys()),
    default="most_common")
parser.add_argument("-t","--transform",
    help="Transformation to apply",
    action="append",
    choices=["gen","ext"])
    
parser.add_argument("-o","--output",
    help="Output Filename",
    default="output.pkl"
)

args = parser.parse_args()

if not os.path.exists(args.input_pkl):
    raise FileNotFoundError("Input file does not found !")

df = pd.read_pickle(args.input_pkl)

cols=set(df.columns)
if not "filename" in cols or not "id_doc" in cols or not "content" in cols or not "lang" in cols:
    raise ValueError("Missing data column in input given")

languages= np.unique(df.lang.values)
print("Languages available in the corpus",languages)

pipelines={
    lang : Pipeline(lang=lang,ner=ner_dict[args.ner](lang=lang),tagger=Tagger(),disambiguator= disambiguator_dict[args.disambiguator]())
    for lang in tqdm(languages,desc="Load Pipelines model")
}
def matcher_agrovoc( lang):
    """
    Return a terminolgy matcher using the Agrovoc vocabulary.
    
    Parameters
    ----------
    nlp : spacy.lang.Language
        model
    lang : str
        language of the terms
    
    Returns
    -------
    TerminologyMatcher
        matcher
    """
    agrovoc_vocab = pd.read_csv("../thematic_str/data/terminology/agrovoc/agrovoc_cleaned.csv")
    agrovoc_vocab["preferred_label_new"] = agrovoc_vocab["preferred_label_new"].apply(
        lambda x: safe_execute({}, Exception, json.loads, x.replace("\'", "\"")))
    agrovoc_vocab["label_lang"] = agrovoc_vocab["preferred_label_new"].apply(
        lambda x: str(resolv_a(x[lang]) if lang in x else np.nan).strip().lower())
    agrovoc_vocab=agrovoc_vocab[~pd.isna(agrovoc_vocab["label_lang"])]
    return agrovoc_vocab["label_lang"].values.tolist()

stopwords = {
    lang:matcher_agrovoc(lang)
    for lang in tqdm(languages,desc="Load stopwords")
}
for lang in stopwords:
    stopwords[lang].extend(get_stop_words(lang))

print("Clean input content ...")
if not "entities" in df:
    df["content"]= df.content.progress_apply(lambda x :clean_text(x))

count_error=0
def build(pipelines,x):
    global count_error
    try:
        if "entities" in x:
            return pipelines[x.lang].build(x.content,toponyms=x.entities,stop_words=stopwords[x.lang])
    except Exception as e:
        print(e)
    
    try:
        return pipelines[x.lang].build(x.content)
    except Exception as e:
        print(e)
        try:
            return pipelines[x.lang].build(str(x.content).encode("utf-8").decode("utf-8"))
        except Exception:
            warnings.warn("Could not build STR for doc with id = {0}".format(x.id_doc))
            count_error +=1
            return STR.from_networkx_graph(nx.Graph())

print("Transforming text to STR ...")

df["str_object"]=df.progress_apply(lambda x: build(pipelines,x) if len(x.content) >0 else STR.from_networkx_graph(nx.Graph()) , axis = 1)
df["str_object"]=df["str_object"].apply(lambda x: x[0] if isinstance(x,tuple) else x)

if "ext" in args.transform:
    print("Extending STR ...")
    df["ext_1"]=df.progress_apply(lambda x: pipelines[x.lang].transform(x.str_object,type_trans="ext",adjacent_count=1,distance="100"), axis = 1)
    df["ext_2"]=df.progress_apply(lambda x: pipelines[x.lang].transform(x.str_object,type_trans="ext",adjacent_count=2,distance="100"), axis = 1)

if "gen" in args.transform:
    print("Generalising STR ...")
    df["gen_region"]=df.progress_apply(lambda x: pipelines[x.lang].transform(x.str_object,type_trans="gen",type_gen="bounded",bound="region"), axis = 1)
    df["gen_country"]=df.progress_apply(lambda x: pipelines[x.lang].transform(x.str_object,type_trans="gen",type_gen="bounded",bound="country"), axis = 1)

print("Done with {0} error(s)... Now saving !".format(count_error))
df.to_pickle(args.output)