Spaces:
Sleeping
Sleeping
import plotly | |
import pandas as pd | |
from umap import UMAP | |
import plotly.express as px | |
from bertopic import BERTopic | |
from typing import Dict, List | |
def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure: | |
""" | |
Create a custom horizontal bar chart of top topics using plotly.express. | |
""" | |
data = [] | |
for topic_id, label in topic_labels.items(): | |
if topic_id == -1: | |
continue | |
topic = model.get_topic(topic_id) | |
if not isinstance(topic, list) or len(topic) == 0: | |
continue | |
for pair in topic[:n_words]: | |
if not isinstance(pair, (list, tuple)) or len(pair) != 2: | |
continue | |
word, score = pair | |
data.append({"Topic": label, "Word": word, "Score": score}) | |
# ✅ Construct only if data exists | |
if not data: | |
print("[WARN] No topic-word-score data to visualize.") | |
return plotly.graph_objs.Figure() | |
df = pd.DataFrame(data) | |
required_cols = {"Topic", "Word", "Score"} | |
if not required_cols.issubset(df.columns): | |
print("[ERROR] Required columns missing in DataFrame.") | |
return plotly.graph_objs.Figure() | |
fig = px.bar( | |
df, | |
x="Score", | |
y="Word", | |
color="Topic", | |
orientation='h', | |
barmode="group", | |
) | |
fig.update_layout( | |
margin=dict(l=40, r=20, t=40, b=20), | |
yaxis=dict(title=""), | |
xaxis=dict(title="Relevance"), | |
legend_title_text="Topic", | |
) | |
return fig | |
def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure: | |
""" | |
Custom UMAP plotting to work better with the Gradio layout. | |
""" | |
# Compute a safe number of neighbours | |
n_samples = len(embeddings) | |
safe_n_neighbours = min(15, max(2, n_samples - 1)) | |
reducer = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42) | |
umap_coords = reducer.fit_transform(embeddings) | |
df = pd.DataFrame(umap_coords, columns=["x", "y"]) | |
df["topic"] = topics | |
df["label"] = [topic_labels[t] for t in topics] | |
# Filter out topic -1 (noise) | |
df = df[df["topic"] != -1] | |
fig = px.scatter( | |
df, | |
x = 'x', | |
y = 'y', | |
color = "label", | |
labels = {"label": "Topic"}, | |
#height = 500 | |
) | |
fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20)) | |
return fig |