From 8444a01cc8e5b105ea97ccce9b8ed202a0b8394e Mon Sep 17 00:00:00 2001 From: Ienco Dino <dino.ienco@irstea.fr> Date: Wed, 10 Mar 2021 19:06:53 +0100 Subject: [PATCH] some modification --- main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/main.py b/main.py index ef76768..0cfe225 100644 --- a/main.py +++ b/main.py @@ -136,6 +136,7 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz #th = 40 n_epochs_warmUp = 40 centers = None + print("PRETRAINING STAGE : AE + CONTRASTIVE LOSS") for e in range(n_epochs_warmUp): f_data, s_data, y_val, = shuffle(f_data, s_data, y_val) data = shuffle(data) @@ -143,6 +144,8 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz trainLoss += trainStepL(model, f_data, s_data, y_val, loss_huber, optimizer2, BATCH_SIZE, e) print("epoch %d with loss %f" % (e, trainLoss)) + + print("COMPUTE INTERMEDIATE CLUSTERING ASSIGNMENT") emb, _, _, _ = model(data) km = KMeans(n_clusters=nClasses) km.fit(emb) @@ -151,6 +154,8 @@ def trainRNNAE(model, nClasses, data, f_data, s_data, y_val, loss_huber, optimiz centers.append( km.cluster_centers_[val]) centers = np.array(centers) + + print("REFINEMENT STEP alternating AE + MANIFOLD STRETCH TOWARDS CENTROIDS and AE + CONTRASTIVE LOSS") for e in range(n_epochs - n_epochs_warmUp): #labelledData, labelsSmall = shuffle(labelledData, labelsSmall) data, centers = shuffle(data, centers) -- GitLab