semanticdala / src /modelling /topic_model.py
crossroderick's picture
Major clustering update
9d99321
raw
history blame
4.09 kB
import re
import plotly
from umap import UMAP
from hdbscan import HDBSCAN
from bertopic import BERTopic
from collections import Counter
from src.utils.data_utils import tokeniser
from src.modelling.embed import DalaEmbedder
from sklearn.feature_extraction.text import CountVectorizer
from src.utils.plotting import custom_topic_barchart, custom_umap_plot
from typing import Dict, List, Tuple
class TopicModeller:
"""
Wrapper for topic modelling with BERTopic.
"""
def __init__(self):
# Custom vectoriser with stopword filtering
self.vectorizer_model = None
self.model = None
def _extract_dalat5_stopwords(self, texts: List[str], top_k: int = 75) -> List[str]:
"""
Identify frequent tokens using DalaT5's tokeniser as proxy stopwords.
"""
token_counter = Counter()
for text in texts:
token_ids = tokeniser.encode(text, add_special_tokens = False)
token_counter.update(token_ids)
most_common = token_counter.most_common(top_k)
stop_tokens = [tokeniser.decode([tok_id]).strip() for tok_id, _ in most_common]
return stop_tokens
def _preprocess_texts(self, texts: List[str]) -> List[str]:
"""
Lowercase and remove digits/symbols from texts.
"""
return [
re.sub(r"\d+|\s+", " ", t.lower()).strip()
for t in texts
]
def fit(
self,
texts: List[str],
embeddings: List[List[float]]
) -> Tuple[List[str], plotly.graph_objs.Figure, Dict[int, str], plotly.graph_objs.Figure]:
"""
Fit BERTopic on preprocessed texts and given embeddings.
Returns topics and an interactive plot.
"""
clean_texts = self._preprocess_texts(texts)
# Compute a safe number of neighbours and clusters
n_samples = len(embeddings)
min_cluster_size = max(2, len(embeddings) // 2)
safe_n_neighbours = min(15, max(2, n_samples - 1))
# Create a UMAP model
umap_model = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42)
# Leverage DalaT5's tokeniser for stopword acquisition
stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75)
# Define vectoriser and model
self.vectoriser_model = CountVectorizer(
stop_words = stopwords,
token_pattern = r"\b[a-zA-Z]+(?:-[a-zA-Z]+)?\b"
)
self.model = BERTopic(
language = "multilingual",
vectorizer_model = self.vectoriser_model,
embedding_model = DalaEmbedder().get_model(),
umap_model = umap_model,
hdbscan_model = HDBSCAN(min_cluster_size = min_cluster_size, min_samples = 1, cluster_selection_epsilon = 0.1)
)
topics, _ = self.model.fit_transform(clean_texts, embeddings)
# Generate labels
topic_info = self.model.get_topic_info()
topic_labels = {}
for topic_id in topic_info.Topic.values:
if topic_id == -1:
topic_labels[topic_id] = '-'
continue
topic_words = self.model.get_topic(topic_id)
if not isinstance(topic_words, list) or len(topic_words) == 0:
print(f"[WARN] Skipping label generation for topic_id={topic_id} - invalid topic")
continue
words = []
for pair in topic_words[:4]:
if isinstance(pair, (list, tuple)) and len(pair) >= 1:
words.append(pair[0])
if not words:
print(f"[WARN] No valid words found for topic_id = {topic_id}")
continue
label = "_".join(words)
topic_labels[topic_id] = f"{topic_id}_{label}"
fig = custom_topic_barchart(self.model, topic_labels)
umap_fig = custom_umap_plot(embeddings, topics, topic_labels)
labeled_topics = [topic_labels[t] for t in topics]
return labeled_topics, fig, topic_labels, umap_fig