semanticdala / src /utils /plotting.py
crossroderick's picture
Major clustering update
9d99321
raw
history blame
2.53 kB
import plotly
import pandas as pd
from umap import UMAP
import plotly.express as px
from bertopic import BERTopic
from typing import Dict, List
def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure:
"""
Create a custom horizontal bar chart of top topics using plotly.express.
"""
data = []
for topic_id, label in topic_labels.items():
if topic_id == -1:
continue
topic = model.get_topic(topic_id)
if not isinstance(topic, list) or len(topic) == 0:
continue
for pair in topic[:n_words]:
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
continue
word, score = pair
data.append({"Topic": label, "Word": word, "Score": score})
# ✅ Construct only if data exists
if not data:
print("[WARN] No topic-word-score data to visualize.")
return plotly.graph_objs.Figure()
df = pd.DataFrame(data)
required_cols = {"Topic", "Word", "Score"}
if not required_cols.issubset(df.columns):
print("[ERROR] Required columns missing in DataFrame.")
return plotly.graph_objs.Figure()
fig = px.bar(
df,
x="Score",
y="Word",
color="Topic",
orientation='h',
barmode="group",
)
fig.update_layout(
margin=dict(l=40, r=20, t=40, b=20),
yaxis=dict(title=""),
xaxis=dict(title="Relevance"),
legend_title_text="Topic",
)
return fig
def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure:
"""
Custom UMAP plotting to work better with the Gradio layout.
"""
# Compute a safe number of neighbours
n_samples = len(embeddings)
safe_n_neighbours = min(15, max(2, n_samples - 1))
reducer = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42)
umap_coords = reducer.fit_transform(embeddings)
df = pd.DataFrame(umap_coords, columns=["x", "y"])
df["topic"] = topics
df["label"] = [topic_labels[t] for t in topics]
# Filter out topic -1 (noise)
df = df[df["topic"] != -1]
fig = px.scatter(
df,
x = 'x',
y = 'y',
color = "label",
labels = {"label": "Topic"},
#height = 500
)
fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20))
return fig