Spaces:
Sleeping
Sleeping
File size: 4,087 Bytes
0eb636f 9d99321 0eb636f 9d99321 0eb636f 9d99321 0eb636f 9d99321 0eb636f 9d99321 0eb636f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import re
import plotly
from umap import UMAP
from hdbscan import HDBSCAN
from bertopic import BERTopic
from collections import Counter
from src.utils.data_utils import tokeniser
from src.modelling.embed import DalaEmbedder
from sklearn.feature_extraction.text import CountVectorizer
from src.utils.plotting import custom_topic_barchart, custom_umap_plot
from typing import Dict, List, Tuple
class TopicModeller:
"""
Wrapper for topic modelling with BERTopic.
"""
def __init__(self):
# Custom vectoriser with stopword filtering
self.vectorizer_model = None
self.model = None
def _extract_dalat5_stopwords(self, texts: List[str], top_k: int = 75) -> List[str]:
"""
Identify frequent tokens using DalaT5's tokeniser as proxy stopwords.
"""
token_counter = Counter()
for text in texts:
token_ids = tokeniser.encode(text, add_special_tokens = False)
token_counter.update(token_ids)
most_common = token_counter.most_common(top_k)
stop_tokens = [tokeniser.decode([tok_id]).strip() for tok_id, _ in most_common]
return stop_tokens
def _preprocess_texts(self, texts: List[str]) -> List[str]:
"""
Lowercase and remove digits/symbols from texts.
"""
return [
re.sub(r"\d+|\s+", " ", t.lower()).strip()
for t in texts
]
def fit(
self,
texts: List[str],
embeddings: List[List[float]]
) -> Tuple[List[str], plotly.graph_objs.Figure, Dict[int, str], plotly.graph_objs.Figure]:
"""
Fit BERTopic on preprocessed texts and given embeddings.
Returns topics and an interactive plot.
"""
clean_texts = self._preprocess_texts(texts)
# Compute a safe number of neighbours and clusters
n_samples = len(embeddings)
min_cluster_size = max(2, len(embeddings) // 2)
safe_n_neighbours = min(15, max(2, n_samples - 1))
# Create a UMAP model
umap_model = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42)
# Leverage DalaT5's tokeniser for stopword acquisition
stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75)
# Define vectoriser and model
self.vectoriser_model = CountVectorizer(
stop_words = stopwords,
token_pattern = r"\b[a-zA-Z]+(?:-[a-zA-Z]+)?\b"
)
self.model = BERTopic(
language = "multilingual",
vectorizer_model = self.vectoriser_model,
embedding_model = DalaEmbedder().get_model(),
umap_model = umap_model,
hdbscan_model = HDBSCAN(min_cluster_size = min_cluster_size, min_samples = 1, cluster_selection_epsilon = 0.1)
)
topics, _ = self.model.fit_transform(clean_texts, embeddings)
# Generate labels
topic_info = self.model.get_topic_info()
topic_labels = {}
for topic_id in topic_info.Topic.values:
if topic_id == -1:
topic_labels[topic_id] = '-'
continue
topic_words = self.model.get_topic(topic_id)
if not isinstance(topic_words, list) or len(topic_words) == 0:
print(f"[WARN] Skipping label generation for topic_id={topic_id} - invalid topic")
continue
words = []
for pair in topic_words[:4]:
if isinstance(pair, (list, tuple)) and len(pair) >= 1:
words.append(pair[0])
if not words:
print(f"[WARN] No valid words found for topic_id = {topic_id}")
continue
label = "_".join(words)
topic_labels[topic_id] = f"{topic_id}_{label}"
fig = custom_topic_barchart(self.model, topic_labels)
umap_fig = custom_umap_plot(embeddings, topics, topic_labels)
labeled_topics = [topic_labels[t] for t in topics]
return labeled_topics, fig, topic_labels, umap_fig
|