semanticdala / src /modelling /topic_model.py
crossroderick's picture
Added all files
0eb636f
raw
history blame
3.03 kB
import re
import plotly
from bertopic import BERTopic
from collections import Counter
from src.utils.data_utils import tokeniser
from src.modelling.embed import DalaEmbedder
from sklearn.feature_extraction.text import CountVectorizer
from src.utils.plotting import custom_topic_barchart, custom_umap_plot
from typing import Dict, List, Tuple
class TopicModeller:
"""
Wrapper for topic modelling with BERTopic.
"""
def __init__(self):
# Custom vectoriser with stopword filtering
self.vectorizer_model = None
self.model = None
def _extract_dalat5_stopwords(self, texts: List[str], top_k: int = 75) -> List[str]:
"""
Identify frequent tokens using DalaT5's tokeniser as proxy stopwords.
"""
token_counter = Counter()
for text in texts:
token_ids = tokeniser.encode(text, add_special_tokens=False)
token_counter.update(token_ids)
most_common = token_counter.most_common(top_k)
stop_tokens = [tokeniser.decode([tok_id]).strip() for tok_id, _ in most_common]
return stop_tokens
def _preprocess_texts(self, texts: List[str]) -> List[str]:
"""
Lowercase and remove digits/symbols from texts.
"""
return [
re.sub(r"\d+|\s+", " ", t.lower()).strip()
for t in texts
]
def fit(
self,
texts: List[str],
embeddings: List[List[float]]
) -> Tuple[List[str], plotly.graph_objs.Figure, Dict[int, str], plotly.graph_objs.Figure]:
"""
Fit BERTopic on preprocessed texts and given embeddings.
Returns topics and an interactive plot.
"""
clean_texts = self._preprocess_texts(texts)
# Leverage DalaT5's tokeniser for stopword acquisition
stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75)
# Define vectoriser and model
self.vectoriser_model = CountVectorizer(
stop_words = stopwords,
token_pattern = r"\b[a-zA-Z]+(?:-[a-zA-Z]+)?\b"
)
self.model = BERTopic(
language = "multilingual",
vectorizer_model = self.vectoriser_model,
embedding_model = DalaEmbedder().get_model()
)
topics, _ = self.model.fit_transform(clean_texts, embeddings)
# Generate labels
topic_info = self.model.get_topic_info()
topic_labels = {}
for topic_id in topic_info.Topic.values:
if topic_id == -1:
topic_labels[topic_id] = '-'
continue
words = [word for word, _ in self.model.get_topic(topic_id)[:4]]
label = "_".join(words)
topic_labels[topic_id] = f"{topic_id}_{label}"
fig = custom_topic_barchart(self.model, topic_labels)
umap_fig = custom_umap_plot(embeddings, topics, topic_labels)
labeled_topics = [topic_labels[t] for t in topics]
return labeled_topics, fig, topic_labels, umap_fig