classify.py 2.12 KB
Newer Older
Dino Ienco's avatar
Dino Ienco committed
import numpy as np
import sys
import glob
from scipy import sparse
from gbssl import LGC,HMN,PARW,MAD,OMNIProp,CAMLP
from sklearn.neighbors import kneighbors_graph
from scipy.sparse import coo_matrix
from sklearn.preprocessing import normalize
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import f1_score
import os.path

def extractKNNGraph(knn, k):
	nrow, _ = knn.shape
	G = sparse.lil_matrix((nrow,nrow))

	for i in range(nrow):
		for j in knn[i,1:k+1]:
			G[i,j]=1

	G_trans = G.transpose()
	#MUTUAL KNN GRAPH
	mKNN = np.minimum(G.todense(),G_trans.todense())
	return sparse.lil_matrix(mKNN)


def getKNNEucl(X):
	dist = pdist(X,'euclidean')
	dist = squareform(dist)

	knn = np.argsort(dist,axis=1)
	return knn



def classify(directoy, embFileName, labelFileName, numberOfNearestNeighbors):
	#Y = np.load(directory+"/class.npy")
	X = np.load(embFileName)
	scaler = MinMaxScaler()
	scaler.fit(X)
	X = scaler.transform(X)

	knn = getKNNEucl(X)
	G = extractKNNGraph(knn, numberOfNearestNeighbors)
	nrow, _ = G.shape

	#print directory+"/labels/"+str(runId)+"_"+str(nsamples)+".npy"
	labeled = np.load( labelFileName )
	id_labeled = labeled[:,0].astype("int")
	cl_labeled = labeled[:,1].astype("int")
	camlp = CAMLP(graph=G)
	camlp.fit(np.array(id_labeled),np.array(cl_labeled))

	prob_cl = camlp.predict_proba(np.arange(nrow))
	predict = np.argmax(prob_cl,axis=1)
	return predict


#Directory Name on which data are stored
directory = sys.argv[1]

#File that contains the new representation learned with the SESAM approach or any other data representation
embFileName = sys.argv[2]

#File in the directory/labels folder with label information
#The file has as many row as the number of labeled example
#Each row has two information:  the position of the labeled example w.r.t. the data file data.npy, the associated label
labelFileName = sys.argv[3]

numberOfNearestNeighbors = 20
prediction = classify(directory, embFileName, labelFileName, numberOfNearestNeighbors)
print(prediction)