File size: 2,534 Bytes
0eb636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d99321
 
 
 
 
 
 
 
0eb636f
 
9d99321
 
 
 
 
0eb636f
 
9d99321
 
 
 
 
0eb636f
 
9d99321
 
 
 
 
0eb636f
 
 
9d99321
 
 
 
0eb636f
 
 
 
 
 
 
 
 
9d99321
 
 
 
 
0eb636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import plotly
import pandas as pd
from umap import UMAP
import plotly.express as px
from bertopic import BERTopic

from typing import Dict, List


def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure:
    """
    Create a custom horizontal bar chart of top topics using plotly.express.
    """
    data = []

    for topic_id, label in topic_labels.items():
        if topic_id == -1:
            continue

        topic = model.get_topic(topic_id)
        if not isinstance(topic, list) or len(topic) == 0:
            continue

        for pair in topic[:n_words]:
            if not isinstance(pair, (list, tuple)) or len(pair) != 2:
                continue
            word, score = pair
            data.append({"Topic": label, "Word": word, "Score": score})

    # ✅ Construct only if data exists
    if not data:
        print("[WARN] No topic-word-score data to visualize.")
        return plotly.graph_objs.Figure()

    df = pd.DataFrame(data)

    required_cols = {"Topic", "Word", "Score"}
    if not required_cols.issubset(df.columns):
        print("[ERROR] Required columns missing in DataFrame.")
        return plotly.graph_objs.Figure()

    fig = px.bar(
        df,
        x="Score",
        y="Word",
        color="Topic",
        orientation='h',
        barmode="group",
    )

    fig.update_layout(
        margin=dict(l=40, r=20, t=40, b=20),
        yaxis=dict(title=""),
        xaxis=dict(title="Relevance"),
        legend_title_text="Topic",
    )

    return fig


def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure:
    """
    Custom UMAP plotting to work better with the Gradio layout.
    """
    # Compute a safe number of neighbours
    n_samples = len(embeddings)
    safe_n_neighbours = min(15, max(2, n_samples - 1))

    reducer = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42)
    umap_coords = reducer.fit_transform(embeddings)

    df = pd.DataFrame(umap_coords, columns=["x", "y"])
    df["topic"] = topics
    df["label"] = [topic_labels[t] for t in topics]

    # Filter out topic -1 (noise)
    df = df[df["topic"] != -1]

    fig = px.scatter(
        df,
        x = 'x',
        y = 'y',
        color = "label",
        labels = {"label": "Topic"},
        #height = 500
    )

    fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20))

    return fig