Spaces:
Sleeping
Sleeping
import re | |
import plotly | |
from umap import UMAP | |
from hdbscan import HDBSCAN | |
from bertopic import BERTopic | |
from collections import Counter | |
from src.utils.data_utils import tokeniser | |
from src.modelling.embed import DalaEmbedder | |
from sklearn.feature_extraction.text import CountVectorizer | |
from src.utils.plotting import custom_topic_barchart, custom_umap_plot | |
from typing import Dict, List, Tuple | |
class TopicModeller: | |
""" | |
Wrapper for topic modelling with BERTopic. | |
""" | |
def __init__(self): | |
# Custom vectoriser with stopword filtering | |
self.vectorizer_model = None | |
self.model = None | |
def _extract_dalat5_stopwords(self, texts: List[str], top_k: int = 75) -> List[str]: | |
""" | |
Identify frequent tokens using DalaT5's tokeniser as proxy stopwords. | |
""" | |
token_counter = Counter() | |
for text in texts: | |
token_ids = tokeniser.encode(text, add_special_tokens = False) | |
token_counter.update(token_ids) | |
most_common = token_counter.most_common(top_k) | |
stop_tokens = [tokeniser.decode([tok_id]).strip() for tok_id, _ in most_common] | |
return stop_tokens | |
def _preprocess_texts(self, texts: List[str]) -> List[str]: | |
""" | |
Lowercase and remove digits/symbols from texts. | |
""" | |
return [ | |
re.sub(r"\d+|\s+", " ", t.lower()).strip() | |
for t in texts | |
] | |
def fit( | |
self, | |
texts: List[str], | |
embeddings: List[List[float]] | |
) -> Tuple[List[str], plotly.graph_objs.Figure, Dict[int, str], plotly.graph_objs.Figure]: | |
""" | |
Fit BERTopic on preprocessed texts and given embeddings. | |
Returns topics and an interactive plot. | |
""" | |
clean_texts = self._preprocess_texts(texts) | |
# Compute a safe number of neighbours and clusters | |
n_samples = len(embeddings) | |
min_cluster_size = max(2, len(embeddings) // 2) | |
safe_n_neighbours = min(15, max(2, n_samples - 1)) | |
# Create a UMAP model | |
umap_model = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42) | |
# Leverage DalaT5's tokeniser for stopword acquisition | |
stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75) | |
# Define vectoriser and model | |
self.vectoriser_model = CountVectorizer( | |
stop_words = stopwords, | |
token_pattern = r"\b[a-zA-Z]+(?:-[a-zA-Z]+)?\b" | |
) | |
self.model = BERTopic( | |
language = "multilingual", | |
vectorizer_model = self.vectoriser_model, | |
embedding_model = DalaEmbedder().get_model(), | |
umap_model = umap_model, | |
hdbscan_model = HDBSCAN(min_cluster_size = min_cluster_size, min_samples = 1, cluster_selection_epsilon = 0.1) | |
) | |
topics, _ = self.model.fit_transform(clean_texts, embeddings) | |
# Generate labels | |
topic_info = self.model.get_topic_info() | |
topic_labels = {} | |
for topic_id in topic_info.Topic.values: | |
if topic_id == -1: | |
topic_labels[topic_id] = '-' | |
continue | |
topic_words = self.model.get_topic(topic_id) | |
if not isinstance(topic_words, list) or len(topic_words) == 0: | |
print(f"[WARN] Skipping label generation for topic_id={topic_id} - invalid topic") | |
continue | |
words = [] | |
for pair in topic_words[:4]: | |
if isinstance(pair, (list, tuple)) and len(pair) >= 1: | |
words.append(pair[0]) | |
if not words: | |
print(f"[WARN] No valid words found for topic_id = {topic_id}") | |
continue | |
label = "_".join(words) | |
topic_labels[topic_id] = f"{topic_id}_{label}" | |
fig = custom_topic_barchart(self.model, topic_labels) | |
umap_fig = custom_umap_plot(embeddings, topics, topic_labels) | |
labeled_topics = [topic_labels[t] for t in topics] | |
return labeled_topics, fig, topic_labels, umap_fig | |