gloignon commited on
Commit
1b5bc8f
·
verified ·
1 Parent(s): 1014197

Trying new generic method

Browse files
Files changed (1) hide show
  1. app.py +17 -36
app.py CHANGED
@@ -9,64 +9,45 @@ from sentence_transformers import SentenceTransformer
9
  model = SentenceTransformer('all-MiniLM-L6-v2')
10
 
11
  # Function to compute document embeddings and apply PCA
12
- def compute_pca(id1, text1, id2, text2, id3, text3, id4, text4):
13
- # Collect identifiers and texts into lists
14
- identifiers = [id1, id2, id3, id4]
15
- texts = [text1, text2, text3, text4]
16
-
17
- # Filter out any empty inputs
18
- valid_entries = [(id, text) for id, text in zip(identifiers, texts) if id and text]
19
- if not valid_entries:
20
  return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.")
21
 
22
- # Unzip identifiers and texts
23
- identifiers, texts = zip(*valid_entries)
24
-
25
  # Generate embeddings
26
- embeddings = model.encode(texts)
27
 
28
  # Perform PCA to reduce to 2 dimensions
29
  pca = PCA(n_components=2)
30
  pca_result = pca.fit_transform(embeddings)
31
 
32
- # Create DataFrame for visualization
33
- result_df = pd.DataFrame({
34
- 'Identifier': identifiers,
35
- 'PC1': pca_result[:, 0],
36
- 'PC2': pca_result[:, 1]
37
- })
38
 
39
  # Plot the PCA result with identifiers as labels
40
- fig = px.scatter(result_df, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings')
41
  return fig
42
 
43
- # Gradio interface
44
  def text_editor_app():
45
  with gr.Blocks() as demo:
46
- # Input boxes for four identifier-text pairs
47
- with gr.Row():
48
- id1 = gr.Textbox(label="Identifier 1")
49
- text1 = gr.Textbox(label="Text 1")
50
- with gr.Row():
51
- id2 = gr.Textbox(label="Identifier 2")
52
- text2 = gr.Textbox(label="Text 2")
53
- with gr.Row():
54
- id3 = gr.Textbox(label="Identifier 3")
55
- text3 = gr.Textbox(label="Text 3")
56
- with gr.Row():
57
- id4 = gr.Textbox(label="Identifier 4")
58
- text4 = gr.Textbox(label="Text 4")
59
 
60
  # Button to run the analysis
61
  analyze_button = gr.Button("Run Analysis")
62
-
63
  # Output plot
64
  output_plot = gr.Plot(label="PCA Visualization")
65
-
66
  # Run analysis when the button is clicked
67
- analyze_button.click(fn=compute_pca, inputs=[id1, text1, id2, text2, id3, text3, id4, text4], outputs=output_plot)
68
 
69
  return demo
70
 
 
71
  # Launch the app
72
  text_editor_app().launch()
 
9
  model = SentenceTransformer('all-MiniLM-L6-v2')
10
 
11
  # Function to compute document embeddings and apply PCA
12
+ # Modify the Gradio interface to accept a list of identifiers and texts
13
+ def compute_pca(data):
14
+ # data is expected to be a list of dictionaries with 'Identifier' and 'Text' keys
15
+ df = pd.DataFrame(data)
16
+ valid_entries = df.dropna()
17
+ if valid_entries.empty:
 
 
18
  return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.")
19
 
 
 
 
20
  # Generate embeddings
21
+ embeddings = model.encode(valid_entries['Text'].tolist())
22
 
23
  # Perform PCA to reduce to 2 dimensions
24
  pca = PCA(n_components=2)
25
  pca_result = pca.fit_transform(embeddings)
26
 
27
+ # Add PCA results to the DataFrame
28
+ valid_entries['PC1'] = pca_result[:, 0]
29
+ valid_entries['PC2'] = pca_result[:, 1]
 
 
 
30
 
31
  # Plot the PCA result with identifiers as labels
32
+ fig = px.scatter(valid_entries, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings')
33
  return fig
34
 
 
35
  def text_editor_app():
36
  with gr.Blocks() as demo:
37
+ # Use a DataFrame component for inputs
38
+ data_input = gr.Dataframe(headers=["Identifier", "Text"], datatype=["str", "str"], label="Input Data")
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Button to run the analysis
41
  analyze_button = gr.Button("Run Analysis")
42
+
43
  # Output plot
44
  output_plot = gr.Plot(label="PCA Visualization")
45
+
46
  # Run analysis when the button is clicked
47
+ analyze_button.click(fn=compute_pca, inputs=[data_input], outputs=output_plot)
48
 
49
  return demo
50
 
51
+
52
  # Launch the app
53
  text_editor_app().launch()