gloignon commited on
Commit
76026d0
·
verified ·
1 Parent(s): 2834d04

back to working version

Browse files
Files changed (1) hide show
  1. app.py +29 -47
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import pandas as pd
3
  import plotly.express as px
4
  from sklearn.decomposition import PCA
@@ -8,28 +9,21 @@ from sentence_transformers import SentenceTransformer
8
  model = SentenceTransformer('all-MiniLM-L6-v2')
9
 
10
  # Function to compute document embeddings and apply PCA
11
- def compute_pca(*args):
12
- # args will be identifier1, text1, identifier2, text2, ..., identifierN, textN
13
- # Pair them up
14
- data = []
15
- for i in range(0, len(args), 2):
16
- identifier = args[i]
17
- text = args[i+1]
18
- data.append([identifier, text])
19
  df = pd.DataFrame(data, columns=["Identifier", "Text"])
20
-
 
21
  # Remove rows where 'Identifier' or 'Text' is empty or contains only whitespace
22
  valid_entries = df[
23
- (df['Identifier'].astype(str).str.strip() != '') &
24
- (df['Text'].astype(str).str.strip() != '')
25
  ]
26
 
27
  if valid_entries.empty:
28
  return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.")
29
 
30
- if len(valid_entries) < 2:
31
- return gr.Plot.update(value=None, label="At least two texts are required to perform PCA.")
32
-
33
  # Generate embeddings
34
  embeddings = model.encode(valid_entries['Text'].tolist())
35
 
@@ -44,62 +38,50 @@ def compute_pca(*args):
44
 
45
  # Plot the PCA result with identifiers as labels
46
  fig = px.scatter(valid_entries, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings')
47
- fig.update_traces(textposition='top center')
48
  return fig
49
 
50
  def text_editor_app():
51
  with gr.Blocks() as demo:
52
  identifier_inputs = []
53
  text_inputs = []
54
- pair_rows = []
55
 
56
  gr.Markdown("### Enter at least two identifier-text pairs:")
57
-
58
- add_pair_btn = gr.Button("Add Text")
59
- analyze_button = gr.Button("Run Analysis")
60
- output_plot = gr.Plot(label="PCA Visualization")
61
-
62
- max_pairs = 10 # Maximum number of pairs
63
- initial_pairs = 4 # Initial number of visible pairs
64
-
65
- # Create the input pairs
66
- for i in range(max_pairs):
67
- with gr.Column(visible=(i < initial_pairs)) as pair_row:
68
  id_input = gr.Textbox(label=f"Identifier {i+1}")
69
  text_input = gr.Textbox(label=f"Text {i+1}")
70
- gr.Markdown("---") # Add a horizontal rule to create a break
71
- identifier_inputs.append(id_input)
72
- text_inputs.append(text_input)
73
- pair_rows.append(pair_row)
74
-
75
- # Function to add a new pair
76
- def add_pair():
77
- # Find the next invisible pair and make it visible
78
- for pair_row in pair_rows:
79
- if not pair_row.visible:
80
- return gr.update(visible=True, value=None, interactive=True, component=pair_row)
81
- return None # No more pairs to show
82
-
83
- # Connect the add_pair function to the button
84
- add_pair_btn.click(fn=add_pair, inputs=None, outputs=pair_rows)
85
 
 
 
 
86
  # Function to collect inputs and process them
87
  def collect_inputs(*args):
 
 
88
  data = []
89
  for i in range(0, len(args), 2):
90
  identifier = args[i]
91
  text = args[i+1]
92
  data.append([identifier, text])
93
- return compute_pca(*args)
94
-
95
- # Combine all inputs
96
  inputs = []
97
  for id_input, text_input in zip(identifier_inputs, text_inputs):
98
  inputs.extend([id_input, text_input])
99
-
100
  analyze_button.click(fn=collect_inputs, inputs=inputs, outputs=output_plot)
101
-
102
  return demo
103
 
 
 
 
104
  # Launch the app
105
  text_editor_app().launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
  import pandas as pd
4
  import plotly.express as px
5
  from sklearn.decomposition import PCA
 
9
  model = SentenceTransformer('all-MiniLM-L6-v2')
10
 
11
  # Function to compute document embeddings and apply PCA
12
+ # Modify the Gradio interface to accept a list of identifiers and texts
13
+ def compute_pca(data):
14
+ # data is expected to be a list of dictionaries with 'Identifier' and 'Text' keys
 
 
 
 
 
15
  df = pd.DataFrame(data, columns=["Identifier", "Text"])
16
+
17
+
18
  # Remove rows where 'Identifier' or 'Text' is empty or contains only whitespace
19
  valid_entries = df[
20
+ (df['Identifier'].str.strip() != '') &
21
+ (df['Text'].str.strip() != '')
22
  ]
23
 
24
  if valid_entries.empty:
25
  return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.")
26
 
 
 
 
27
  # Generate embeddings
28
  embeddings = model.encode(valid_entries['Text'].tolist())
29
 
 
38
 
39
  # Plot the PCA result with identifiers as labels
40
  fig = px.scatter(valid_entries, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings')
 
41
  return fig
42
 
43
  def text_editor_app():
44
  with gr.Blocks() as demo:
45
  identifier_inputs = []
46
  text_inputs = []
 
47
 
48
  gr.Markdown("### Enter at least two identifier-text pairs:")
49
+
50
+ for i in range(4): # Assuming we have 4 entries
51
+ with gr.Column():
 
 
 
 
 
 
 
 
52
  id_input = gr.Textbox(label=f"Identifier {i+1}")
53
  text_input = gr.Textbox(label=f"Text {i+1}")
54
+ identifier_inputs.append(id_input)
55
+ text_inputs.append(text_input)
56
+ gr.Markdown("---") # Add a horizontal rule to create a break
57
+
58
+ # Button to run the analysis
59
+ analyze_button = gr.Button("Run Analysis")
 
 
 
 
 
 
 
 
 
60
 
61
+ # Output plot
62
+ output_plot = gr.Plot(label="PCA Visualization")
63
+
64
  # Function to collect inputs and process them
65
  def collect_inputs(*args):
66
+ # args will be identifier1, text1, identifier2, text2, ..., identifier4, text4
67
+ # So we need to pair them up
68
  data = []
69
  for i in range(0, len(args), 2):
70
  identifier = args[i]
71
  text = args[i+1]
72
  data.append([identifier, text])
73
+ return compute_pca(data)
74
+
 
75
  inputs = []
76
  for id_input, text_input in zip(identifier_inputs, text_inputs):
77
  inputs.extend([id_input, text_input])
78
+
79
  analyze_button.click(fn=collect_inputs, inputs=inputs, outputs=output_plot)
80
+
81
  return demo
82
 
83
+
84
+
85
+
86
  # Launch the app
87
  text_editor_app().launch()