semanticdala / src /utils /plotting.py
crossroderick's picture
Added all files
0eb636f
raw
history blame
1.86 kB
import plotly
import pandas as pd
from umap import UMAP
import plotly.express as px
from bertopic import BERTopic
from typing import Dict, List
def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure:
"""
Create a custom horizontal bar chart of top topics using plotly.express.
"""
data = []
for topic_id, label in topic_labels.items():
if topic_id == -1:
continue
for word, score in model.get_topic(topic_id)[:n_words]:
data.append({"Topic": label, "Word": word, "Score": score})
df = pd.DataFrame(data)
fig = px.bar(
df,
x = "Score",
y = "Word",
color = "Topic",
orientation = 'h',
barmode = "group",
#height = 500,
)
fig.update_layout(
margin = dict(l = 40, r = 20, t = 40, b = 20),
yaxis = dict(title = ""),
xaxis = dict(title = "Relevance"),
legend_title_text = "Topic",
)
return fig
def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure:
"""
Custom UMAP plotting to work better with the Gradio layout.
"""
reducer = UMAP(n_neighbors = 15, min_dist = 0.1, metric = "cosine", random_state = 42)
umap_coords = reducer.fit_transform(embeddings)
df = pd.DataFrame(umap_coords, columns=["x", "y"])
df["topic"] = topics
df["label"] = [topic_labels[t] for t in topics]
# Filter out topic -1 (noise)
df = df[df["topic"] != -1]
fig = px.scatter(
df,
x = 'x',
y = 'y',
color = "label",
labels = {"label": "Topic"},
#height = 500
)
fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20))
return fig