Commit 8444a01c authored by Ienco Dino's avatar Ienco Dino
Browse files

some modification

parent 4f80c5d2
No related merge requests found
Showing with 5 additions and 0 deletions
+5 -0
......@@ -136,6 +136,7 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz
#th = 40
n_epochs_warmUp = 40
centers = None
print("PRETRAINING STAGE : AE + CONTRASTIVE LOSS")
for e in range(n_epochs_warmUp):
f_data, s_data, y_val, = shuffle(f_data, s_data, y_val)
data = shuffle(data)
......@@ -143,6 +144,8 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz
trainLoss += trainStepL(model, f_data, s_data, y_val, loss_huber, optimizer2, BATCH_SIZE, e)
print("epoch %d with loss %f" % (e, trainLoss))
print("COMPUTE INTERMEDIATE CLUSTERING ASSIGNMENT")
emb, _, _, _ = model(data)
km = KMeans(n_clusters=nClasses)
km.fit(emb)
......@@ -151,6 +154,8 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz
centers.append( km.cluster_centers_[val])
centers = np.array(centers)
print("REFINEMENT STEP alternating AE + MANIFOLD STRETCH TOWARDS CENTROIDS and AE + CONTRASTIVE LOSS")
for e in range(n_epochs - n_epochs_warmUp):
#labelledData, labelsSmall = shuffle(labelledData, labelsSmall)
data, centers = shuffle(data, centers)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment