Commit a0d78cb7 authored by Eelco van der Wel's avatar Eelco van der Wel :speech_balloon:
Browse files

dataframe

parent b164f6df
Pipeline #12190 failed with stage
in 5 minutes and 14 seconds
Showing with 3023 additions and 26 deletions
+3023 -26
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -2,6 +2,7 @@ import json
from dataclasses import dataclass
from typing import Dict, List
import hdbscan
import numpy as np
from bertopic import BERTopic
from sentence_transformers import SentenceTransformer
......@@ -12,9 +13,11 @@ from .preprocessing import get_stopwords
MODEL_NAME = "paraphrase-multilingual-MiniLM-L12-v2"
@dataclass
class TopicModelFit:
"""return type for TopicModel.fit"""
model: BERTopic
assignments: List[int]
probs: np.ndarray
......@@ -22,14 +25,32 @@ class TopicModelFit:
topic_descriptions: Dict[int, List[str]]
document_relevance: Dict[int, np.ndarray]
class TopicModel:
def __init__(self) -> None:
self.embedding_model = SentenceTransformer(MODEL_NAME)
def fit(self, corpus: List[str], diversity: float=0.1, ngram_range=(1, 2), num_topic_descriptors=3):
def fit(
self,
corpus: List[str],
diversity: float = 0.01,
ngram_range=(1, 2),
num_topic_descriptors=3,
):
vectorizer_model = CountVectorizer(ngram_range=ngram_range, stop_words=get_stopwords())
clustering_model = hdbscan.HDBSCAN(
min_cluster_size=5,
metric="euclidean",
cluster_selection_method="eom",
prediction_data=True,
# cluster_selection_epsilon=0.1,
)
topic_model = BERTopic(
vectorizer_model=vectorizer_model, language="multilingual", diversity=diversity, embedding_model=self.embedding_model
vectorizer_model=vectorizer_model,
language="multilingual",
diversity=diversity,
embedding_model=self.embedding_model,
hdbscan_model=clustering_model,
)
embeddings = self.embedding_model.encode(sentences=corpus)
assignments, probs = topic_model.fit_transform(corpus, embeddings=embeddings)
......@@ -45,7 +66,9 @@ class TopicModel:
document_relevance=relevance,
)
def get_document_relevance(self, topic_model: BERTopic, document_embeddings: np.ndarray) -> Dict[int, np.ndarray]:
def get_document_relevance(
self, topic_model: BERTopic, document_embeddings: np.ndarray
) -> Dict[int, np.ndarray]:
"""Returns relevance (cosine similarity) for each topic for all documents.
Args:
......@@ -66,7 +89,6 @@ class TopicModel:
topic_relevance[topic_idx] = similarity_matrix[i]
return topic_relevance
def get_topic_descriptions(self, topic_model, num_descriptors=3) -> Dict[int, str]:
topics = topic_model.get_topics()
return {k: [v_i[0] for v_i in v[:num_descriptors]] for k, v in topics.items()}
import json
from datetime import datetime
from os import environ
from typing import Dict, List, Optional, Tuple
from os import environ
import pandas as pd
from pymemri.data.schema import Tweet
from pymemri.pod.client import PodClient
from pymemri.plugin.constants import POD_FULL_ADDRESS_ENV
from pymemri.plugin.pluginbase import PluginBase
from pymemri.plugin.states import RUN_DAEMON
from pymemri.webserver.public_api import register_endpoint
from pymemri.plugin.constants import POD_FULL_ADDRESS_ENV
from pymemri.pod.api import DEFAULT_POD_ADDRESS
from pymemri.pod.client import PodClient
from pymemri.webserver.public_api import register_endpoint
from .model import TopicModel, TopicModelFit
from .preprocessing import preprocess_tweets
from .schema import Cluster, ClusterEntry, TwitterTopicModel
from .utils import get_tweets
NUM_TOPIC_DESCRIPTORS = 3
DESCRIPTION_DIVERSITY = 0.1
DESCRIPTION_DIVERSITY = 0.05
class TwitterTopicModelPlugin(PluginBase):
schema_classes = [Cluster, ClusterEntry, TwitterTopicModel, Tweet]
......@@ -38,18 +37,24 @@ class TwitterTopicModelPlugin(PluginBase):
print("Starting daemon mode")
def fit_topic_model(self, tweets: List[Tweet]):
tweet_contents = preprocess_tweets(tweets)
model_fit = self.model.fit(list(tweet_contents.values()), diversity=DESCRIPTION_DIVERSITY, num_topic_descriptors=NUM_TOPIC_DESCRIPTORS)
clustered_tweets = self.get_tweets_per_cluster(tweets, tweet_contents, model_fit)
tweets_df = preprocess_tweets(tweets)
tweet_contents = tweets_df["message"].to_list()
model_fit = self.model.fit(
tweet_contents,
diversity=DESCRIPTION_DIVERSITY,
num_topic_descriptors=NUM_TOPIC_DESCRIPTORS
)
clustered_tweets = self.get_tweets_per_cluster(tweets, tweets_df, model_fit)
return model_fit, clustered_tweets
def get_tweets_per_cluster(self, tweets: List[Tweet], tweet_contents: Dict[str, str], model_fit: TopicModelFit) -> Dict[int, List[Tuple[Tweet, float]]]:
def get_tweets_per_cluster(self, tweets: List[Tweet], tweets_df: pd.DataFrame, model_fit: TopicModelFit) -> Dict[int, List[Tuple[Tweet, float]]]:
"""
Returns tweets and tweet relevances as tuple, clustered by topic.
"""
clustered_tweets = {topic_idx: list() for topic_idx in model_fit.topic_descriptions.keys()}
tweet_id_map = {tweet.id: tweet for tweet in tweets}
for document_idx, (tweet_id, topic_idx) in enumerate(zip(tweet_contents.keys(), model_fit.assignments)):
for document_idx, (tweet_id, topic_idx) in enumerate(zip(tweets_df["id"], model_fit.assignments)):
tweet = tweet_id_map[tweet_id]
relevance = model_fit.document_relevance[topic_idx][document_idx]
clustered_tweets[topic_idx].append((tweet, relevance))
......
......@@ -2,18 +2,20 @@ import re
from typing import Dict, List
import nltk
import pandas as pd
from nltk.corpus import stopwords
from pymemri.data.schema import Tweet
from sklearn.feature_extraction import text
def preprocess_tweets(tweets: List[Tweet]) -> Dict[str, str]:
tweets_processed = {tweet.id: preprocess_tweet(tweet) for tweet in tweets}
tweets_processed = filter_num_words(tweets_processed)
return tweets_processed
def preprocess_tweets(tweets: List[Tweet], min_num_words: int = 5) -> pd.DataFrame:
_preprocessed = [{"id": tweet.id, "message": preprocess_tweet(tweet)} for tweet in tweets]
tweets_df = pd.DataFrame.from_records(_preprocessed)
tweets_df = filter_num_words(tweets_df, "message", min_num_words)
return tweets_df
def filter_num_words(x: Dict[str, str], min_num_words: int=10):
return {k: v for k, v in x.items() if len(v.split(" ")) > 10}
def filter_num_words(df: pd.DataFrame, column: str, min_num_words: int = 5):
return df[df[column].apply(lambda x: len(x.split(" ")) >= min_num_words)]
def preprocess_tweet(tweet: Tweet):
tweet_content = tweet.message
......@@ -21,7 +23,3 @@ def preprocess_tweet(tweet: Tweet):
def _remove_urls(s: str):
return re.sub(r'http\S+', '', s)
def get_stopwords():
nltk.download("stopwords")
return text.ENGLISH_STOP_WORDS.union(stopwords.words("dutch"))
......@@ -10,6 +10,7 @@ class ClusterEntry(Item):
class Cluster(Item):
isUnassigned: Optional[bool] = None
supervisedLabel: Optional[str] = None
clusterDescription: Optional[str] = None
isRead: Optional[bool] = None
readLater: Optional[bool] = None
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment