textographe / app.py
gloignon's picture
back to working version
76026d0 verified
raw
history blame
2.98 kB
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA
from sentence_transformers import SentenceTransformer
# Load pre-trained sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Function to compute document embeddings and apply PCA
# Modify the Gradio interface to accept a list of identifiers and texts
def compute_pca(data):
# data is expected to be a list of dictionaries with 'Identifier' and 'Text' keys
df = pd.DataFrame(data, columns=["Identifier", "Text"])
# Remove rows where 'Identifier' or 'Text' is empty or contains only whitespace
valid_entries = df[
(df['Identifier'].str.strip() != '') &
(df['Text'].str.strip() != '')
]
if valid_entries.empty:
return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.")
# Generate embeddings
embeddings = model.encode(valid_entries['Text'].tolist())
# Perform PCA to reduce to 2 dimensions
pca = PCA(n_components=2)
pca_result = pca.fit_transform(embeddings)
# Add PCA results to the DataFrame
valid_entries = valid_entries.reset_index(drop=True)
valid_entries['PC1'] = pca_result[:, 0]
valid_entries['PC2'] = pca_result[:, 1]
# Plot the PCA result with identifiers as labels
fig = px.scatter(valid_entries, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings')
return fig
def text_editor_app():
with gr.Blocks() as demo:
identifier_inputs = []
text_inputs = []
gr.Markdown("### Enter at least two identifier-text pairs:")
for i in range(4): # Assuming we have 4 entries
with gr.Column():
id_input = gr.Textbox(label=f"Identifier {i+1}")
text_input = gr.Textbox(label=f"Text {i+1}")
identifier_inputs.append(id_input)
text_inputs.append(text_input)
gr.Markdown("---") # Add a horizontal rule to create a break
# Button to run the analysis
analyze_button = gr.Button("Run Analysis")
# Output plot
output_plot = gr.Plot(label="PCA Visualization")
# Function to collect inputs and process them
def collect_inputs(*args):
# args will be identifier1, text1, identifier2, text2, ..., identifier4, text4
# So we need to pair them up
data = []
for i in range(0, len(args), 2):
identifier = args[i]
text = args[i+1]
data.append([identifier, text])
return compute_pca(data)
inputs = []
for id_input, text_input in zip(identifier_inputs, text_inputs):
inputs.extend([id_input, text_input])
analyze_button.click(fn=collect_inputs, inputs=inputs, outputs=output_plot)
return demo
# Launch the app
text_editor_app().launch()