An error occurred while loading the file. Please try again.
-
Harold Boissenin authorede3cd5ac0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# coding = utf-8
import argparse, os, re, json, glob
import pandas as pd
import networkx as nx
import numpy as np
from strpython.eval.automatic_annotation import AnnotationAutomatic,save_cache,add_cache
from strpython.models.str import STR
from tqdm import tqdm,TqdmSynchronisationWarning
from joblib import Parallel, delayed
from multiprocessing import cpu_count
import warnings
warnings.simplefilter("ignore", TqdmSynchronisationWarning)
tqdm.pandas()
annotater = AnnotationAutomatic()
parser = argparse.ArgumentParser()
parser.add_argument("csvinputdir")
parser.add_argument("graph_dir")
parser.add_argument("output_file")
args = parser.parse_args()
if not os.path.exists(args.csvinputdir) or not os.path.exists(args.graph_dir):
raise FileNotFoundError("Error in Input")
all_cp=[]
fns = glob.glob("{0}/*".format(args.csvinputdir))
for fn in fns:
df = pd.read_csv(fn)
cps=df["G1 G2".split()].apply(lambda x:"_".join(np.sort(x.values).astype(str)) ,axis=1).values
all_cp.extend(cps.tolist())
all_cp=set(all_cp)
df = pd.DataFrame([cp.split("_") for cp in all_cp],columns="G1 G2".split())
str_graph_path = args.graph_dir
strs = {}
for file in tqdm(glob.glob(os.path.join(str_graph_path, "*.gexf")),desc="Load Graphs"):
id_ = int(re.findall("\d+", file)[-1])
strs[id_] = STR.from_networkx_graph(nx.read_gexf(file))
#print(strs)
def foo(x):
try:
return annotater.all(strs[int(x.G1)], strs[int(x.G2)],int(x.G1), int(x.G2))
except KeyError as e:
add_cache(int(x.G1), int(x.G2),[0, 0, 0, 0])
return [0, 0, 0, 0]
df["res"] = df.progress_apply(lambda x: foo(x), axis=1) #Parallel(n_jobs=4)(delayed(foo)(x) for x in tqdm(df.itertuples(),total=df.size,desc="Extracting Crit"))#
df.res=df.res.apply(lambda x :list(map(int,x)) if x else [])
df["c1"] = df.res.apply(lambda x: x[0] if len(x)>0 else 0)
df["c2"] = df.res.apply(lambda x: x[1] if len(x)>0 else 0)
df["c3"] = df.res.apply(lambda x: x[2] if len(x)>0 else 0)
df["c4"] = df.res.apply(lambda x: x[3] if len(x)>0 else 0)
del df["res"]
save_cache()
df.to_csv(args.output_file)