semanticdala / src /utils /plotting.py
crossroderick's picture
Major clustering update
9d99321
import plotly
import pandas as pd
from umap import UMAP
import plotly.express as px
from bertopic import BERTopic
from typing import Dict, List
def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure:
"""
Create a custom horizontal bar chart of top topics using plotly.express.
"""
data = []
for topic_id, label in topic_labels.items():
if topic_id == -1:
continue
topic = model.get_topic(topic_id)
if not isinstance(topic, list) or len(topic) == 0:
continue
for pair in topic[:n_words]:
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
continue
word, score = pair
data.append({"Topic": label, "Word": word, "Score": score})
# ✅ Construct only if data exists
if not data:
print("[WARN] No topic-word-score data to visualize.")
return plotly.graph_objs.Figure()
df = pd.DataFrame(data)
required_cols = {"Topic", "Word", "Score"}
if not required_cols.issubset(df.columns):
print("[ERROR] Required columns missing in DataFrame.")
return plotly.graph_objs.Figure()
fig = px.bar(
df,
x="Score",
y="Word",
color="Topic",
orientation='h',
barmode="group",
)
fig.update_layout(
margin=dict(l=40, r=20, t=40, b=20),
yaxis=dict(title=""),
xaxis=dict(title="Relevance"),
legend_title_text="Topic",
)
return fig
def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure:
"""
Custom UMAP plotting to work better with the Gradio layout.
"""
# Compute a safe number of neighbours
n_samples = len(embeddings)
safe_n_neighbours = min(15, max(2, n_samples - 1))
reducer = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42)
umap_coords = reducer.fit_transform(embeddings)
df = pd.DataFrame(umap_coords, columns=["x", "y"])
df["topic"] = topics
df["label"] = [topic_labels[t] for t in topics]
# Filter out topic -1 (noise)
df = df[df["topic"] != -1]
fig = px.scatter(
df,
x = 'x',
y = 'y',
color = "label",
labels = {"label": "Topic"},
#height = 500
)
fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20))
return fig