pipeline.py 4.02 KiB
# coding =utf-8
import re

from nltk import word_tokenize
from joblib import Parallel, delayed

from .models.spatial_relation import RelationExtractor
from .models.str import STR
from .models.transformation.transform import Generalisation, Expansion

from .nlp.disambiguator import *
from .nlp.ner import *

from .nlp.exception.disambiguator import NotADisambiguatorInstance
from .nlp.exception.ner import NotANERInstance
from multiprocessing import cpu_count

from mytoolbox.env import in_notebook
if in_notebook():
    from tqdm._tqdm_notebook import tqdm_notebook as tqdm
else:
    from tqdm import tqdm


class Pipeline(object):
    """
    Class defining a Pipeline instance
    Run the whole treatement on a given text
    """

    def __init__(self,lang="en",**kwargs):
        """
        Constructor


        :param kwargs:
        """
        self.lang=lang[:2]
        self.ner = kwargs["ner"] if "ner" in kwargs else Spacy(lang=lang[:2])
        self.disambiguator=kwargs["disambiguator"] if "disambiguator" in kwargs else MostCommonDisambiguator()

    def parse(self,text,debug=False):
        """

        :param text:
        :rtype: list,dict
        """
        output = text

        # NER
        output = self.ner.identify(output)
        # Disambiguation
        se_identified = self.disambiguator.disambiguate(self.lang,ner_output=output)
        if debug:
            print(se_identified)

        return text, se_identified


    def set_ner(self,ner):
        """
        Set NER used in the pipeline
        :param ner:
        :return:
        """
        if isinstance(ner,NER):
            self.ner=ner
        else:
            raise NotANERInstance()

    def set_disambiguator(self,disambiguator):
        """

        :param disambiguator:
        :return:
        """
        if isinstance(disambiguator,Disambiguator):
            self.disambiguator=disambiguator
        else:
            raise NotADisambiguatorInstance()

    def extract_all_relation(self,spatial_entities):
        """
        Extract relation information between spatial entities
        Parameters
        ----------
        spatial_entities

        Returns
        -------

        """
        r = RelationExtractor(spatial_entities)
        r.get_relation_geometry_based()
        r.get_relation_meta_based()
        df_adj, df_inc = r.fuse_meta_and_geom()
        dict_adj = df_adj.to_dict()
        dict_inc = df_inc.to_dict()
        return dict_adj, dict_inc

    def pipe_build(self,texts, cpu_count=cpu_count(),**kwargs):

        text_and_spatial_entities = Parallel(n_jobs=cpu_count,backend="threading")(delayed(self.parse)(text) for text in tqdm(texts,desc="Extract spatial entities from the texts"))
        sp_es= []
        for res in text_and_spatial_entities:
            sp_es.extend(list(res[1].values()))
        sp_es= [es for es in sp_es if es.startswith("GD")]
        print("Extract Spatial Relation for all identified spatial entities")
        adj_rel_dict, inc_rel_dict = self.extract_all_relation(sp_es)

        str_s = Parallel(n_jobs=cpu_count,backend="threading")(delayed(self.build)(ext[0], ext[1], adj_rel_dict, inc_rel_dict, **kwargs) for ext in tqdm(text_and_spatial_entities, desc="Build STR"))
        return str_s

    def pipe_transform(self,strs_, cpu_count=cpu_count(),**kwargs):

        str_s = Parallel(n_jobs=cpu_count,backend="threading")(delayed(self.transform)(str_, **kwargs) for str_ in tqdm(strs_,desc="Transform STR"))
        return str_s

    def build(self, text_input, spatial_entities_identified, prec_adj, prec_inc):
        str_ = STR(word_tokenize(text_input), spatial_entities_identified, toponym_first=True,precomputed_adj=prec_adj,precomputed_inc=prec_inc)
        str_.build(adj=True, inc=True)
        return str_

    def transform(self,str_,**kwargs):
        if not "type_trans" in kwargs:
            return str_
        type_trans=kwargs.pop("type_trans")
        if type_trans == "gen":
            str_=Generalisation().transform(str_,**kwargs)
        else:
            str_=Expansion().transform(str_,**kwargs)
        return str_