generate_selected_document.py 1.32 KiB
# coding = utf-8
import argparse,glob,random,re
import networkx as nx
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("graph_input_dir")
args=parser.parse_args()

graphs={}
for file in glob.glob(args.graph_input_dir+"/normal/*.gexf"):
    id=int(re.findall("\d+",file)[-1])
    graphs[id]=nx.read_gexf(file)

median=np.median([len(g) for g in graphs.values()])
if median <=2:
    median=int(np.mean([len(g) for g in graphs.values()]))

cat_interval=[
    [1,median],
    [median,median*2],
    [median,1000000]
]
print("Interval",cat_interval)
size_selection=100
cat_size=[
    size_selection/5,
    (size_selection/5)*2,
    (size_selection/5)*2
]

per_size={0:[],1:[],2:[]}
for i,g in graphs.items():
    size_ = len(g)
    for c in range(len(cat_interval)):
        cat=cat_interval[c]
        if size_ >= cat[0] and size_ < cat[1]:
            per_size[c].append(i)
            break

for k,p in per_size.items():
    random.shuffle(p)

selected=[]
for k,p in per_size.items():
    selected.extend(p[:int(cat_size[k])])
print(sorted(selected))

count={0:0,1:0,2:0}
for i in selected:
    size_ = len(graphs[i])
    for c in range(len(cat_interval)):
        cat=cat_interval[c]
        if size_ >= cat[0] and size_ < cat[1]:
            count[c]+=1
            break

print("Check if good proportions {0}".format(count))