|
import streamlit as st |
|
import pandas as pd |
|
import vec2text |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
from umap import UMAP |
|
from tqdm import tqdm |
|
import plotly.express as px |
|
import numpy as np |
|
from sklearn.decomposition import PCA |
|
from streamlit_plotly_events import plotly_events |
|
import plotly.graph_objects as go |
|
import logging |
|
|
|
tqdm.pandas() |
|
|
|
@st.cache_resource |
|
def vector_compressor_from_config(): |
|
'TODO' |
|
|
|
return UMAP(2) |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") |
|
|
|
df = load_data() |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(): |
|
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda") |
|
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") |
|
return encoder, tokenizer |
|
|
|
encoder, tokenizer = load_model_and_tokenizer() |
|
|
|
|
|
@st.cache_resource |
|
def load_corrector(): |
|
return vec2text.load_pretrained_corrector("gtr-base") |
|
|
|
corrector = load_corrector() |
|
|
|
|
|
@st.cache_data |
|
def load_embeddings(): |
|
return np.load("syac-title-embeddings.npy") |
|
|
|
embeddings = load_embeddings() |
|
|
|
|
|
@st.cache_resource |
|
def reduce_embeddings(embeddings): |
|
reducer = vector_compressor_from_config() |
|
return reducer.fit_transform(embeddings), reducer |
|
|
|
vectors_2d, reducer = reduce_embeddings(embeddings) |
|
|
|
|
|
fig = px.scatter( |
|
x=vectors_2d[:, 0], |
|
y=vectors_2d[:, 1], |
|
opacity=0.6, |
|
hover_data={"Title": df["title"]}, |
|
labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'}, |
|
title="UMAP Scatter Plot of Reddit Titles", |
|
color_discrete_sequence=["#01a8d3"] |
|
) |
|
|
|
|
|
fig.update_layout( |
|
template=None, |
|
plot_bgcolor="rgba(0, 0, 0, 0)", |
|
paper_bgcolor="rgba(0, 0, 0, 0)" |
|
) |
|
|
|
|
|
selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%") |
|
|
|
|
|
|
|
if selected_points: |
|
|
|
clicked_point = selected_points[0] |
|
x_coord = x = clicked_point['x'] |
|
y_coord = y = clicked_point['y'] |
|
st.text(f"Embeddings shape: {embeddings.shape}") |
|
st.text(f"2dvector shapes shape: {vectors_2d.shape}") |
|
st.text(f"Clicked point coordinates: x = {x_coord}, y = {y_coord}") |
|
st.text("fOO") |
|
logging.info("Foo") |
|
inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]])) |
|
logging.info("Bar") |
|
|
|
st.text("Bar") |
|
|
|
inferred_embedding = inferred_embedding.astype("float32") |
|
st.text("Bar") |
|
|
|
output = vec2text.invert_embeddings( |
|
embeddings=torch.tensor(inferred_embedding).cuda(), |
|
corrector=corrector, |
|
num_steps=20, |
|
) |
|
st.text("Bar") |
|
|
|
st.text(str(output)) |
|
st.text(str(inferred_embedding)) |
|
else: |
|
st.text("Click on a point in the scatterplot to see its coordinates.") |