Spaces:
Sleeping
Sleeping
File size: 2,534 Bytes
0eb636f 9d99321 0eb636f 9d99321 0eb636f 9d99321 0eb636f 9d99321 0eb636f 9d99321 0eb636f 9d99321 0eb636f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import plotly
import pandas as pd
from umap import UMAP
import plotly.express as px
from bertopic import BERTopic
from typing import Dict, List
def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure:
"""
Create a custom horizontal bar chart of top topics using plotly.express.
"""
data = []
for topic_id, label in topic_labels.items():
if topic_id == -1:
continue
topic = model.get_topic(topic_id)
if not isinstance(topic, list) or len(topic) == 0:
continue
for pair in topic[:n_words]:
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
continue
word, score = pair
data.append({"Topic": label, "Word": word, "Score": score})
# ✅ Construct only if data exists
if not data:
print("[WARN] No topic-word-score data to visualize.")
return plotly.graph_objs.Figure()
df = pd.DataFrame(data)
required_cols = {"Topic", "Word", "Score"}
if not required_cols.issubset(df.columns):
print("[ERROR] Required columns missing in DataFrame.")
return plotly.graph_objs.Figure()
fig = px.bar(
df,
x="Score",
y="Word",
color="Topic",
orientation='h',
barmode="group",
)
fig.update_layout(
margin=dict(l=40, r=20, t=40, b=20),
yaxis=dict(title=""),
xaxis=dict(title="Relevance"),
legend_title_text="Topic",
)
return fig
def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure:
"""
Custom UMAP plotting to work better with the Gradio layout.
"""
# Compute a safe number of neighbours
n_samples = len(embeddings)
safe_n_neighbours = min(15, max(2, n_samples - 1))
reducer = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42)
umap_coords = reducer.fit_transform(embeddings)
df = pd.DataFrame(umap_coords, columns=["x", "y"])
df["topic"] = topics
df["label"] = [topic_labels[t] for t in topics]
# Filter out topic -1 (noise)
df = df[df["topic"] != -1]
fig = px.scatter(
df,
x = 'x',
y = 'y',
color = "label",
labels = {"label": "Topic"},
#height = 500
)
fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20))
return fig |