marksverdhei's picture
It works now
644a030
raw
history blame
3.48 kB
import streamlit as st
import pandas as pd
import vec2text
import torch
from transformers import AutoModel, AutoTokenizer
from umap import UMAP
from tqdm import tqdm
import plotly.express as px
import numpy as np
from sklearn.decomposition import PCA
from streamlit_plotly_events import plotly_events
import plotly.graph_objects as go
import logging
# Activate tqdm with pandas
tqdm.pandas()
@st.cache_resource
def vector_compressor_from_config():
'TODO'
# return PCA(2)
return UMAP(2)
# Caching the dataframe since loading from external source can be time-consuming
@st.cache_data
def load_data():
return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
df = load_data()
# Caching the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer():
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
return encoder, tokenizer
encoder, tokenizer = load_model_and_tokenizer()
# Caching the vec2text corrector
@st.cache_resource
def load_corrector():
return vec2text.load_pretrained_corrector("gtr-base")
corrector = load_corrector()
# Caching the precomputed embeddings since they are stored locally and large
@st.cache_data
def load_embeddings():
return np.load("syac-title-embeddings.npy")
embeddings = load_embeddings()
# Caching UMAP reduction as it's a heavy computation
@st.cache_resource
def reduce_embeddings(embeddings):
reducer = vector_compressor_from_config()
return reducer.fit_transform(embeddings), reducer
vectors_2d, reducer = reduce_embeddings(embeddings)
# Add a scatter plot using Plotly
fig = px.scatter(
x=vectors_2d[:, 0],
y=vectors_2d[:, 1],
opacity=0.6,
hover_data={"Title": df["title"]},
labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
title="UMAP Scatter Plot of Reddit Titles",
color_discrete_sequence=["#01a8d3"] # Set default blue color for points
)
# Customize the layout to adapt to browser settings (light/dark mode)
fig.update_layout(
template=None, # Let Plotly adapt automatically based on user settings
plot_bgcolor="rgba(0, 0, 0, 0)",
paper_bgcolor="rgba(0, 0, 0, 0)"
)
# Display the scatterplot and capture click events
selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%")
# If a point is clicked, handle the embedding inversion
if selected_points:
clicked_point = selected_points[0]
x_coord = x = clicked_point['x']
y_coord = y = clicked_point['y']
st.text(f"Embeddings shape: {embeddings.shape}")
st.text(f"2dvector shapes shape: {vectors_2d.shape}")
st.text(f"Clicked point coordinates: x = {x_coord}, y = {y_coord}")
st.text("fOO")
logging.info("Foo")
inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
logging.info("Bar")
st.text("Bar")
inferred_embedding = inferred_embedding.astype("float32")
st.text("Bar")
output = vec2text.invert_embeddings(
embeddings=torch.tensor(inferred_embedding).cuda(),
corrector=corrector,
num_steps=20,
)
st.text("Bar")
st.text(str(output))
st.text(str(inferred_embedding))
else:
st.text("Click on a point in the scatterplot to see its coordinates.")