textographe / app.py
gloignon's picture
Update app.py
f8825b7 verified
raw
history blame
2.37 kB
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA
from sentence_transformers import SentenceTransformer
# Load pre-trained sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Function to compute document embeddings and apply PCA
# Modify the Gradio interface to accept a list of identifiers and texts
def compute_pca(data):
# data is expected to be a list of dictionaries with 'Identifier' and 'Text' keys
df = pd.DataFrame(data, columns=["Identifier", "Text"])
# Remove rows where 'Identifier' or 'Text' is empty or contains only whitespace
valid_entries = df[
(df['Identifier'].str.strip() != '') &
(df['Text'].str.strip() != '')
]
if valid_entries.empty:
return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.")
# Generate embeddings
embeddings = model.encode(valid_entries['Text'].tolist())
# Perform PCA to reduce to 2 dimensions
pca = PCA(n_components=2)
pca_result = pca.fit_transform(embeddings)
# Add PCA results to the DataFrame
valid_entries = valid_entries.reset_index(drop=True)
valid_entries['PC1'] = pca_result[:, 0]
valid_entries['PC2'] = pca_result[:, 1]
# Plot the PCA result with identifiers as labels
fig = px.scatter(valid_entries, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings')
return fig
def text_editor_app():
with gr.Blocks() as demo:
identifiers = []
texts = []
with gr.Row():
for i in range(4): # Assuming 4 entries
with gr.Column():
id_input = gr.Textbox(label=f"Identifier {i+1}")
text_input = gr.Textbox(label=f"Text {i+1}")
identifiers.append(id_input)
texts.append(text_input)
analyze_button = gr.Button("Run Analysis")
output_plot = gr.Plot(label="PCA Visualization")
def collect_inputs(*args):
data = list(zip(args[:4], args[4:])) # Pair identifiers and texts
return compute_pca(data)
inputs = identifiers + texts
analyze_button.click(fn=collect_inputs, inputs=inputs, outputs=output_plot)
return demo
# Launch the app
text_editor_app().launch()