Commit 8444a01c authored by Ienco Dino's avatar Ienco Dino
Browse files

some modification

parent 4f80c5d2
...@@ -136,6 +136,7 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz ...@@ -136,6 +136,7 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz
#th = 40 #th = 40
n_epochs_warmUp = 40 n_epochs_warmUp = 40
centers = None centers = None
print("PRETRAINING STAGE : AE + CONTRASTIVE LOSS")
for e in range(n_epochs_warmUp): for e in range(n_epochs_warmUp):
f_data, s_data, y_val, = shuffle(f_data, s_data, y_val) f_data, s_data, y_val, = shuffle(f_data, s_data, y_val)
data = shuffle(data) data = shuffle(data)
...@@ -143,6 +144,8 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz ...@@ -143,6 +144,8 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz
trainLoss += trainStepL(model, f_data, s_data, y_val, loss_huber, optimizer2, BATCH_SIZE, e) trainLoss += trainStepL(model, f_data, s_data, y_val, loss_huber, optimizer2, BATCH_SIZE, e)
print("epoch %d with loss %f" % (e, trainLoss)) print("epoch %d with loss %f" % (e, trainLoss))
print("COMPUTE INTERMEDIATE CLUSTERING ASSIGNMENT")
emb, _, _, _ = model(data) emb, _, _, _ = model(data)
km = KMeans(n_clusters=nClasses) km = KMeans(n_clusters=nClasses)
km.fit(emb) km.fit(emb)
...@@ -151,6 +154,8 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz ...@@ -151,6 +154,8 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz
centers.append( km.cluster_centers_[val]) centers.append( km.cluster_centers_[val])
centers = np.array(centers) centers = np.array(centers)
print("REFINEMENT STEP alternating AE + MANIFOLD STRETCH TOWARDS CENTROIDS and AE + CONTRASTIVE LOSS")
for e in range(n_epochs - n_epochs_warmUp): for e in range(n_epochs - n_epochs_warmUp):
#labelledData, labelsSmall = shuffle(labelledData, labelsSmall) #labelledData, labelsSmall = shuffle(labelledData, labelsSmall)
data, centers = shuffle(data, centers) data, centers = shuffle(data, centers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment