import re import plotly from umap import UMAP from hdbscan import HDBSCAN from bertopic import BERTopic from collections import Counter from src.utils.data_utils import tokeniser from src.modelling.embed import DalaEmbedder from sklearn.feature_extraction.text import CountVectorizer from src.utils.plotting import custom_topic_barchart, custom_umap_plot from typing import Dict, List, Tuple class TopicModeller: """ Wrapper for topic modelling with BERTopic. """ def __init__(self): # Custom vectoriser with stopword filtering self.vectorizer_model = None self.model = None def _extract_dalat5_stopwords(self, texts: List[str], top_k: int = 75) -> List[str]: """ Identify frequent tokens using DalaT5's tokeniser as proxy stopwords. """ token_counter = Counter() for text in texts: token_ids = tokeniser.encode(text, add_special_tokens = False) token_counter.update(token_ids) most_common = token_counter.most_common(top_k) stop_tokens = [tokeniser.decode([tok_id]).strip() for tok_id, _ in most_common] return stop_tokens def _preprocess_texts(self, texts: List[str]) -> List[str]: """ Lowercase and remove digits/symbols from texts. """ return [ re.sub(r"\d+|\s+", " ", t.lower()).strip() for t in texts ] def fit( self, texts: List[str], embeddings: List[List[float]] ) -> Tuple[List[str], plotly.graph_objs.Figure, Dict[int, str], plotly.graph_objs.Figure]: """ Fit BERTopic on preprocessed texts and given embeddings. Returns topics and an interactive plot. """ clean_texts = self._preprocess_texts(texts) # Compute a safe number of neighbours and clusters n_samples = len(embeddings) min_cluster_size = max(2, len(embeddings) // 2) safe_n_neighbours = min(15, max(2, n_samples - 1)) # Create a UMAP model umap_model = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42) # Leverage DalaT5's tokeniser for stopword acquisition stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75) # Define vectoriser and model self.vectoriser_model = CountVectorizer( stop_words = stopwords, token_pattern = r"\b[a-zA-Z]+(?:-[a-zA-Z]+)?\b" ) self.model = BERTopic( language = "multilingual", vectorizer_model = self.vectoriser_model, embedding_model = DalaEmbedder().get_model(), umap_model = umap_model, hdbscan_model = HDBSCAN(min_cluster_size = min_cluster_size, min_samples = 1, cluster_selection_epsilon = 0.1) ) topics, _ = self.model.fit_transform(clean_texts, embeddings) # Generate labels topic_info = self.model.get_topic_info() topic_labels = {} for topic_id in topic_info.Topic.values: if topic_id == -1: topic_labels[topic_id] = '-' continue topic_words = self.model.get_topic(topic_id) if not isinstance(topic_words, list) or len(topic_words) == 0: print(f"[WARN] Skipping label generation for topic_id={topic_id} - invalid topic") continue words = [] for pair in topic_words[:4]: if isinstance(pair, (list, tuple)) and len(pair) >= 1: words.append(pair[0]) if not words: print(f"[WARN] No valid words found for topic_id = {topic_id}") continue label = "_".join(words) topic_labels[topic_id] = f"{topic_id}_{label}" fig = custom_topic_barchart(self.model, topic_labels) umap_fig = custom_umap_plot(embeddings, topics, topic_labels) labeled_topics = [topic_labels[t] for t in topics] return labeled_topics, fig, topic_labels, umap_fig