Commit b2d98d3b authored by Aziz Berkay Yesilyurt's avatar Aziz Berkay Yesilyurt
Browse files

Merge branch 'dev' into 'uat'

reduce min tweets for cluster

See merge request !28
parents c9ff242c 5c56ea68
Pipeline #13668 passed with stages
in 15 minutes and 39 seconds
Showing with 43 additions and 149 deletions
+43 -149
......@@ -18,14 +18,12 @@ install_requires =
pymemri==0.0.45
torch==1.13.1
numpy>=1.20.0
umap-learn>=0.5.0
hdbscan>=0.8.29
sentence-transformers>=0.4.1
sentence-transformers>=0.4.1 # sentence_transformers
tqdm>=4.41.1
scikit-learn==1.1.3
wget==3.2
fasttext-langdetect==1.0.3
keybert
fasttext-wheel==0.9.2 # fasttext-langdetect depends on this
fasttext-langdetect==1.0.3 # ftlangdetect
wget==3.2 # ftlangdetect depends on this
pytest
[options.extras_require]
......
......@@ -10,7 +10,7 @@ from pymemri.pod.client import PodClient
from pymemri.pod.graphql_utils import GQLQuery
from twitter_topic_model.model import TopicModelCluster
from twitter_topic_model.plugin import NUM_TOPIC_DESCRIPTORS, TwitterTopicModelPlugin
from twitter_topic_model.plugin import TwitterTopicModelPlugin
from twitter_topic_model.postprocessing import (
filter_min_accounts_per_cluster,
get_unique_cluster_account_ids,
......@@ -58,7 +58,9 @@ def test_fit_model_on_tweets(tweets: List[Tweet]) -> None:
plugin = TwitterTopicModelPlugin(client=client)
tweets = tweets[:250]
clusters, unassigned_tweets = plugin.fit_topic_model(tweets, prediction_threshold=0.6)
clusters, unassigned_tweets = plugin.fit_topic_model(
tweets, prediction_threshold=0.6
)
# for cluster in clusters:
# print(cluster.coarse_label, cluster.fine_label, max(cluster.tweet_relevance))
......@@ -69,10 +71,24 @@ def test_fit_model_on_tweets(tweets: List[Tweet]) -> None:
# print("-------")
def test_fit_model_on_2_tweets(tweets: List[Tweet]) -> None:
client = PodClient()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plugin = TwitterTopicModelPlugin(client=client)
tweets = tweets[:2]
clusters, unassigned_tweets = plugin.fit_topic_model(
tweets, prediction_threshold=0.1
)
assert clusters
def test_fit_model_on_pod(tweets: List[Tweet], topic_model_query: GQLQuery) -> None:
# Client with Tweets
client = PodClient()
client.add_to_schema(Tweet, Cluster, ClusterEntry, TwitterTopicModel)
tweets = tweets[:2]
client.bulk_action(create_items=tweets)
# Load Plugin and fit model
......
......@@ -53,7 +53,8 @@ class TweetTopicClassifier(nn.Module):
def load_classifier_head(model_path, device="cpu"):
chkpt = torch.load(model_path, map_location=device)
chkpt = torch.load(model_path, map_location="cpu")
hparams = chkpt["hyper_parameters"]
state_dict = chkpt["state_dict"]
# training script append `classifier.*` to all params, removed here.
......
import numpy as np
from hdbscan import HDBSCAN
from umap import UMAP
def compute_umap(embeddings: np.ndarray) -> np.ndarray:
umap = UMAP(
n_neighbors=15,
n_components=5,
min_dist=0.0,
metric='cosine',
low_memory=False,
random_state=42,
)
return umap.fit_transform(embeddings)
def cluster_embeddings(embeddings: np.ndarray, min_cluster_size: int = 3, cluster_selection_epsilon: float = 0.0) -> np.ndarray:
hdbscan = HDBSCAN(
min_cluster_size=min_cluster_size,
metric='euclidean',
cluster_selection_method='eom',
prediction_data=False,
cluster_selection_epsilon=cluster_selection_epsilon,
)
return hdbscan.fit_predict(embeddings)
......@@ -3,17 +3,12 @@ from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np
import torch
from keybert import KeyBERT
from keybert.backend._sentencetransformers import SentenceTransformerBackend
from pymemri.data.schema import Tweet
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from twitter_topic_model.coarse_labels import COARSE_CATEGORY_MAPPING
from .classifier_fine_head import load_classifier_head
from .clustering import cluster_embeddings, compute_umap
from .preprocessing import PreprocessedTweet
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
......@@ -46,7 +41,6 @@ class TopicModel:
"assets/i9hajgp7.ckpt",
self.device,
)
self.keybert = KeyBERT(model=SentenceTransformerBackend(self.embedding_model))
def _classify_supervised(self, embeddings: np.ndarray, prediction_threshold: float = 0.5) -> List[List[str]]:
return self.classifier_fine_head.predict(embeddings, pred_threshold=prediction_threshold)
......@@ -110,59 +104,6 @@ class TopicModel:
return [tweet for tweet in tweets if tweet.id not in assigned_ids]
def cluster_unsupervised(
self,
tweets_subset: List[PreprocessedTweet],
all_embeddings: np.ndarray,
all_umap_embeddings: np.ndarray,
supervised_label: Optional[str] = None,
diversity: float = 0.01,
ngram_range=(1, 1),
num_topic_descriptors=3,
min_cluster_size=4,
cluster_selection_epsilon=0.0,
):
# NOTE embeddings and umap embeddings for all tweets are passed, `tweets_subset` parameter could be a subset of that
# Always use `self.get_embeddings_for_tweets`
embeddings_subset = self.get_embeddings_for_tweets(
tweets_subset, all_umap_embeddings
)
assignments = cluster_embeddings(
embeddings_subset,
min_cluster_size=min_cluster_size,
cluster_selection_epsilon=cluster_selection_epsilon,
)
tweets_per_cluster = self.get_tweets_per_cluster(tweets_subset, assignments)
clusters = []
for cluster_idx, cluster_tweets in tweets_per_cluster.items():
if cluster_idx == -1:
# hdbscan idx == -1 means there is no assignment.
continue
embeddings_cluster = self.get_embeddings_for_tweets(
cluster_tweets, all_embeddings
)
unsupervised_labels, topic_embedding = self.get_topic_description(
cluster_tweets,
embeddings_cluster,
num_descriptions=num_topic_descriptors,
ngram_range=ngram_range,
diversity=diversity,
)
relevances = self.get_document_relevance(
topic_embedding, embeddings_cluster
)
cluster = TopicModelCluster(
supervised_label=supervised_label,
unsupervised_labels=unsupervised_labels,
tweets=cluster_tweets,
tweet_relevance=relevances,
)
clusters.append(cluster)
return clusters
@staticmethod
def get_embeddings_for_tweets(
tweets: List[PreprocessedTweet], all_embeddings: np.ndarray
......@@ -184,49 +125,6 @@ class TopicModel:
tweets_clustered[assignment].append(tweet)
return tweets_clustered
def get_document_relevance(
self, topic_embedding: np.ndarray, document_embeddings: np.ndarray
) -> Dict[int, np.ndarray]:
"""Returns relevance (cosine similarity) for each document wrt the topic.
Args:
topic_model (BERTopic): a topic embedding, shape (D,)
document_embeddings (np.ndarray): Array of document embeddings, shape (N, D)
Returns:
np.ndarray: Document relevance, shape (N,)
"""
topic_embedding = np.expand_dims(topic_embedding, 0)
similarities = cosine_similarity(topic_embedding, document_embeddings)
return similarities.squeeze(0)
def get_topic_description(
self,
tweets: List[PreprocessedTweet],
tweet_embeddings: List[np.ndarray],
diversity: float = 0.3,
num_descriptions: Optional[int] = None,
ngram_range=(1, 1),
) -> List[str]:
doc = self._concat_tweets(tweets)
avg_embedding = np.mean(tweet_embeddings, axis=0)
keybert = KeyBERT(model=SentenceTransformerBackend(self.embedding_model))
res = keybert.extract_keywords(
doc,
top_n=num_descriptions,
keyphrase_ngram_range=ngram_range,
stop_words="english",
min_df=1,
doc_embeddings=[avg_embedding],
use_mmr=True,
diversity=diversity,
)
keywords = [r[0] for r in res]
topic_embedding = self.embed_topic_description(keywords)
return keywords, topic_embedding
def embed_topic_description(self, unsupervised_labels: List[str]) -> np.ndarray:
query = ", ".join(unsupervised_labels)
return self.embedding_model.encode(query)
......
......@@ -20,11 +20,6 @@ from .preprocessing import PreprocessedTweet, preprocess_tweets
from .schema import Cluster, ClusterEntry, TwitterTopicModel
from .utils import get_tweets
MIN_TOPIC_SIZE = 4
NUM_TOPIC_DESCRIPTORS = 3
DESCRIPTION_DIVERSITY = 0.3
CLUSTER_MERGE_DISTANCE = 0.6
class TwitterTopicModelPlugin(PluginBase):
schema_classes = [Cluster, ClusterEntry, TwitterTopicModel, Tweet]
......@@ -33,7 +28,7 @@ class TwitterTopicModelPlugin(PluginBase):
self,
pluginRun=None,
client=None,
min_cluster_tweets=10,
min_cluster_tweets=2,
run_on_start=False,
**kwargs,
):
......@@ -178,7 +173,10 @@ class TwitterTopicModelPlugin(PluginBase):
str: item id of created topic model
"""
if not 0 < prediction_threshold < 1:
raise HTTPException(status_code=400, detail="`prediction_threshold` should be between 0 and 1")
raise HTTPException(
status_code=400,
detail="`prediction_threshold` should be between 0 and 1",
)
client = self.get_client_from_kwargs(pod_client_kwargs)
client.add_to_schema(*self.schema_classes)
......@@ -186,17 +184,25 @@ class TwitterTopicModelPlugin(PluginBase):
if len(tweets) < self.min_cluster_tweets:
raise HTTPException(
status_code=400,
detail=f"Not enough tweets in pod ({len(tweets)}) to cluster"
detail=f"Not enough tweets in pod ({len(tweets)}) to cluster",
)
clusters, unassigned_tweets = self.fit_topic_model(tweets, prediction_threshold=prediction_threshold)
clusters, unassigned_tweets = self.fit_topic_model(
tweets, prediction_threshold=prediction_threshold
)
# Filters
if exclude_read_tweets:
clusters, unassigned_tweets = filter_read_tweets(clusters, unassigned_tweets, client)
clusters, unassigned_tweets = filter_read_tweets(
clusters, unassigned_tweets, client
)
if min_accounts_per_cluster > 1:
clusters = filter_min_accounts_per_cluster(clusters, min_accounts_per_cluster)
clusters = filter_min_accounts_per_cluster(
clusters, min_accounts_per_cluster
)
logger.debug(f"Creating {len(clusters)} clusters with {len(unassigned_tweets)} unassigned tweets")
logger.debug(
f"Creating {len(clusters)} clusters with {len(unassigned_tweets)} unassigned tweets"
)
topic_model_id = self.sync_to_pod(
clusters, unassigned_tweets, start_date, client
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment