Commit e786b9db authored by Fize Jacques's avatar Fize Jacques
Browse files

Add new parameter to annotation + debug

parent 07c1fe38
No related merge requests found
Showing with 616 additions and 695 deletions
+616 -695
...@@ -28,4 +28,5 @@ __pycache__/ ...@@ -28,4 +28,5 @@ __pycache__/
*cache.json *cache.json
*.gexf *.gexf
temp_cluster_2/ temp_cluster_2/
agromada* agromada*
\ No newline at end of file output*
\ No newline at end of file
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
import argparse, os import argparse, os
import warnings import warnings
import os, re, glob import os, re, glob,json
import pandas as pd
import networkx as nx import networkx as nx
import numpy as np import numpy as np
...@@ -19,18 +18,24 @@ from strpython.models.str import STR ...@@ -19,18 +18,24 @@ from strpython.models.str import STR
from strpython.helpers.sim_matrix import matrix_to_pandas_dataframe, read_bz2_matrix from strpython.helpers.sim_matrix import matrix_to_pandas_dataframe, read_bz2_matrix
def main(dataset, matrix_sim_dir, raw_graph_dir, selected_graphs, threshold, inclusion_fn, adjacency_fn): def main(dataset, matrix_sim_dir, raw_graph_dir, selected_graphs, threshold, inclusion_fn, adjacency_fn,min_carac_fn, min_size_G1,min_size_G2,n_car_min_doc1,n_car_min_doc2):
annotater = AnnotationAutomatic(dataset, threshold, inclusion_fn, adjacency_fn) annotater = AnnotationAutomatic(dataset, threshold, inclusion_fn, adjacency_fn)
first_step_output = "output_first_step_{0}_{1}".format(dataset, threshold) first_step_output = "output_first_step_{0}_{1}".format(dataset, threshold)
last_step_output = "output_final_{0}_{1}".format(dataset, threshold) last_step_output = "output_final_{0}_{1}".format(dataset, threshold)
generate_annotation_dataframe(matrix_sim_dir, selected_graphs, first_step_output) generate_annotation_dataframe(matrix_sim_dir, selected_graphs, first_step_output)
extract_criteria_4_all(annotater, first_step_output, raw_graph_dir, dataset, threshold) # size_str = extract_criteria_4_all(annotater, first_step_output, raw_graph_dir, dataset, threshold)
if not os.path.exists(last_step_output): if not os.path.exists(last_step_output):
os.makedirs(last_step_output) os.makedirs(last_step_output)
for fn in glob.glob(os.path.join(first_step_output,"*.csv")):
annotate_eval_sample(annotater, fn, os.path.join(last_step_output, os.path.basename(fn))) # for fn in tqdm(glob.glob(os.path.join(first_step_output,"*.csv")),desc="Annotate sample"):
synthesize(last_step_output,"{0}_{1}.csv".format(dataset,threshold)) # annotate_eval_sample(annotater, fn, os.path.join(last_step_output, os.path.basename(fn)),size_str)
min_carac_dict=None
if min_carac_fn != "" and os.path.exists(min_carac_fn):
min_carac_dict=json.load(open(min_carac_fn))
synthesize(last_step_output,"{0}_{1}.csv".format(dataset,threshold),min_size_G1,min_size_G2,min_carac_dict,n_car_min_doc1,n_car_min_doc2)
...@@ -47,6 +52,7 @@ def generate_annotation_dataframe(matrix_sim_dir, selected_graphs, output_dir): ...@@ -47,6 +52,7 @@ def generate_annotation_dataframe(matrix_sim_dir, selected_graphs, output_dir):
------- -------
""" """
if not os.path.exists(matrix_sim_dir): if not os.path.exists(matrix_sim_dir):
raise FileNotFoundError("Similarity matrix directory not found : {0}".format(matrix_sim_dir)) raise FileNotFoundError("Similarity matrix directory not found : {0}".format(matrix_sim_dir))
...@@ -55,6 +61,8 @@ def generate_annotation_dataframe(matrix_sim_dir, selected_graphs, output_dir): ...@@ -55,6 +61,8 @@ def generate_annotation_dataframe(matrix_sim_dir, selected_graphs, output_dir):
type_ = "_".join(os.path.basename(fn).split("_")[1:]).replace(".npy.bz2", "") type_ = "_".join(os.path.basename(fn).split("_")[1:]).replace(".npy.bz2", "")
print("Proceeding...", measure, type_) print("Proceeding...", measure, type_)
if os.path.exists(os.path.join(output_dir, "{0}_{1}.csv".format(measure, type_))):
continue
df = matrix_to_pandas_dataframe(np.nan_to_num(read_bz2_matrix(fn)), df = matrix_to_pandas_dataframe(np.nan_to_num(read_bz2_matrix(fn)),
selected_graphs, selected_graphs,
measure, type_) measure, type_)
...@@ -96,30 +104,38 @@ def extract_criteria_4_all(annotater, csv_input_dir, raw_graph_dir, dataset, thr ...@@ -96,30 +104,38 @@ def extract_criteria_4_all(annotater, csv_input_dir, raw_graph_dir, dataset, thr
# Load STRs # Load STRs
strs = {} strs = {}
size_STR={}
def load(fn):
id_ = int(re.findall("\d+", fn)[-1])
strs[id_] = STR.from_networkx_graph(nx.read_gexf(fn))
size_STR[id_] = len(strs[id_])
for file in tqdm(glob.glob(os.path.join(raw_graph_dir, "*.gexf")), desc="Load Graphs"): for file in tqdm(glob.glob(os.path.join(raw_graph_dir, "*.gexf")), desc="Load Graphs"):
id_ = int(re.findall("\d+", file)[-1]) id_ = int(re.findall("\d+", file)[-1])
strs[id_] = STR.from_networkx_graph(nx.read_gexf(file)) strs[id_] = STR.from_networkx_graph(nx.read_gexf(file))
size_STR[id_]= len(strs[id_])
# Do the annotation for a match between two STR #Do the annotation for a match between two STR
def annotate(x): def annotate(x):
try: try:
return annotater.all(strs[int(x.G1)], strs[int(x.G2)], int(x.G1), int(x.G2)) return annotater.all(strs[int(x.G1)], strs[int(x.G2)], int(x.G1), int(x.G2))
except KeyError as e: except KeyError as e:
annotater.matching_cache.add(int(x.G1), int(x.G2), *(0, 0, 0, 0)) annotater.matching_cache.add(int(x.G1), int(x.G2), *(0, 0, 0, 0,300000))
return [0, 0, 0, 0] return [0, 0, 0, 0,300000,0]
# Annotation Time # Annotation Time
print("Computing Criteria for each match")
matching_dataframe["res"] = matching_dataframe.progress_apply(lambda x: annotate(x), axis=1) matching_dataframe["res"] = matching_dataframe.progress_apply(lambda x: annotate(x), axis=1)
matching_dataframe.res = matching_dataframe.res.apply(lambda x: list(map(int, x)) if x else []) matching_dataframe.res = matching_dataframe.res.apply(lambda x: [int(x[0]),int(x[1]),int(x[2]),int(x[3]),float(x[4])] if x else [])
for ix, col in enumerate("c1 c2 c3 c4".split()): for ix, col in enumerate("c1 c2 c3 c4 c5".split()):
matching_dataframe[col] = matching_dataframe.res.apply(lambda x: x[ix] if len(x) > 0 else 0) matching_dataframe[col] = matching_dataframe.res.apply(lambda x: x[ix] if len(x) > 0 else 0)
del matching_dataframe["res"] del matching_dataframe["res"]
# Writiting output # Writiting output
matching_dataframe.to_csv(output_file) return size_STR
def annotate_eval_sample(annotater, csv_file, output_file): def annotate_eval_sample(annotater, csv_file, output_file, size_str):
""" """
Third Step Third Step
Parameters Parameters
...@@ -141,21 +157,23 @@ def annotate_eval_sample(annotater, csv_file, output_file): ...@@ -141,21 +157,23 @@ def annotate_eval_sample(annotater, csv_file, output_file):
try: try:
return annotater.all(None, None, x.G1, x.G2) return annotater.all(None, None, x.G1, x.G2)
except Exception as e: except Exception as e:
return [0, 0, 0, 0] return [0, 0, 0, 0,300000]
df["res"] = df.apply(lambda x: foo(x), axis=1) df["res"] = df.apply(lambda x: foo(x), axis=1)
df.res = df.res.apply(lambda x: list(map(int, x)) if x else []) # if bool df.res = df.res.apply(lambda x: list(map(float, x)) if x else []) # if bool
df[["c1"]] = df.res.apply(lambda x: x[0] if len(x) > 0 else 0) df[["c1"]] = df.res.apply(lambda x: x[0] if len(x) > 0 else 0)
df[["c2"]] = df.res.apply(lambda x: x[1] if len(x) > 0 else 0) df[["c2"]] = df.res.apply(lambda x: x[1] if len(x) > 0 else 0)
df[["c3"]] = df.res.apply(lambda x: x[2] if len(x) > 0 else 0) df[["c3"]] = df.res.apply(lambda x: x[2] if len(x) > 0 else 0)
df[["c4"]] = df.res.apply(lambda x: x[3] if len(x) > 0 else 0) df[["c4"]] = df.res.apply(lambda x: x[3] if len(x) > 0 else 0)
df[["c5"]] = df.res.apply(lambda x: x[4] if len(x) > 0 else 300000)
df["size_G1"] =df.apply(lambda x: size_str[x.G1] if x.G1 in size_str else 0, axis=1)
df["size_G2"] = df.apply(lambda x: size_str[x.G2] if x.G2 in size_str else 0, axis=1)
del df["res"] del df["res"]
df.to_csv(output_file) df.to_csv(output_file)
def synthesize(last_step_output,output_filename): def synthesize(last_step_output,output_filename,min_size_G1=None,min_size_G2=None,min_carac_dict=None,ncar_min_doc1=0,ncar_min_doc2=0):
""" """
Fourth Step Fourth Step
Parameters Parameters
...@@ -168,16 +186,41 @@ def synthesize(last_step_output,output_filename): ...@@ -168,16 +186,41 @@ def synthesize(last_step_output,output_filename):
""" """
fns = glob.glob(os.path.join(last_step_output, "*.csv")) fns = glob.glob(os.path.join(last_step_output, "*.csv"))
if min_size_G1:
output_filename= output_filename+"_ming1_{0}".format(min_size_G1)
if min_size_G2:
output_filename= output_filename+"_ming2_{0}".format(min_size_G2)
if min_carac_dict and ncar_min_doc1 > 0:
output_filename= output_filename+"_mindoc1len_{0}".format(ncar_min_doc1)
if min_carac_dict and ncar_min_doc2 > 0:
output_filename= output_filename+"_mindoc2len_{0}".format(ncar_min_doc2)
data = [] data = []
for fn in fns: for fn in tqdm(fns,desc="Synthetise Results"):
df = pd.read_csv(fn) df = pd.read_csv(fn)
if min_size_G1:
df= df[df.size_G1 >= min_size_G1]
if min_size_G2:
df = df[df.size_G2 >= min_size_G2]
if min_carac_dict and ncar_min_doc1>0:
df["len_doc1"]=df.apply(lambda x:min_carac_dict[str(x.G1)],axis=1)
df =df[df.len_doc1 >= ncar_min_doc1]
if min_carac_dict and ncar_min_doc2>0:
df["len_doc2"]=df.apply(lambda x:min_carac_dict[str(x.G2)] if str(x.G2) in min_carac_dict else 0,axis=1)
df =df[df.len_doc2 >= ncar_min_doc2]
df = df.replace([np.inf, -np.inf], 300000)
df["c5"] = 1 - (df.c5 - df.c5.min()) / (df.c5.max() - df.c5.min())
if len(df) <1:
continue
mes = np.unique(df.sim_measure)[0] mes = np.unique(df.sim_measure)[0]
type_ = np.unique(df.type_str)[0] type_ = np.unique(df.type_str)[0]
val = df.groupby("G1").mean().mean()["c1 c2 c3 c4".split()].values.tolist() val = df.groupby("G1").mean().mean()["c1 c2 c3 c4 c5".split()].values.tolist()
val.insert(0, type_) val.insert(0, type_)
val.insert(0, mes) val.insert(0, mes)
data.append(val) data.append(val)
pd.DataFrame(data, columns="mesure type c1 c2 c3 c4".split())
res = pd.DataFrame(data, columns="mesure type c1 c2 c3 c4".split()) res = pd.DataFrame(data, columns="mesure type c1 c2 c3 c4 c5".split())
res.to_csv(output_filename) res.to_csv(output_filename)
\ No newline at end of file
notebooks/MatchingAnalysis/.ipynb_checkpoints/output-checkpoint.png

23.4 KB

File deleted
This diff is collapsed.
notebooks/MatchingAnalysis/c1.png

29.6 KB

notebooks/MatchingAnalysis/c2.png

32.1 KB

notebooks/MatchingAnalysis/c3.png

30.4 KB

notebooks/MatchingAnalysis/c4.png

31 KB

notebooks/MatchingAnalysis/output.png

30.5 KB

notebooks/MatchingAnalysis/sum.png

22.6 KB

# coding = utf-8 # coding = utf-8
import argparse import argparse, shutil, os
import logging,json import logging, json
from mytoolbox.env import yes_or_no
from auto_fill_annotation import main
for _ in ("boto", "elasticsearch", "urllib3", "sklearn"): for _ in ("boto", "elasticsearch", "urllib3", "sklearn"):
logging.getLogger(_).setLevel(logging.CRITICAL) logging.getLogger(_).setLevel(logging.CRITICAL)
parser=argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("dataset",help="Name of the dataset") parser.add_argument("dataset", help="Name of the dataset")
parser.add_argument("sim_matrix_dir",help="Similarity Matrix Directory") parser.add_argument("sim_matrix_dir", help="Similarity Matrix Directory")
parser.add_argument("graph_data_dir", help="STR without transformation graph directory") parser.add_argument("graph_data_dir", help="STR without transformation graph directory")
parser.add_argument("adjacency_fn", help ="Adjacency Information json filename") parser.add_argument("adjacency_fn", help="Adjacency Information json filename")
parser.add_argument("inclusion_fn", help ="Inclusion Information json filename") parser.add_argument("inclusion_fn", help="Inclusion Information json filename")
parser.add_argument("selected_json_file",help="Filename containing the STR graph you want to make your evaluation on") parser.add_argument("selected_json_file", help="Filename containing the STR graph you want to make your evaluation on")
parser.add_argument("-t","--threshold",default=0.5,help="Threshold for the third criteria") parser.add_argument("length_json_file", help="Filename containing the STR text length")
parser.add_argument("-t", "--threshold", default=0.5, help="Threshold for the third criteria")
parser.add_argument("-g", "--ming1",type=int, default=0, help="Return evaluation results based on min size for G1")
parser.add_argument("-j", "--ming2",type=int, default=0, help="Return evaluation results based on min size for G2")
parser.add_argument("-m", "--nb_car_doc1",type=int, default=0, help="Return evaluation results based on min size of associated text for G1")
parser.add_argument("-n", "--nb_car_doc2",type=int, default=0, help="Return evaluation results based on min size of associated text for G2")
args = parser.parse_args()
if os.path.exists("temp_cluster") and yes_or_no("Do you want to compute STR's clusters all over again ?"):
shutil.rmtree('temp_cluster', ignore_errors=True)
os.makedirs("temp_cluster")
args=parser.parse_args()
from auto_fill_annotation import main
main(args.dataset,args.sim_matrix_dir,args.graph_data_dir,json.load(open(args.selected_json_file)),args.threshold,args.inclusion_fn, args.adjacency_fn) main(args.dataset,
\ No newline at end of file args.sim_matrix_dir,
args.graph_data_dir,
json.load(open(args.selected_json_file)),
args.threshold,
args.inclusion_fn,
args.adjacency_fn,
args.length_json_file,
args.ming1,
args.ming2,
args.nb_car_doc1,
args.nb_car_doc2)
run_test.py 0 → 100644
# coding = utf-8
import argparse
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from skcriteria.madm import closeness, simple
from skcriteria import Data, MIN, MAX
def pareto_frontier_multi(myArray):
# Sort on first dimension
myArray = myArray[myArray[:, 0].argsort()]
# Add first row to pareto_frontier
pareto_frontier = myArray[0:1, :]
indices, i = [], 1
# Test next row against the last row in pareto_frontier
for row in myArray[1:, :]:
if sum([row[x] >= pareto_frontier[-1][x]
for x in range(len(row))]) == len(row):
# If it is better on all features add the row to pareto_frontier
pareto_frontier = np.concatenate((pareto_frontier, [row]))
indices.append(i)
i += 1
return indices, pareto_frontier
parser = argparse.ArgumentParser()
parser.add_argument("input")
parser.add_argument("output_fn")
parser.add_argument("-t","--topn",type=int,default=5)
args = parser.parse_args()
writer = pd.ExcelWriter(args.output_fn, engine='xlsxwriter')
if not os.path.exists(args.input):
raise FileNotFoundError("{0} does not exists !".format(args.input))
data = pd.read_csv(args.input, index_col=0)
data["mesure"] = data.mesure.apply(lambda x: "BOW" if x == "BagOfNodes" else x)
data["sum"] = data["c1 c2 c3 c4 c5".split()].sum(axis=1)
combination_pareto_criteria = [
("c1_c2_c3_c4_c5", "c1 c2 c3 c4 c5".split()),
("c1_c2_c5", "c1 c2 c5".split()),
("c1_c2_c3", "c1 c2 c3".split()),
("c3_c4", "c3 c4".split()),
("c5", "c5".split()),
("c2", "c2".split()),
]
weight_criteria = [
("all_0.2", [0.2, 0.2, 0.2, 0.2, 0.2]),
("c1_0.5_c5_0.5", [0.5, 0., 0., 0., 0.5]),
("c2_0.5_c5_0.5", [0., 0.5, 0., 0., 0.5]),
("c1_0.33_c2_0.33_c3_0,33", [0.33, 0.33, 0.33, 0., 0.]),
("c1_0.5_c2_0.5", [0.5, 0.5, 0., 0., 0.]),
("c3_0.5_c4_0.5", [0., 0., 0.5, 0.5, 0.])
]
def get_top_combination_wsm(dataframe, weights,topn):
data = dataframe["c1 c2 c3 c4 c5".split()].values
dd = Data(data, criteria=[MAX, MAX, MAX, MAX, MAX], weights=weights[1])
index_max = np.argsort(simple.WeightedSum().decide(dd)._rank)[:topn]
df = dataframe.iloc[index_max]
df["name"]=weights[0]
df["type_score"] = "wsm"
return df
def get_top_combination_pareto(dataframe, columns,topn):
index, data_pa = pareto_frontier_multi(dataframe[columns[1]].values)
df = data.iloc[index]
df = df.sort_values(by = "sum",ascending=False).head(topn)
df["name"]=columns[0]
df["type_score"] = "pareto"
return df
def write_excel(writer, dataframe, title):
dataframe.to_excel(writer, "result", index=False)
number_of_rows=len(dataframe)
worksheet = writer.sheets["result"]
workbook = writer.book
C_letter = 67
I_letter= 73
format1 = workbook.add_format({'bg_color': '#FFC7CE',
'font_color': '#9C0006'})
# Add a format. Green fill with dark green text.
format2 = workbook.add_format({'bg_color': '#C6EFCE',
'font_color': '#006100'})
for i in range(C_letter,I_letter):
begin=2
for end in range(6,number_of_rows + 1,5):
ch_=chr(i)
color_range = "{0}{1}:{0}{2}".format(ch_,begin,end)
worksheet.conditional_format(color_range, {'type': 'bottom',
'value': '1',
'format': format1})
worksheet.conditional_format(color_range, {'type': 'top',
'value': '1',
'format': format2})
begin=end+1
writer.save()
result = None
for comb_ in tqdm(combination_pareto_criteria, desc="Pareto computation"):
dd = get_top_combination_pareto(data, comb_,args.topn)
if not isinstance(result,pd.DataFrame):
result = dd
else:
result = pd.concat((result,dd),axis=0)
for weight in tqdm(weight_criteria, desc="WSM computation"):
dd= get_top_combination_wsm(data, weight,args.topn)
if not isinstance(result,pd.DataFrame):
result = dd
else:
result = pd.concat((result,dd),axis=0)
write_excel(writer,result,args.output_fn.split("/")[-1])
\ No newline at end of file
...@@ -9,6 +9,17 @@ from ..models.str import STR ...@@ -9,6 +9,17 @@ from ..models.str import STR
from ..helpers.match_cache import MatchingCache from ..helpers.match_cache import MatchingCache
from ..helpers.relation_extraction import AdjacencyRelation, InclusionRelation from ..helpers.relation_extraction import AdjacencyRelation, InclusionRelation
import sys
class JsonProgress(object):
def __init__(self,fn):
self.count = 0
self.fn= fn
def __call__(self, obj):
self.count += 1
if self.count %10 == 0:
sys.stdout.write("\rLoading"+self.fn+": %8d" % self.count)
return obj
class AnnotationAutomatic(object): class AnnotationAutomatic(object):
""" """
...@@ -23,10 +34,11 @@ class AnnotationAutomatic(object): ...@@ -23,10 +34,11 @@ class AnnotationAutomatic(object):
self.inc_rel_db = InclusionRelation() self.inc_rel_db = InclusionRelation()
self.inclusion,self.adjacency = {},{} self.inclusion,self.adjacency = {},{}
if inclusion_fn: if inclusion_fn:
self.inclusion = json.load(open(inclusion_fn)) self.inclusion = json.load(open(inclusion_fn),object_hook=JsonProgress(inclusion_fn))
if adjacency_fn: if adjacency_fn:
self.adjacency = json.load(open(adjacency_fn)) self.adjacency = json.load(open(adjacency_fn),object_hook=JsonProgress(adjacency_fn))
self.threshold = threshold_c3 self.threshold = threshold_c3
def all(self, str1, str2, id1=None, id2=None): def all(self, str1, str2, id1=None, id2=None):
""" """
...@@ -47,7 +59,7 @@ class AnnotationAutomatic(object): ...@@ -47,7 +59,7 @@ class AnnotationAutomatic(object):
return list(value) return list(value)
crit_ = [self.criterion1(str1, str2), self.criterion2(str1, str2), self.criterion3(str1, str2, id1, id2), crit_ = [self.criterion1(str1, str2), self.criterion2(str1, str2), self.criterion3(str1, str2, id1, id2),
self.criterion4(str1, str2, id1, id2)] self.criterion4(str1, str2, id1, id2),self.criteria5(str1, str2, id1, id2)]
self.matching_cache.add(id1, id2, *crit_) self.matching_cache.add(id1, id2, *crit_)
return crit_ return crit_
......
...@@ -27,7 +27,7 @@ class GeoRelationMatchingDatabase(): ...@@ -27,7 +27,7 @@ class GeoRelationMatchingDatabase():
(idse1 text, idse2 text, value integer) (idse1 text, idse2 text, value integer)
""" """
matching_schema = """CREATE TABLE matching matching_schema = """CREATE TABLE matching
(dataset text, g1 integer, g2 integer, c1 integer, c2 integer, c3 integer,c4 integer) (dataset text, g1 integer, g2 integer, c1 integer, c2 integer, c3 integer,c4 integer, c5 REAL )
""" """
cursor.execute(inclusion_schema) cursor.execute(inclusion_schema)
cursor.execute(adjacency_schema) cursor.execute(adjacency_schema)
...@@ -74,7 +74,7 @@ class GeoRelationMatchingDatabase(): ...@@ -74,7 +74,7 @@ class GeoRelationMatchingDatabase():
self._db_connection.commit() self._db_connection.commit()
cursor.close() cursor.close()
def add_matching(self, dataset: str, G1: int, G2: int, c1: bool, c2: bool, c3: bool, c4: bool): def add_matching(self, dataset: str, G1: int, G2: int, c1: bool, c2: bool, c3: bool, c4: bool,c5: float):
""" """
Add a matching criteria result within the database Add a matching criteria result within the database
Parameters Parameters
...@@ -96,8 +96,8 @@ class GeoRelationMatchingDatabase(): ...@@ -96,8 +96,8 @@ class GeoRelationMatchingDatabase():
""" """
cursor = self._db_connection.cursor() cursor = self._db_connection.cursor()
cursor.execute('INSERT INTO matching VALUES(?,?,?,?,?,?,?)', cursor.execute('INSERT INTO matching VALUES(?,?,?,?,?,?,?,?)',
(dataset, G1, G2, int(c1), int(c2), int(c3), int(c4))) (dataset, G1, G2, int(c1), int(c2), int(c3), int(c4),float(c5)))
self._db_connection.commit() self._db_connection.commit()
cursor.close() cursor.close()
...@@ -169,7 +169,7 @@ class GeoRelationMatchingDatabase(): ...@@ -169,7 +169,7 @@ class GeoRelationMatchingDatabase():
result_ = cursor.fetchone() result_ = cursor.fetchone()
cursor.close() cursor.close()
if result_: if result_:
return True, tuple(map(int, result_[-4:])) return True, tuple(map(float, result_[-5:]))
return False, False return False, False
...@@ -185,9 +185,9 @@ if __name__ == "__main__": ...@@ -185,9 +185,9 @@ if __name__ == "__main__":
assert g.get_inclusion("GD1", "GD2") == (True, True) assert g.get_inclusion("GD1", "GD2") == (True, True)
assert g.get_inclusion("GD2", "GD1") == (False, False) assert g.get_inclusion("GD2", "GD1") == (False, False)
g.add_matching("test", 1, 2, True, True, False, True) g.add_matching("test", 1, 2, True, True, False, True,0.)
g.add_matching("test2", 1, 2, True, False, False, True) g.add_matching("test2", 1, 2, True, False, False, True,0.)
assert g.get_matching(1, 2, "test") == (True, (True, True, False, True)) assert g.get_matching(1, 2, "test") == (True, (True, True, False, True,0.))
assert g.get_matching(1, 2, "test2") != (True, (True, True, False, True)) assert g.get_matching(1, 2, "test2") != (True, (True, True, False, True,0.))
print("Passed the tests !") print("Passed the tests !")
...@@ -12,6 +12,6 @@ class MatchingCache: ...@@ -12,6 +12,6 @@ class MatchingCache:
def is_match(self, id_str1: int, id_str2: int): def is_match(self, id_str1: int, id_str2: int):
return self.db_rel_match.get_matching(id_str1, id_str2, self.dataset) return self.db_rel_match.get_matching(id_str1, id_str2, self.dataset)
def add(self, id_str1: int, id_str2: int, c1: int, c2: int, c3: int, c4: int): def add(self, id_str1: int, id_str2: int, c1: int, c2: int, c3: int, c4: int, c5: float):
if not self.is_match(id_str1, id_str2)[0]: if not self.is_match(id_str1, id_str2)[0]:
self.db_rel_match.add_matching(self.dataset, id_str1, id_str2, c1, c2, c3, c4) self.db_rel_match.add_matching(self.dataset, id_str1, id_str2, c1, c2, c3, c4,c5)
...@@ -36,6 +36,6 @@ def matrix_to_pandas_dataframe(matrix, selected, sim_measure, type_str, n=5): ...@@ -36,6 +36,6 @@ def matrix_to_pandas_dataframe(matrix, selected, sim_measure, type_str, n=5):
top_n = np.argsort(matrix[line])[::-1][1:n + 1] top_n = np.argsort(matrix[line])[::-1][1:n + 1]
rank = 1 rank = 1
for val in top_n: for val in top_n:
tab_array.append([line, val, sim, type_, rank, 0, 0, 0, 0]) tab_array.append([line, val, sim, type_, rank, 0, 0, 0, 0,300000])
rank += 1 rank += 1
return pd.DataFrame(tab_array, columns="G1 G2 sim_measure type_str rank c1 c2 c3 c4".split()) return pd.DataFrame(tab_array, columns="G1 G2 sim_measure type_str rank c1 c2 c3 c4 c5".split())
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment