# coding = utf-8 import argparse, os, re, json, glob import pandas as pd import networkx as nx import numpy as np from strpython.eval.automatic_annotation import AnnotationAutomatic,save_cache,add_cache from strpython.models.str import STR from tqdm import tqdm,TqdmSynchronisationWarning from joblib import Parallel, delayed from multiprocessing import cpu_count import warnings warnings.simplefilter("ignore", TqdmSynchronisationWarning) tqdm.pandas() annotater = AnnotationAutomatic() parser = argparse.ArgumentParser() parser.add_argument("csvinputdir") parser.add_argument("graph_dir") parser.add_argument("output_file") args = parser.parse_args() if not os.path.exists(args.csvinputdir) or not os.path.exists(args.graph_dir): raise FileNotFoundError("Error in Input") all_cp=[] fns = glob.glob("{0}/*".format(args.csvinputdir)) for fn in fns: df = pd.read_csv(fn) cps=df["G1 G2".split()].apply(lambda x:"_".join(np.sort(x.values).astype(str)) ,axis=1).values all_cp.extend(cps.tolist()) all_cp=set(all_cp) df = pd.DataFrame([cp.split("_") for cp in all_cp],columns="G1 G2".split()) str_graph_path = args.graph_dir strs = {} for file in tqdm(glob.glob(os.path.join(str_graph_path, "*.gexf")),desc="Load Graphs"): id_ = int(re.findall("\d+", file)[-1]) strs[id_] = STR.from_networkx_graph(nx.read_gexf(file)) #print(strs) def foo(x): try: return annotater.all(strs[int(x.G1)], strs[int(x.G2)],int(x.G1), int(x.G2)) except KeyError as e: add_cache(int(x.G1), int(x.G2),[0, 0, 0, 0]) return [0, 0, 0, 0] df["res"] = df.progress_apply(lambda x: foo(x), axis=1) #Parallel(n_jobs=4)(delayed(foo)(x) for x in tqdm(df.itertuples(),total=df.size,desc="Extracting Crit"))# df.res=df.res.apply(lambda x :list(map(int,x)) if x else []) df["c1"] = df.res.apply(lambda x: x[0] if len(x)>0 else 0) df["c2"] = df.res.apply(lambda x: x[1] if len(x)>0 else 0) df["c3"] = df.res.apply(lambda x: x[2] if len(x)>0 else 0) df["c4"] = df.res.apply(lambda x: x[3] if len(x)>0 else 0) del df["res"] save_cache() df.to_csv(args.output_file)