Spaces:
Sleeping
Sleeping
import plotly | |
import pandas as pd | |
from umap import UMAP | |
import plotly.express as px | |
from bertopic import BERTopic | |
from typing import Dict, List | |
def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure: | |
""" | |
Create a custom horizontal bar chart of top topics using plotly.express. | |
""" | |
data = [] | |
for topic_id, label in topic_labels.items(): | |
if topic_id == -1: | |
continue | |
for word, score in model.get_topic(topic_id)[:n_words]: | |
data.append({"Topic": label, "Word": word, "Score": score}) | |
df = pd.DataFrame(data) | |
fig = px.bar( | |
df, | |
x = "Score", | |
y = "Word", | |
color = "Topic", | |
orientation = 'h', | |
barmode = "group", | |
#height = 500, | |
) | |
fig.update_layout( | |
margin = dict(l = 40, r = 20, t = 40, b = 20), | |
yaxis = dict(title = ""), | |
xaxis = dict(title = "Relevance"), | |
legend_title_text = "Topic", | |
) | |
return fig | |
def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure: | |
""" | |
Custom UMAP plotting to work better with the Gradio layout. | |
""" | |
reducer = UMAP(n_neighbors = 15, min_dist = 0.1, metric = "cosine", random_state = 42) | |
umap_coords = reducer.fit_transform(embeddings) | |
df = pd.DataFrame(umap_coords, columns=["x", "y"]) | |
df["topic"] = topics | |
df["label"] = [topic_labels[t] for t in topics] | |
# Filter out topic -1 (noise) | |
df = df[df["topic"] != -1] | |
fig = px.scatter( | |
df, | |
x = 'x', | |
y = 'y', | |
color = "label", | |
labels = {"label": "Topic"}, | |
#height = 500 | |
) | |
fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20)) | |
return fig |