criteria_cache.py 2.04 KiB
# coding = utf-8

import argparse, os, re, json, glob
import pandas as pd
import networkx as nx
import numpy as np


from strpython.eval.automatic_annotation import AnnotationAutomatic,save_cache,add_cache
from strpython.models.str import STR
from tqdm import tqdm,TqdmSynchronisationWarning
from joblib import Parallel, delayed
from multiprocessing import cpu_count


import warnings
warnings.simplefilter("ignore", TqdmSynchronisationWarning)
tqdm.pandas()

annotater = AnnotationAutomatic()

parser = argparse.ArgumentParser()

parser.add_argument("csvinputdir")
parser.add_argument("graph_dir")
parser.add_argument("output_file")

args = parser.parse_args()

if not os.path.exists(args.csvinputdir) or not os.path.exists(args.graph_dir):
    raise FileNotFoundError("Error in Input")

all_cp=[]
fns = glob.glob("{0}/*".format(args.csvinputdir))
for fn in fns:
    df = pd.read_csv(fn)
    cps=df["G1 G2".split()].apply(lambda x:"_".join(np.sort(x.values).astype(str)) ,axis=1).values
    all_cp.extend(cps.tolist())
all_cp=set(all_cp)
df = pd.DataFrame([cp.split("_") for cp in all_cp],columns="G1 G2".split())
str_graph_path = args.graph_dir


strs = {}
for file in tqdm(glob.glob(os.path.join(str_graph_path, "*.gexf")),desc="Load Graphs"):
    id_ = int(re.findall("\d+", file)[-1])
    strs[id_] = STR.from_networkx_graph(nx.read_gexf(file))

#print(strs)
def foo(x):
    try:
        return annotater.all(strs[int(x.G1)], strs[int(x.G2)],int(x.G1), int(x.G2))
    except KeyError as e:
        add_cache(int(x.G1), int(x.G2),[0, 0, 0, 0])
        return [0, 0, 0, 0]


df["res"] = df.progress_apply(lambda x: foo(x), axis=1) #Parallel(n_jobs=4)(delayed(foo)(x) for x in tqdm(df.itertuples(),total=df.size,desc="Extracting Crit"))#
df.res=df.res.apply(lambda x :list(map(int,x)) if x else [])
df["c1"] = df.res.apply(lambda x: x[0] if len(x)>0 else 0)
df["c2"] = df.res.apply(lambda x: x[1] if len(x)>0 else 0)
df["c3"] = df.res.apply(lambda x: x[2] if len(x)>0 else 0)
df["c4"] = df.res.apply(lambda x: x[3] if len(x)>0 else 0)

del df["res"]
save_cache()
df.to_csv(args.output_file)