Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
from sklearn.decomposition import PCA | |
from sentence_transformers import SentenceTransformer | |
# Load pre-trained sentence transformer model | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Function to compute document embeddings and apply PCA | |
# Modify the Gradio interface to accept a list of identifiers and texts | |
def compute_pca(data): | |
# data is expected to be a list of dictionaries with 'Identifier' and 'Text' keys | |
df = pd.DataFrame(data, columns=["Identifier", "Text"]) | |
# Remove rows where 'Identifier' or 'Text' is empty or contains only whitespace | |
valid_entries = df[ | |
(df['Identifier'].str.strip() != '') & | |
(df['Text'].str.strip() != '') | |
] | |
if valid_entries.empty: | |
return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.") | |
# Generate embeddings | |
embeddings = model.encode(valid_entries['Text'].tolist()) | |
# Perform PCA to reduce to 2 dimensions | |
pca = PCA(n_components=2) | |
pca_result = pca.fit_transform(embeddings) | |
# Add PCA results to the DataFrame | |
valid_entries = valid_entries.reset_index(drop=True) | |
valid_entries['PC1'] = pca_result[:, 0] | |
valid_entries['PC2'] = pca_result[:, 1] | |
# Plot the PCA result with identifiers as labels | |
fig = px.scatter(valid_entries, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings') | |
return fig | |
def text_editor_app(): | |
with gr.Blocks() as demo: | |
identifier_inputs = [] | |
text_inputs = [] | |
gr.Markdown("### Enter at least two identifier-text pairs:") | |
for i in range(4): # Assuming we have 4 entries | |
with gr.Column(): | |
id_input = gr.Textbox(label=f"Identifier {i+1}") | |
text_input = gr.Textbox(label=f"Text {i+1}") | |
identifier_inputs.append(id_input) | |
text_inputs.append(text_input) | |
gr.Markdown("---") # Add a horizontal rule to create a break | |
# Button to run the analysis | |
analyze_button = gr.Button("Run Analysis") | |
# Output plot | |
output_plot = gr.Plot(label="PCA Visualization") | |
# Function to collect inputs and process them | |
def collect_inputs(*args): | |
# args will be identifier1, text1, identifier2, text2, ..., identifier4, text4 | |
# So we need to pair them up | |
data = [] | |
for i in range(0, len(args), 2): | |
identifier = args[i] | |
text = args[i+1] | |
data.append([identifier, text]) | |
return compute_pca(data) | |
inputs = [] | |
for id_input, text_input in zip(identifier_inputs, text_inputs): | |
inputs.extend([id_input, text_input]) | |
analyze_button.click(fn=collect_inputs, inputs=inputs, outputs=output_plot) | |
return demo | |
# Launch the app | |
text_editor_app().launch() | |