From 33da2019d82dfd6aa123cad7f9043e10f46d1bfb Mon Sep 17 00:00:00 2001 From: Decoupes Remy <remy.decoupes@irstea.fr> Date: Tue, 5 Apr 2022 10:13:14 +0200 Subject: [PATCH] debuging of zeroshot in progress --- elasticsearch/src/tf-idf-es.py | 56 +++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/elasticsearch/src/tf-idf-es.py b/elasticsearch/src/tf-idf-es.py index d05abe4..7dcdf85 100644 --- a/elasticsearch/src/tf-idf-es.py +++ b/elasticsearch/src/tf-idf-es.py @@ -141,28 +141,15 @@ if __name__ == '__main__': df_tfidf = tf_idf(df_results["text"].tolist()) df_tfidf["tf_idf_terms"] = df_tfidf.index df_tfidf = df_tfidf.merge(df_results, on="index_of_tweet") - # prepare to Gephi for graph vizu: Graph bipartites. Nodes are Newspaper and TF-IDf - news_paper_name = pd.read_csv("./../params/accountsFollowed.csv") # get account (followed by MOOD) names - news_paper_name["retweeted_status.user.id"] = news_paper_name["twitterID"] # prepare for merge - gephi = df_tfidf - # gephi["Source"] = gephi["user.id"].apply(lambda x: hashlib.md5(str(x).encode()).hexdigest()) # pseudonimization - gephi["Source"] = gephi.index # id du lien - gephi["Target"] = gephi["retweeted_status.user.id"] - gephi["Id"] = gephi.index - gephi["Label"] = gephi["tf_idf_terms"] - gephi["timeset"] = gephi["@timestamp"] - gephi = gephi[gephi["Target"].str.len() !=0] # filter out tweet that are not retweeted - gephi[["ID", "label", "Source", "Target", "Timestamp"]].to_csv( - "analysis-output/acquitaine_script_gephi.csv", - index=False - ) + # Translation to english model_checkpoint_fr = "Helsinki-NLP/opus-mt-fr-en" translator_fr = pipeline("translation", model=model_checkpoint_fr) # zeroshot classification classifier = pipeline("zero-shot-classification", model="digitalepidemiologylab/covid-twitter-bert-v2-mnli") - candidate_labels_fr = ["covid-19", "grippe aviaire", "AMR", "tiques", "autres"] + candidate_labels_fr = ["covid-19"] + # candidate_labels_fr = ["covid-19", "grippe aviaire", "AMR", "tiques", "autres"] candidate_labels_en = ["avian influenza"] # candidate_labels_en = ["covid-19", "avian influenza", "AMR", "tick borne", "others"] classifier_results = [] @@ -170,23 +157,44 @@ if __name__ == '__main__': for i, tweets in tqdm(df_tfidf.iterrows(), total=df_tfidf.shape[0]): text = tweets["text"] try: - text_translated = translator_fr(text)[0]["translation_text"] - classifier_results.append(classifier(text_translated, candidate_labels_en)["scores"]) - item = {"text" : text, "scores" : classifier(text_translated, candidate_labels_en)["scores"]} + text_translated = text + # text_translated = translator_fr(text)[0]["translation_text"] + classifier_results.append(classifier(text_translated, candidate_labels_fr)["scores"]) + item = {"text" : text, "scores" : classifier(text_translated, candidate_labels_fr)["scores"]} classifier_results_2.append(item) except: df_tfidf.drop([i], inplace=True) print("text: " + text + " | translated: " + text_translated) - classifier_df = pd.DataFrame(classifier_results, columns=candidate_labels_en) - f=open("analysis-output/test_2.txt", "w") - for l in classifier_results_2: - f.write(l) - f.close() + classifier_df = pd.DataFrame(classifier_results, columns=candidate_labels_fr) + try: + f = open("analysis-output/test_2.txt", "w") + for l in classifier_results_2: + f.write(str(l)) + f.close() + except: + print("can not save file with results from zeroshot") classifier_df_2 = pd.DataFrame(classifier_results_2) classifier_df_2.to_csv("analysis-output/acquitaine_test.csv") df_tfidf = df_tfidf.join(classifier_df) df_tfidf.to_csv("analysis-output/acquitaine-digitalepidemiologylab.csv") df_tfidf.to_pickle("analysis-output/acquitaine-digitalepidemiologylab.pkl") + + + # prepare to Gephi for graph vizu: Graph bipartites. Nodes are Newspaper and TF-IDf + news_paper_name = pd.read_csv("./../params/accountsFollowed.csv") # get account (followed by MOOD) names + news_paper_name["retweeted_status.user.id"] = news_paper_name["twitterID"] # prepare for merge + gephi = df_tfidf + # gephi["Source"] = gephi["user.id"].apply(lambda x: hashlib.md5(str(x).encode()).hexdigest()) # pseudonimization + gephi["Source"] = gephi.index # id du lien + gephi["Target"] = gephi["retweeted_status.user.id"] + gephi["Id"] = gephi.index + gephi["Label"] = gephi["tf_idf_terms"] + gephi["timeset"] = gephi["@timestamp"] + gephi = gephi[gephi["Target"].str.len() !=0] # filter out tweet that are not retweeted + gephi[["ID", "label", "Source", "Target", "Timestamp"]].to_csv( + "analysis-output/acquitaine_script_gephi.csv", + index=False + ) gephi[["Id", "Label", "Source", "Target", "timeset"]].to_csv( "analysis-output/acquitaine_script_gephi_edge.csv", index=False -- GitLab