textographe / app.py
gloignon's picture
Update app.py
68dda59 verified
raw
history blame
2.8 kB
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA
from sentence_transformers import SentenceTransformer
# Load pre-trained sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Function to calculate embeddings and PCA
def compute_pca(texts, ids):
# Generate embeddings
embeddings = model.encode(texts)
# Compute PCA
pca = PCA(n_components=2)
pca_result = pca.fit_transform(embeddings)
# Create DataFrame for visualization
df = pd.DataFrame({
'ID': ids,
'Text': texts,
'PC1': pca_result[:, 0],
'PC2': pca_result[:, 1]
})
# Plot the PCA result with identifiers as labels
fig = px.scatter(df, x='PC1', y='PC2', text='ID', title='PCA of Text Embeddings')
return fig
# Define Gradio app layout and interactions
def text_editor_app():
with gr.Blocks() as demo:
# Input fields for text and identifier
text_input = gr.Textbox(lines=5, placeholder="Enter or paste your texts here, one per line...", label="Text Inputs")
id_input = gr.Textbox(lines=5, placeholder="Enter an identifier for each text, one per line...", label="Identifier Inputs")
# Display the list of texts with identifiers
texts_df = gr.Dataframe(headers=["ID", "Text"], label="Text List with Identifiers", interactive=True)
# Button to process texts and identifiers
submit_button = gr.Button("Compute Embeddings and PCA")
# Output plot
output_plot = gr.Plot(label="PCA Visualization")
# Function to process input texts and identifiers
def process_texts_and_ids(text_input, id_input):
# Split input texts and identifiers by newline
text_list = text_input.strip().split('\n')
id_list = id_input.strip().split('\n')
# Ensure both lists are of equal length
if len(text_list) != len(id_list):
return gr.update(value=[], error="Number of texts and identifiers must match.")
# Return a new DataFrame instance with both text and identifiers
return [[id_list[i], text_list[i]] for i in range(len(text_list))]
# Define the button click interaction
submit_button.click(fn=lambda x: compute_pca([t[1] for t in x], [t[0] for t in x]), inputs=texts_df, outputs=output_plot)
# Update DataFrame with texts and identifiers
text_input.change(fn=process_texts_and_ids, inputs=[text_input, id_input], outputs=texts_df)
id_input.change(fn=process_texts_and_ids, inputs=[text_input, id_input], outputs=texts_df)
return demo
# Launch the app
text_editor_app().launch()