import plotly import pandas as pd from umap import UMAP import plotly.express as px from bertopic import BERTopic from typing import Dict, List def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure: """ Create a custom horizontal bar chart of top topics using plotly.express. """ data = [] for topic_id, label in topic_labels.items(): if topic_id == -1: continue topic = model.get_topic(topic_id) if not isinstance(topic, list) or len(topic) == 0: continue for pair in topic[:n_words]: if not isinstance(pair, (list, tuple)) or len(pair) != 2: continue word, score = pair data.append({"Topic": label, "Word": word, "Score": score}) # ✅ Construct only if data exists if not data: print("[WARN] No topic-word-score data to visualize.") return plotly.graph_objs.Figure() df = pd.DataFrame(data) required_cols = {"Topic", "Word", "Score"} if not required_cols.issubset(df.columns): print("[ERROR] Required columns missing in DataFrame.") return plotly.graph_objs.Figure() fig = px.bar( df, x="Score", y="Word", color="Topic", orientation='h', barmode="group", ) fig.update_layout( margin=dict(l=40, r=20, t=40, b=20), yaxis=dict(title=""), xaxis=dict(title="Relevance"), legend_title_text="Topic", ) return fig def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure: """ Custom UMAP plotting to work better with the Gradio layout. """ # Compute a safe number of neighbours n_samples = len(embeddings) safe_n_neighbours = min(15, max(2, n_samples - 1)) reducer = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42) umap_coords = reducer.fit_transform(embeddings) df = pd.DataFrame(umap_coords, columns=["x", "y"]) df["topic"] = topics df["label"] = [topic_labels[t] for t in topics] # Filter out topic -1 (noise) df = df[df["topic"] != -1] fig = px.scatter( df, x = 'x', y = 'y', color = "label", labels = {"label": "Topic"}, #height = 500 ) fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20)) return fig