File size: 3,484 Bytes
8aa44e7
 
 
 
644a030
8aa44e7
 
 
 
93f5069
644a030
 
 
8aa44e7
 
 
93f5069
 
 
644a030
 
93f5069
8aa44e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f5069
8aa44e7
 
 
 
 
 
 
 
644a030
8aa44e7
 
644a030
 
 
 
 
 
 
 
 
8aa44e7
 
644a030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
import pandas as pd
import vec2text
import torch
from transformers import AutoModel, AutoTokenizer
from umap import UMAP
from tqdm import tqdm
import plotly.express as px
import numpy as np
from sklearn.decomposition import PCA
from streamlit_plotly_events import plotly_events
import plotly.graph_objects as go
import logging
# Activate tqdm with pandas
tqdm.pandas()

@st.cache_resource
def vector_compressor_from_config():
    'TODO'
    # return PCA(2)
    return UMAP(2)

# Caching the dataframe since loading from external source can be time-consuming
@st.cache_data
def load_data():
    return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")

df = load_data()

# Caching the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer():
    encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
    return encoder, tokenizer

encoder, tokenizer = load_model_and_tokenizer()

# Caching the vec2text corrector
@st.cache_resource
def load_corrector():
    return vec2text.load_pretrained_corrector("gtr-base")

corrector = load_corrector()

# Caching the precomputed embeddings since they are stored locally and large
@st.cache_data
def load_embeddings():
    return np.load("syac-title-embeddings.npy")

embeddings = load_embeddings()

# Caching UMAP reduction as it's a heavy computation
@st.cache_resource
def reduce_embeddings(embeddings):
    reducer = vector_compressor_from_config()
    return reducer.fit_transform(embeddings), reducer

vectors_2d, reducer = reduce_embeddings(embeddings)

# Add a scatter plot using Plotly
fig = px.scatter(
    x=vectors_2d[:, 0], 
    y=vectors_2d[:, 1], 
    opacity=0.6,
    hover_data={"Title": df["title"]}, 
    labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
    title="UMAP Scatter Plot of Reddit Titles",
    color_discrete_sequence=["#01a8d3"]  # Set default blue color for points
)

# Customize the layout to adapt to browser settings (light/dark mode)
fig.update_layout(
    template=None,  # Let Plotly adapt automatically based on user settings
    plot_bgcolor="rgba(0, 0, 0, 0)",
    paper_bgcolor="rgba(0, 0, 0, 0)"
)

# Display the scatterplot and capture click events
selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%")


# If a point is clicked, handle the embedding inversion
if selected_points:
    
    clicked_point = selected_points[0]
    x_coord = x = clicked_point['x']
    y_coord = y = clicked_point['y']
    st.text(f"Embeddings shape: {embeddings.shape}")
    st.text(f"2dvector shapes shape: {vectors_2d.shape}")
    st.text(f"Clicked point coordinates: x = {x_coord}, y = {y_coord}")
    st.text("fOO")
    logging.info("Foo")
    inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
    logging.info("Bar")

    st.text("Bar")

    inferred_embedding = inferred_embedding.astype("float32")
    st.text("Bar")

    output = vec2text.invert_embeddings(
        embeddings=torch.tensor(inferred_embedding).cuda(),
        corrector=corrector,
        num_steps=20,
    )
    st.text("Bar")

    st.text(str(output))
    st.text(str(inferred_embedding))
else:
    st.text("Click on a point in the scatterplot to see its coordinates.")