From 33da2019d82dfd6aa123cad7f9043e10f46d1bfb Mon Sep 17 00:00:00 2001
From: Decoupes Remy <remy.decoupes@irstea.fr>
Date: Tue, 5 Apr 2022 10:13:14 +0200
Subject: [PATCH] debuging of zeroshot in progress

---
 elasticsearch/src/tf-idf-es.py | 56 +++++++++++++++++++---------------
 1 file changed, 32 insertions(+), 24 deletions(-)

diff --git a/elasticsearch/src/tf-idf-es.py b/elasticsearch/src/tf-idf-es.py
index d05abe4..7dcdf85 100644
--- a/elasticsearch/src/tf-idf-es.py
+++ b/elasticsearch/src/tf-idf-es.py
@@ -141,28 +141,15 @@ if __name__ == '__main__':
     df_tfidf = tf_idf(df_results["text"].tolist())
     df_tfidf["tf_idf_terms"] = df_tfidf.index
     df_tfidf = df_tfidf.merge(df_results, on="index_of_tweet")
-    # prepare to Gephi for graph vizu: Graph bipartites. Nodes are Newspaper and TF-IDf
-    news_paper_name = pd.read_csv("./../params/accountsFollowed.csv") # get account (followed by MOOD) names
-    news_paper_name["retweeted_status.user.id"] = news_paper_name["twitterID"] # prepare for merge
-    gephi = df_tfidf
-    # gephi["Source"] = gephi["user.id"].apply(lambda x: hashlib.md5(str(x).encode()).hexdigest())  # pseudonimization
-    gephi["Source"] = gephi.index # id du lien
-    gephi["Target"] = gephi["retweeted_status.user.id"]
-    gephi["Id"] = gephi.index
-    gephi["Label"] = gephi["tf_idf_terms"]
-    gephi["timeset"] = gephi["@timestamp"]
-    gephi = gephi[gephi["Target"].str.len() !=0] # filter out tweet that are not retweeted
-    gephi[["ID", "label", "Source", "Target", "Timestamp"]].to_csv(
-        "analysis-output/acquitaine_script_gephi.csv",
-        index=False
-    )
+
     # Translation to english
     model_checkpoint_fr = "Helsinki-NLP/opus-mt-fr-en"
     translator_fr = pipeline("translation", model=model_checkpoint_fr)
     # zeroshot classification
     classifier = pipeline("zero-shot-classification",
                           model="digitalepidemiologylab/covid-twitter-bert-v2-mnli")
-    candidate_labels_fr = ["covid-19", "grippe aviaire", "AMR", "tiques", "autres"]
+    candidate_labels_fr = ["covid-19"]
+    # candidate_labels_fr = ["covid-19", "grippe aviaire", "AMR", "tiques", "autres"]
     candidate_labels_en = ["avian influenza"]
     # candidate_labels_en = ["covid-19", "avian influenza", "AMR", "tick borne", "others"]
     classifier_results = []
@@ -170,23 +157,44 @@ if __name__ == '__main__':
     for i, tweets in tqdm(df_tfidf.iterrows(), total=df_tfidf.shape[0]):
         text = tweets["text"]
         try:
-            text_translated = translator_fr(text)[0]["translation_text"]
-            classifier_results.append(classifier(text_translated, candidate_labels_en)["scores"])
-            item = {"text" : text, "scores" : classifier(text_translated, candidate_labels_en)["scores"]}
+            text_translated = text
+            # text_translated = translator_fr(text)[0]["translation_text"]
+            classifier_results.append(classifier(text_translated, candidate_labels_fr)["scores"])
+            item = {"text" : text, "scores" : classifier(text_translated, candidate_labels_fr)["scores"]}
             classifier_results_2.append(item)
         except:
             df_tfidf.drop([i], inplace=True)
             print("text: " + text + " | translated: " + text_translated)
-    classifier_df = pd.DataFrame(classifier_results, columns=candidate_labels_en)
-    f=open("analysis-output/test_2.txt", "w")
-    for l in classifier_results_2:
-        f.write(l)
-    f.close()
+    classifier_df = pd.DataFrame(classifier_results, columns=candidate_labels_fr)
+    try:
+        f = open("analysis-output/test_2.txt", "w")
+        for l in classifier_results_2:
+            f.write(str(l))
+        f.close()
+    except:
+        print("can not save file with results from zeroshot")
     classifier_df_2 = pd.DataFrame(classifier_results_2)
     classifier_df_2.to_csv("analysis-output/acquitaine_test.csv")
     df_tfidf = df_tfidf.join(classifier_df)
     df_tfidf.to_csv("analysis-output/acquitaine-digitalepidemiologylab.csv")
     df_tfidf.to_pickle("analysis-output/acquitaine-digitalepidemiologylab.pkl")
+
+
+    # prepare to Gephi for graph vizu: Graph bipartites. Nodes are Newspaper and TF-IDf
+    news_paper_name = pd.read_csv("./../params/accountsFollowed.csv") # get account (followed by MOOD) names
+    news_paper_name["retweeted_status.user.id"] = news_paper_name["twitterID"] # prepare for merge
+    gephi = df_tfidf
+    # gephi["Source"] = gephi["user.id"].apply(lambda x: hashlib.md5(str(x).encode()).hexdigest())  # pseudonimization
+    gephi["Source"] = gephi.index # id du lien
+    gephi["Target"] = gephi["retweeted_status.user.id"]
+    gephi["Id"] = gephi.index
+    gephi["Label"] = gephi["tf_idf_terms"]
+    gephi["timeset"] = gephi["@timestamp"]
+    gephi = gephi[gephi["Target"].str.len() !=0] # filter out tweet that are not retweeted
+    gephi[["ID", "label", "Source", "Target", "Timestamp"]].to_csv(
+        "analysis-output/acquitaine_script_gephi.csv",
+        index=False
+    )
     gephi[["Id", "Label", "Source", "Target", "timeset"]].to_csv(
         "analysis-output/acquitaine_script_gephi_edge.csv",
         index=False
-- 
GitLab