Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
from sklearn.decomposition import PCA | |
from sentence_transformers import SentenceTransformer | |
# Load pre-trained sentence transformer model | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Function to compute document embeddings and apply PCA | |
# Modify the Gradio interface to accept a list of identifiers and texts | |
def compute_pca(data): | |
# data is expected to be a list of dictionaries with 'Identifier' and 'Text' keys | |
df = pd.DataFrame(data, columns=["Identifier", "Text"]) | |
# Remove rows where 'Identifier' or 'Text' is empty or contains only whitespace | |
valid_entries = df[ | |
(df['Identifier'].str.strip() != '') & | |
(df['Text'].str.strip() != '') | |
] | |
if valid_entries.empty: | |
return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.") | |
# Generate embeddings | |
embeddings = model.encode(valid_entries['Text'].tolist()) | |
# Perform PCA to reduce to 2 dimensions | |
pca = PCA(n_components=2) | |
pca_result = pca.fit_transform(embeddings) | |
# Add PCA results to the DataFrame | |
valid_entries = valid_entries.reset_index(drop=True) | |
valid_entries['PC1'] = pca_result[:, 0] | |
valid_entries['PC2'] = pca_result[:, 1] | |
# Plot the PCA result with identifiers as labels | |
fig = px.scatter(valid_entries, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings') | |
return fig | |
def text_editor_app(): | |
with gr.Blocks() as demo: | |
identifiers = [] | |
texts = [] | |
with gr.Row(): | |
for i in range(4): # Assuming 4 entries | |
with gr.Column(): | |
id_input = gr.Textbox(label=f"Identifier {i+1}") | |
text_input = gr.Textbox(label=f"Text {i+1}") | |
identifiers.append(id_input) | |
texts.append(text_input) | |
analyze_button = gr.Button("Run Analysis") | |
output_plot = gr.Plot(label="PCA Visualization") | |
def collect_inputs(*args): | |
data = list(zip(args[:4], args[4:])) # Pair identifiers and texts | |
return compute_pca(data) | |
inputs = identifiers + texts | |
analyze_button.click(fn=collect_inputs, inputs=inputs, outputs=output_plot) | |
return demo | |
# Launch the app | |
text_editor_app().launch() | |