auto_fill_annotation.py 1.40 KiB
# coding = utf-8


import argparse, os, re, json, glob
import pandas as pd
import networkx as nx

from strpython.eval.automatic_annotation import AnnotationAutomatic,save_cache
from strpython.models.str import STR
from tqdm import tqdm,TqdmSynchronisationWarning
import warnings
warnings.simplefilter("ignore", TqdmSynchronisationWarning)
tqdm.pandas()

annotater = AnnotationAutomatic()

parser = argparse.ArgumentParser()

parser.add_argument("csv_file")
parser.add_argument("graph_dir")
parser.add_argument("output_file")

args = parser.parse_args()

if not os.path.exists(args.csv_file) or not os.path.exists(args.graph_dir):
    raise FileNotFoundError("Error in Input")

df = pd.read_csv(args.csv_file, index_col=0)
str_graph_path = args.graph_dir


strs = {}
for file in glob.glob(os.path.join(str_graph_path, "*.gexf")):
    id_ = int(re.findall("\d+", file)[-1])
    strs[id_] = STR.from_networkx_graph(nx.read_gexf(file))

#print(strs)
def foo(x):
    try:
        return annotater.all(strs[x.G1], strs[x.G2],x.G1, x.G2)
    except Exception as e:
        print(e)
        return [0, 0, 0, 0]


df["res"] = df.progress_apply(lambda x: foo(x), axis=1)
df.res=df.res.apply(lambda x :list(map(int,x)))
df[["c1"]] = df.res.apply(lambda x: x[0])
df[["c2"]] = df.res.apply(lambda x: x[1])
df[["c3"]] = df.res.apply(lambda x: x[2])
df[["c4"]] = df.res.apply(lambda x: x[3])

del df["res"]
save_cache()
df.to_csv(args.output_file)