File size: 1,857 Bytes
0eb636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import plotly
import pandas as pd
from umap import UMAP
import plotly.express as px
from bertopic import BERTopic

from typing import Dict, List


def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure:
    """
    Create a custom horizontal bar chart of top topics using plotly.express.
    """
    data = []

    for topic_id, label in topic_labels.items():
        if topic_id == -1:
            continue

        for word, score in model.get_topic(topic_id)[:n_words]:
            data.append({"Topic": label, "Word": word, "Score": score})

    df = pd.DataFrame(data)

    fig = px.bar(
        df,
        x = "Score",
        y = "Word",
        color = "Topic",
        orientation = 'h',
        barmode = "group",
        #height = 500,
    )

    fig.update_layout(
        margin = dict(l = 40, r = 20, t = 40, b = 20),
        yaxis = dict(title = ""),
        xaxis = dict(title = "Relevance"),
        legend_title_text = "Topic",
    )

    return fig



def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure:
    """
    Custom UMAP plotting to work better with the Gradio layout.
    """
    reducer = UMAP(n_neighbors = 15, min_dist = 0.1, metric = "cosine", random_state = 42)
    umap_coords = reducer.fit_transform(embeddings)

    df = pd.DataFrame(umap_coords, columns=["x", "y"])
    df["topic"] = topics
    df["label"] = [topic_labels[t] for t in topics]

    # Filter out topic -1 (noise)
    df = df[df["topic"] != -1]

    fig = px.scatter(
        df,
        x = 'x',
        y = 'y',
        color = "label",
        labels = {"label": "Topic"},
        #height = 500
    )

    fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20))

    return fig