Spaces:
Sleeping
Sleeping
File size: 3,026 Bytes
0eb636f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import re
import plotly
from bertopic import BERTopic
from collections import Counter
from src.utils.data_utils import tokeniser
from src.modelling.embed import DalaEmbedder
from sklearn.feature_extraction.text import CountVectorizer
from src.utils.plotting import custom_topic_barchart, custom_umap_plot
from typing import Dict, List, Tuple
class TopicModeller:
"""
Wrapper for topic modelling with BERTopic.
"""
def __init__(self):
# Custom vectoriser with stopword filtering
self.vectorizer_model = None
self.model = None
def _extract_dalat5_stopwords(self, texts: List[str], top_k: int = 75) -> List[str]:
"""
Identify frequent tokens using DalaT5's tokeniser as proxy stopwords.
"""
token_counter = Counter()
for text in texts:
token_ids = tokeniser.encode(text, add_special_tokens=False)
token_counter.update(token_ids)
most_common = token_counter.most_common(top_k)
stop_tokens = [tokeniser.decode([tok_id]).strip() for tok_id, _ in most_common]
return stop_tokens
def _preprocess_texts(self, texts: List[str]) -> List[str]:
"""
Lowercase and remove digits/symbols from texts.
"""
return [
re.sub(r"\d+|\s+", " ", t.lower()).strip()
for t in texts
]
def fit(
self,
texts: List[str],
embeddings: List[List[float]]
) -> Tuple[List[str], plotly.graph_objs.Figure, Dict[int, str], plotly.graph_objs.Figure]:
"""
Fit BERTopic on preprocessed texts and given embeddings.
Returns topics and an interactive plot.
"""
clean_texts = self._preprocess_texts(texts)
# Leverage DalaT5's tokeniser for stopword acquisition
stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75)
# Define vectoriser and model
self.vectoriser_model = CountVectorizer(
stop_words = stopwords,
token_pattern = r"\b[a-zA-Z]+(?:-[a-zA-Z]+)?\b"
)
self.model = BERTopic(
language = "multilingual",
vectorizer_model = self.vectoriser_model,
embedding_model = DalaEmbedder().get_model()
)
topics, _ = self.model.fit_transform(clean_texts, embeddings)
# Generate labels
topic_info = self.model.get_topic_info()
topic_labels = {}
for topic_id in topic_info.Topic.values:
if topic_id == -1:
topic_labels[topic_id] = '-'
continue
words = [word for word, _ in self.model.get_topic(topic_id)[:4]]
label = "_".join(words)
topic_labels[topic_id] = f"{topic_id}_{label}"
fig = custom_topic_barchart(self.model, topic_labels)
umap_fig = custom_umap_plot(embeddings, topics, topic_labels)
labeled_topics = [topic_labels[t] for t in topics]
return labeled_topics, fig, topic_labels, umap_fig
|