File size: 3,026 Bytes
0eb636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import re
import plotly
from bertopic import BERTopic
from collections import Counter
from src.utils.data_utils import tokeniser
from src.modelling.embed import DalaEmbedder
from sklearn.feature_extraction.text import CountVectorizer
from src.utils.plotting import custom_topic_barchart, custom_umap_plot

from typing import Dict, List, Tuple


class TopicModeller:
    """
    Wrapper for topic modelling with BERTopic.
    """
    def __init__(self):
        # Custom vectoriser with stopword filtering
        self.vectorizer_model = None
        self.model = None


    def _extract_dalat5_stopwords(self, texts: List[str], top_k: int = 75) -> List[str]:
        """
        Identify frequent tokens using DalaT5's tokeniser as proxy stopwords.
        """
        token_counter = Counter()

        for text in texts:
            token_ids = tokeniser.encode(text, add_special_tokens=False)
            token_counter.update(token_ids)

        most_common = token_counter.most_common(top_k)
        stop_tokens = [tokeniser.decode([tok_id]).strip() for tok_id, _ in most_common]

        return stop_tokens


    def _preprocess_texts(self, texts: List[str]) -> List[str]:
        """
        Lowercase and remove digits/symbols from texts.
        """
        return [
            re.sub(r"\d+|\s+", " ", t.lower()).strip()
            for t in texts
        ]
    

    def fit(
        self, 
        texts: List[str], 
        embeddings: List[List[float]]
    ) -> Tuple[List[str], plotly.graph_objs.Figure, Dict[int, str], plotly.graph_objs.Figure]:
        """
        Fit BERTopic on preprocessed texts and given embeddings.
        Returns topics and an interactive plot.
        """
        clean_texts = self._preprocess_texts(texts)

        # Leverage DalaT5's tokeniser for stopword acquisition
        stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75)

        # Define vectoriser and model
        self.vectoriser_model = CountVectorizer(
            stop_words = stopwords,
            token_pattern = r"\b[a-zA-Z]+(?:-[a-zA-Z]+)?\b"
        )
        self.model = BERTopic(
            language = "multilingual", 
            vectorizer_model = self.vectoriser_model,
            embedding_model = DalaEmbedder().get_model()
        )

        topics, _ = self.model.fit_transform(clean_texts, embeddings)

        # Generate labels
        topic_info = self.model.get_topic_info()
        topic_labels = {}

        for topic_id in topic_info.Topic.values:
            if topic_id == -1:
                topic_labels[topic_id] = '-'

                continue

            words = [word for word, _ in self.model.get_topic(topic_id)[:4]]
            label = "_".join(words)
            topic_labels[topic_id] = f"{topic_id}_{label}"

        fig = custom_topic_barchart(self.model, topic_labels)
        umap_fig = custom_umap_plot(embeddings, topics, topic_labels)
        labeled_topics = [topic_labels[t] for t in topics]

        return labeled_topics, fig, topic_labels, umap_fig