gloignon commited on
Commit
6119c28
·
verified ·
1 Parent(s): 11f43b6

Trying a new generic method for input pairs

Browse files
Files changed (1) hide show
  1. app.py +47 -28
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import numpy as np
3
  import pandas as pd
4
  import plotly.express as px
5
  from sklearn.decomposition import PCA
@@ -9,21 +8,28 @@ from sentence_transformers import SentenceTransformer
9
  model = SentenceTransformer('all-MiniLM-L6-v2')
10
 
11
  # Function to compute document embeddings and apply PCA
12
- # Modify the Gradio interface to accept a list of identifiers and texts
13
- def compute_pca(data):
14
- # data is expected to be a list of dictionaries with 'Identifier' and 'Text' keys
 
 
 
 
 
15
  df = pd.DataFrame(data, columns=["Identifier", "Text"])
16
-
17
-
18
  # Remove rows where 'Identifier' or 'Text' is empty or contains only whitespace
19
  valid_entries = df[
20
- (df['Identifier'].str.strip() != '') &
21
- (df['Text'].str.strip() != '')
22
  ]
23
 
24
  if valid_entries.empty:
25
  return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.")
26
 
 
 
 
27
  # Generate embeddings
28
  embeddings = model.encode(valid_entries['Text'].tolist())
29
 
@@ -38,50 +44,63 @@ def compute_pca(data):
38
 
39
  # Plot the PCA result with identifiers as labels
40
  fig = px.scatter(valid_entries, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings')
 
41
  return fig
42
 
43
  def text_editor_app():
44
  with gr.Blocks() as demo:
 
45
  identifier_inputs = []
46
  text_inputs = []
 
47
 
48
  gr.Markdown("### Enter at least two identifier-text pairs:")
49
-
50
- for i in range(4): # Assuming we have 4 entries
51
- with gr.Column():
52
- id_input = gr.Textbox(label=f"Identifier {i+1}")
53
- text_input = gr.Textbox(label=f"Text {i+1}")
 
 
54
  identifier_inputs.append(id_input)
55
  text_inputs.append(text_input)
56
- gr.Markdown("---") # Add a horizontal rule to create a break
57
-
58
- # Button to run the analysis
59
- analyze_button = gr.Button("Run Analysis")
60
 
61
- # Output plot
 
62
  output_plot = gr.Plot(label="PCA Visualization")
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Function to collect inputs and process them
65
  def collect_inputs(*args):
66
- # args will be identifier1, text1, identifier2, text2, ..., identifier4, text4
67
- # So we need to pair them up
68
  data = []
69
  for i in range(0, len(args), 2):
70
  identifier = args[i]
71
  text = args[i+1]
72
  data.append([identifier, text])
73
- return compute_pca(data)
74
-
75
  inputs = []
76
  for id_input, text_input in zip(identifier_inputs, text_inputs):
77
  inputs.extend([id_input, text_input])
78
-
79
  analyze_button.click(fn=collect_inputs, inputs=inputs, outputs=output_plot)
80
-
81
  return demo
82
 
83
-
84
-
85
-
86
  # Launch the app
87
  text_editor_app().launch()
 
 
1
  import gradio as gr
 
2
  import pandas as pd
3
  import plotly.express as px
4
  from sklearn.decomposition import PCA
 
8
  model = SentenceTransformer('all-MiniLM-L6-v2')
9
 
10
  # Function to compute document embeddings and apply PCA
11
+ def compute_pca(*args):
12
+ # args will be identifier1, text1, identifier2, text2, ..., identifierN, textN
13
+ # Pair them up
14
+ data = []
15
+ for i in range(0, len(args), 2):
16
+ identifier = args[i]
17
+ text = args[i+1]
18
+ data.append([identifier, text])
19
  df = pd.DataFrame(data, columns=["Identifier", "Text"])
20
+
 
21
  # Remove rows where 'Identifier' or 'Text' is empty or contains only whitespace
22
  valid_entries = df[
23
+ (df['Identifier'].astype(str).str.strip() != '') &
24
+ (df['Text'].astype(str).str.strip() != '')
25
  ]
26
 
27
  if valid_entries.empty:
28
  return gr.Plot.update(value=None, label="No data to process. Please fill in the boxes.")
29
 
30
+ if len(valid_entries) < 2:
31
+ return gr.Plot.update(value=None, label="At least two texts are required to perform PCA.")
32
+
33
  # Generate embeddings
34
  embeddings = model.encode(valid_entries['Text'].tolist())
35
 
 
44
 
45
  # Plot the PCA result with identifiers as labels
46
  fig = px.scatter(valid_entries, x='PC1', y='PC2', text='Identifier', title='PCA of Text Embeddings')
47
+ fig.update_traces(textposition='top center')
48
  return fig
49
 
50
  def text_editor_app():
51
  with gr.Blocks() as demo:
52
+ num_pairs_visible = gr.State(value=4)
53
  identifier_inputs = []
54
  text_inputs = []
55
+ pair_rows = []
56
 
57
  gr.Markdown("### Enter at least two identifier-text pairs:")
58
+
59
+ with gr.Column() as input_column:
60
+ for i in range(10): # Max 10 pairs
61
+ with gr.Column(visible=(i < 4)) as pair_row:
62
+ id_input = gr.Textbox(label=f"Identifier {i+1}")
63
+ text_input = gr.Textbox(label=f"Text {i+1}")
64
+ gr.Markdown("---") # Add a horizontal rule to create a break
65
  identifier_inputs.append(id_input)
66
  text_inputs.append(text_input)
67
+ pair_rows.append(pair_row)
 
 
 
68
 
69
+ add_pair_btn = gr.Button("Add Text")
70
+ analyze_button = gr.Button("Run Analysis")
71
  output_plot = gr.Plot(label="PCA Visualization")
72
+
73
+ def add_pair(num_visible):
74
+ if num_visible >= len(pair_rows):
75
+ return [gr.update()] * len(pair_rows) + [num_visible] # No more pairs to show
76
+ updates = []
77
+ for idx, pair_row in enumerate(pair_rows):
78
+ if idx < num_visible + 1:
79
+ updates.append(gr.update(visible=True))
80
+ else:
81
+ updates.append(gr.update())
82
+ num_visible += 1
83
+ return updates + [num_visible]
84
+
85
+ add_pair_btn.click(fn=add_pair, inputs=num_pairs_visible, outputs=pair_rows + [num_pairs_visible])
86
+
87
  # Function to collect inputs and process them
88
  def collect_inputs(*args):
 
 
89
  data = []
90
  for i in range(0, len(args), 2):
91
  identifier = args[i]
92
  text = args[i+1]
93
  data.append([identifier, text])
94
+ return compute_pca(*args)
95
+
96
  inputs = []
97
  for id_input, text_input in zip(identifier_inputs, text_inputs):
98
  inputs.extend([id_input, text_input])
99
+
100
  analyze_button.click(fn=collect_inputs, inputs=inputs, outputs=output_plot)
101
+
102
  return demo
103
 
 
 
 
104
  # Launch the app
105
  text_editor_app().launch()
106
+