File size: 4,087 Bytes
0eb636f
 
9d99321
 
0eb636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d99321
0eb636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d99321
 
 
 
 
 
 
 
0eb636f
 
 
 
 
 
 
 
 
 
 
9d99321
 
 
0eb636f
 
 
 
 
 
 
 
 
 
 
 
 
 
9d99321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eb636f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import re
import plotly
from umap import UMAP
from hdbscan import HDBSCAN
from bertopic import BERTopic
from collections import Counter
from src.utils.data_utils import tokeniser
from src.modelling.embed import DalaEmbedder
from sklearn.feature_extraction.text import CountVectorizer
from src.utils.plotting import custom_topic_barchart, custom_umap_plot

from typing import Dict, List, Tuple


class TopicModeller:
    """
    Wrapper for topic modelling with BERTopic.
    """
    def __init__(self):
        # Custom vectoriser with stopword filtering
        self.vectorizer_model = None
        self.model = None


    def _extract_dalat5_stopwords(self, texts: List[str], top_k: int = 75) -> List[str]:
        """
        Identify frequent tokens using DalaT5's tokeniser as proxy stopwords.
        """
        token_counter = Counter()

        for text in texts:
            token_ids = tokeniser.encode(text, add_special_tokens = False)
            token_counter.update(token_ids)

        most_common = token_counter.most_common(top_k)
        stop_tokens = [tokeniser.decode([tok_id]).strip() for tok_id, _ in most_common]

        return stop_tokens


    def _preprocess_texts(self, texts: List[str]) -> List[str]:
        """
        Lowercase and remove digits/symbols from texts.
        """
        return [
            re.sub(r"\d+|\s+", " ", t.lower()).strip()
            for t in texts
        ]
    

    def fit(
        self, 
        texts: List[str], 
        embeddings: List[List[float]]
    ) -> Tuple[List[str], plotly.graph_objs.Figure, Dict[int, str], plotly.graph_objs.Figure]:
        """
        Fit BERTopic on preprocessed texts and given embeddings.
        Returns topics and an interactive plot.
        """
        clean_texts = self._preprocess_texts(texts)

        # Compute a safe number of neighbours and clusters
        n_samples = len(embeddings)
        min_cluster_size = max(2, len(embeddings) // 2)
        safe_n_neighbours = min(15, max(2, n_samples - 1))

        # Create a UMAP model
        umap_model = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42)

        # Leverage DalaT5's tokeniser for stopword acquisition
        stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75)

        # Define vectoriser and model
        self.vectoriser_model = CountVectorizer(
            stop_words = stopwords,
            token_pattern = r"\b[a-zA-Z]+(?:-[a-zA-Z]+)?\b"
        )
        self.model = BERTopic(
            language = "multilingual", 
            vectorizer_model = self.vectoriser_model,
            embedding_model = DalaEmbedder().get_model(),
            umap_model = umap_model,
            hdbscan_model = HDBSCAN(min_cluster_size = min_cluster_size, min_samples = 1, cluster_selection_epsilon = 0.1)
        )

        topics, _ = self.model.fit_transform(clean_texts, embeddings)

        # Generate labels
        topic_info = self.model.get_topic_info()
        topic_labels = {}

        for topic_id in topic_info.Topic.values:
            if topic_id == -1:
                topic_labels[topic_id] = '-'

                continue

            topic_words = self.model.get_topic(topic_id)

            if not isinstance(topic_words, list) or len(topic_words) == 0:
                print(f"[WARN] Skipping label generation for topic_id={topic_id} - invalid topic")
                continue

            words = []

            for pair in topic_words[:4]:
                if isinstance(pair, (list, tuple)) and len(pair) >= 1:
                    words.append(pair[0])

            if not words:
                print(f"[WARN] No valid words found for topic_id = {topic_id}")

                continue

            label = "_".join(words)
            topic_labels[topic_id] = f"{topic_id}_{label}"

        fig = custom_topic_barchart(self.model, topic_labels)
        umap_fig = custom_umap_plot(embeddings, topics, topic_labels)
        labeled_topics = [topic_labels[t] for t in topics]

        return labeled_topics, fig, topic_labels, umap_fig