Commit fcc41b98 authored by Decoupes Remy's avatar Decoupes Remy
Browse files

visualize sentence transformers through t-sne

parent 2532e56a
......@@ -16,7 +16,9 @@ import pandas as pd
import as px
import as pio
from plotly.subplots import make_subplots
from collections import defaultdict
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.manifold import TSNE
def logsetup():
......@@ -159,6 +161,16 @@ def time_series_by_disease_keywords(jinja_env, es_url, index_es, list_of_keyword
return df_all_kw_timeserie
def get_tweet_content_by_disease(jinja_env, es_url, index_es, list_of_keywords, disease, nb_of_estimated_results=10000):
Retrieves all tweets for a specific disease thanks its keywords
:param jinja_env:
:param es_url:
:param index_es:
:param list_of_keywords:
:param disease:
:param nb_of_estimated_results:
template = jinja_env.get_template("get_tweets_content_by_keywords.json.j2")
query = template.render(list_of_keywords=list_of_keywords)
headers = {'content-type': 'application/json'}
......@@ -179,9 +191,33 @@ def get_tweet_content_by_disease(jinja_env, es_url, index_es, list_of_keywords,
df_results = pd.DataFrame(list_of_tweets)
# df_results.to_pickle("/home/rdecoupe/Téléchargements/test/get_tweet_content_by_disease.pkl")
return df_results
def visualize_sentence_through_embedding(corpus):
# How to choose the model: An overview on huggingface
multi_lingual_model = "distiluse-base-multilingual-cased-v1"
best_quality_model = "all-mpnet-base-v2"
faster_model = "all-MiniLM-L6-v2" # used by UKPLab :
embedder = SentenceTransformer(faster_model)
# Encode !
corpus_embeddings = embedder.encode(corpus)
# Normalize the embeddings to unit length
corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1, keepdims=True)
# Dimension reduction with t-SNE
tsne_model = TSNE(perplexity=40, n_components=2, init='pca', n_iter=2500, random_state=23)
corpus_embeddings_tsne = tsne_model.fit_transform(corpus_embeddings)
corpus_embeddings_tsne_df = pd.DataFrame(corpus_embeddings_tsne)
corpus_embeddings_tsne_df["label"] = corpus
# plot with plotly express
fig = px.scatter(
corpus_embeddings_tsne_df, x=0, y=1,
if __name__ == '__main__':
logger = logsetup()"EDA start")
......@@ -289,5 +325,8 @@ if __name__ == '__main__':
list_of_keywords = ['Fowl', 'Bird', 'Avian', 'HPAI', 'FowlPlague', 'AvianInfluenza', 'avianInfluenza',
'Avianflu', 'bird', 'BirdFlu']
corpus_tweets = get_tweet_content_by_disease(jinja_env, es_url, index_es, list_of_keywords, disease)
corpus_tweets_list = corpus_tweets.text.values.tolist()
corpus_tweets_list = list(set(corpus_tweets_list)) #Remove duplicate tweets (mostly RT)
visualize_sentence_through_embedding(corpus_tweets_list)"EDA stop")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment