crossroderick commited on
Commit
9d99321
·
1 Parent(s): 724dcb9

Major clustering update

Browse files
app.py CHANGED
@@ -125,13 +125,13 @@ def process_file(file: Any) -> Tuple[List[Tuple[str, int]], Any, Any]:
125
  vector_db.add(embeddings, metadata)
126
 
127
  # Topic modelling
128
- topics, fig, topic_labels, umap_fig = topic_modeller.fit(translits, embeddings)
129
 
130
  # Get a list of rows for topic labels
131
  overview_table = [[k, v] for k, v in topic_labels.items()]
132
 
133
  # Zip back transliterated text with topic IDs
134
- annotated = list(zip(translits, topics))
135
 
136
  # Log success
137
  log_submission(file.name, len(chunks), start, status = "success")
 
125
  vector_db.add(embeddings, metadata)
126
 
127
  # Topic modelling
128
+ topics, fig, topic_labels, umap_fig = topic_modeller.fit(dedup_translits, embeddings)
129
 
130
  # Get a list of rows for topic labels
131
  overview_table = [[k, v] for k, v in topic_labels.items()]
132
 
133
  # Zip back transliterated text with topic IDs
134
+ annotated = list(zip(dedup_translits, topics))
135
 
136
  # Log success
137
  log_submission(file.name, len(chunks), start, status = "success")
src/modelling/__pycache__/topic_model.cpython-312.pyc CHANGED
Binary files a/src/modelling/__pycache__/topic_model.cpython-312.pyc and b/src/modelling/__pycache__/topic_model.cpython-312.pyc differ
 
src/modelling/topic_model.py CHANGED
@@ -1,5 +1,7 @@
1
  import re
2
  import plotly
 
 
3
  from bertopic import BERTopic
4
  from collections import Counter
5
  from src.utils.data_utils import tokeniser
@@ -27,7 +29,7 @@ class TopicModeller:
27
  token_counter = Counter()
28
 
29
  for text in texts:
30
- token_ids = tokeniser.encode(text, add_special_tokens=False)
31
  token_counter.update(token_ids)
32
 
33
  most_common = token_counter.most_common(top_k)
@@ -57,6 +59,14 @@ class TopicModeller:
57
  """
58
  clean_texts = self._preprocess_texts(texts)
59
 
 
 
 
 
 
 
 
 
60
  # Leverage DalaT5's tokeniser for stopword acquisition
61
  stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75)
62
 
@@ -68,7 +78,9 @@ class TopicModeller:
68
  self.model = BERTopic(
69
  language = "multilingual",
70
  vectorizer_model = self.vectoriser_model,
71
- embedding_model = DalaEmbedder().get_model()
 
 
72
  )
73
 
74
  topics, _ = self.model.fit_transform(clean_texts, embeddings)
@@ -83,7 +95,23 @@ class TopicModeller:
83
 
84
  continue
85
 
86
- words = [word for word, _ in self.model.get_topic(topic_id)[:4]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  label = "_".join(words)
88
  topic_labels[topic_id] = f"{topic_id}_{label}"
89
 
 
1
  import re
2
  import plotly
3
+ from umap import UMAP
4
+ from hdbscan import HDBSCAN
5
  from bertopic import BERTopic
6
  from collections import Counter
7
  from src.utils.data_utils import tokeniser
 
29
  token_counter = Counter()
30
 
31
  for text in texts:
32
+ token_ids = tokeniser.encode(text, add_special_tokens = False)
33
  token_counter.update(token_ids)
34
 
35
  most_common = token_counter.most_common(top_k)
 
59
  """
60
  clean_texts = self._preprocess_texts(texts)
61
 
62
+ # Compute a safe number of neighbours and clusters
63
+ n_samples = len(embeddings)
64
+ min_cluster_size = max(2, len(embeddings) // 2)
65
+ safe_n_neighbours = min(15, max(2, n_samples - 1))
66
+
67
+ # Create a UMAP model
68
+ umap_model = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42)
69
+
70
  # Leverage DalaT5's tokeniser for stopword acquisition
71
  stopwords = self._extract_dalat5_stopwords(clean_texts, top_k = 75)
72
 
 
78
  self.model = BERTopic(
79
  language = "multilingual",
80
  vectorizer_model = self.vectoriser_model,
81
+ embedding_model = DalaEmbedder().get_model(),
82
+ umap_model = umap_model,
83
+ hdbscan_model = HDBSCAN(min_cluster_size = min_cluster_size, min_samples = 1, cluster_selection_epsilon = 0.1)
84
  )
85
 
86
  topics, _ = self.model.fit_transform(clean_texts, embeddings)
 
95
 
96
  continue
97
 
98
+ topic_words = self.model.get_topic(topic_id)
99
+
100
+ if not isinstance(topic_words, list) or len(topic_words) == 0:
101
+ print(f"[WARN] Skipping label generation for topic_id={topic_id} - invalid topic")
102
+ continue
103
+
104
+ words = []
105
+
106
+ for pair in topic_words[:4]:
107
+ if isinstance(pair, (list, tuple)) and len(pair) >= 1:
108
+ words.append(pair[0])
109
+
110
+ if not words:
111
+ print(f"[WARN] No valid words found for topic_id = {topic_id}")
112
+
113
+ continue
114
+
115
  label = "_".join(words)
116
  topic_labels[topic_id] = f"{topic_id}_{label}"
117
 
src/utils/__pycache__/plotting.cpython-312.pyc CHANGED
Binary files a/src/utils/__pycache__/plotting.cpython-312.pyc and b/src/utils/__pycache__/plotting.cpython-312.pyc differ
 
src/utils/plotting.py CHANGED
@@ -17,37 +17,56 @@ def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_t
17
  if topic_id == -1:
18
  continue
19
 
20
- for word, score in model.get_topic(topic_id)[:n_words]:
 
 
 
 
 
 
 
21
  data.append({"Topic": label, "Word": word, "Score": score})
22
 
 
 
 
 
 
23
  df = pd.DataFrame(data)
24
 
 
 
 
 
 
25
  fig = px.bar(
26
  df,
27
- x = "Score",
28
- y = "Word",
29
- color = "Topic",
30
- orientation = 'h',
31
- barmode = "group",
32
- #height = 500,
33
  )
34
 
35
  fig.update_layout(
36
- margin = dict(l = 40, r = 20, t = 40, b = 20),
37
- yaxis = dict(title = ""),
38
- xaxis = dict(title = "Relevance"),
39
- legend_title_text = "Topic",
40
  )
41
 
42
  return fig
43
 
44
 
45
-
46
  def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure:
47
  """
48
  Custom UMAP plotting to work better with the Gradio layout.
49
  """
50
- reducer = UMAP(n_neighbors = 15, min_dist = 0.1, metric = "cosine", random_state = 42)
 
 
 
 
51
  umap_coords = reducer.fit_transform(embeddings)
52
 
53
  df = pd.DataFrame(umap_coords, columns=["x", "y"])
 
17
  if topic_id == -1:
18
  continue
19
 
20
+ topic = model.get_topic(topic_id)
21
+ if not isinstance(topic, list) or len(topic) == 0:
22
+ continue
23
+
24
+ for pair in topic[:n_words]:
25
+ if not isinstance(pair, (list, tuple)) or len(pair) != 2:
26
+ continue
27
+ word, score = pair
28
  data.append({"Topic": label, "Word": word, "Score": score})
29
 
30
+ # ✅ Construct only if data exists
31
+ if not data:
32
+ print("[WARN] No topic-word-score data to visualize.")
33
+ return plotly.graph_objs.Figure()
34
+
35
  df = pd.DataFrame(data)
36
 
37
+ required_cols = {"Topic", "Word", "Score"}
38
+ if not required_cols.issubset(df.columns):
39
+ print("[ERROR] Required columns missing in DataFrame.")
40
+ return plotly.graph_objs.Figure()
41
+
42
  fig = px.bar(
43
  df,
44
+ x="Score",
45
+ y="Word",
46
+ color="Topic",
47
+ orientation='h',
48
+ barmode="group",
 
49
  )
50
 
51
  fig.update_layout(
52
+ margin=dict(l=40, r=20, t=40, b=20),
53
+ yaxis=dict(title=""),
54
+ xaxis=dict(title="Relevance"),
55
+ legend_title_text="Topic",
56
  )
57
 
58
  return fig
59
 
60
 
 
61
  def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure:
62
  """
63
  Custom UMAP plotting to work better with the Gradio layout.
64
  """
65
+ # Compute a safe number of neighbours
66
+ n_samples = len(embeddings)
67
+ safe_n_neighbours = min(15, max(2, n_samples - 1))
68
+
69
+ reducer = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42)
70
  umap_coords = reducer.fit_transform(embeddings)
71
 
72
  df = pd.DataFrame(umap_coords, columns=["x", "y"])
vector_store/faiss_index.index CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:25004d0d5df0be08b29e41af806fefc2215d37f215c08fdd5b8ce16484ee83fc
3
- size 175149
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce19ebb4f9f8a57800a85bffbc97c637134adf2775420b4b09889dec95943cf6
3
+ size 6189
vector_store/faiss_index.json CHANGED
The diff for this file is too large to render. See raw diff