kenlkehl commited on
Commit
c3ce722
·
verified ·
1 Parent(s): 0262ad2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -54
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import torch
4
- import tempfile
5
  import torch.nn.functional as F
 
6
  from sentence_transformers import SentenceTransformer
7
  from safetensors import safe_open
8
  from transformers import pipeline, AutoTokenizer
@@ -19,43 +19,42 @@ with safe_open("trial_space_embeddings.safetensors", framework="pt") as f:
19
 
20
  # Load checker pipeline
21
  tokenizer = AutoTokenizer.from_pretrained("roberta-large")
22
- checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenizer=tokenizer,
23
- truncation=True, padding='max_length', max_length=512)
24
-
25
-
26
-
27
- import gradio as gr
28
- import pandas as pd
29
- import torch
30
- import torch.nn.functional as F
31
- import tempfile
32
-
33
- # Assume the following are already loaded:
34
- # trial_spaces (DataFrame), embedding_model (SentenceTransformer),
35
- # trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
36
- #
37
- # For example:
38
- # trial_spaces = pd.read_csv("some_file.csv")
39
- # embedding_model = SentenceTransformer("model-name", device="cuda")
40
- # trial_space_embeddings = torch.load("trial_space_embeddings.pt")
41
- # checker_pipe = pipeline(...)
42
- # etc.
43
-
44
- def match_clinical_trials_dropdown(patient_summary: str):
45
  """
46
  1) Runs the trial matching logic.
47
- 2) Returns a gr.update(...) for the dropdown (setting its choices),
48
- plus a DataFrame for further use.
49
  """
 
 
 
 
 
 
 
 
 
 
 
 
50
  # 1. Encode user input
51
  patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
52
 
53
  # 2. Compute similarities
54
  similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)
55
 
56
- # 3. Pull top 10
57
  sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
58
- top_indices = sorted_indices[0:20].cpu().numpy()
59
 
60
  # 4. Build DataFrame
61
  relevant_spaces = trial_spaces.iloc[top_indices].this_space
@@ -85,10 +84,10 @@ def match_clinical_trials_dropdown(patient_summary: str):
85
  analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
86
  analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
87
 
88
- # restrict to trials that pass Checker
89
- analysis = analysis[analysis.trial_checker_result == 'POSITIVE'].reset_index()
90
 
91
- # 7. Final columns
92
  out_df = analysis[[
93
  'patient_summary_query',
94
  'nct_id',
@@ -100,38 +99,41 @@ def match_clinical_trials_dropdown(patient_summary: str):
100
  'trial_checker_score'
101
  ]]
102
 
103
- # Build the dropdown choices, e.g., "NCT001 - Some Title"
104
  dropdown_options = []
105
- for this_index, row in out_df.iterrows():
106
- option_str = f"{this_index+1}. {row['nct_id']} - {row['trial_title']}"
107
  dropdown_options.append(option_str)
108
 
109
- # Return an update for the dropdown (choices + clear any initial value)
110
- dropdown_update = gr.Dropdown(
111
- choices=dropdown_options,
112
- interactive=True,
113
- value=dropdown_options[0]
114
- )
115
 
116
- return dropdown_update, out_df
 
 
 
 
117
 
118
  def show_selected_trial(selected_option: str, df: pd.DataFrame):
119
  """
120
- 1) Given the selected dropdown option, e.g. "NCT001 - Some Title"
121
  2) Find the row in df and build a summary string.
122
  """
123
  if not selected_option:
124
  return ""
125
 
126
- # Parse NCT ID from "NCT001 - Some Title"
127
- chosen_index = selected_option.split(".")[0].strip()
 
 
 
 
128
 
129
- #row = df[df['nct_id'] == nct_id]
130
- row = df.iloc[[int(chosen_index) - 1]]
131
- if row.empty:
132
  return "No data found for the selected trial."
133
 
134
- record = row.iloc[0].to_dict()
135
  details = (
136
  f"Patient Summary Query: {record['patient_summary_query']}\n\n"
137
  f"NCT ID: {record['nct_id']}\n"
@@ -139,9 +141,8 @@ def show_selected_trial(selected_option: str, df: pd.DataFrame):
139
  f"Trial Space: {record['this_space']}\n\n"
140
  f"Trial Checker Result: {record['trial_checker_result']}\n"
141
  f"Trial Checker Score: {record['trial_checker_score']}\n\n"
142
- f"Trial Brief Summary: {record['trial_brief_summary']}\n\n"
143
- f"Trial Full Eligibility Criteria: {record['trial_eligibility_criteria']}\n\n"
144
-
145
  )
146
  return details
147
 
@@ -153,7 +154,7 @@ def export_results(df: pd.DataFrame):
153
  df.to_csv(temp.name, index=False)
154
  return temp.name
155
 
156
- # A little CSS for the input box
157
  custom_css = """
158
  #input_box textarea {
159
  width: 600px !important;
@@ -166,7 +167,12 @@ with gr.Blocks(css=custom_css) as demo:
166
  gr.HTML("""
167
  <h3>Demonstration version of clinical trial search based on MatchMiner-AI</h3>
168
  <p>Based on clinicaltrials.gov cancer trials export 10/31/24.</p>
169
- <p>Queries take approximately 60 seconds to run (demo is running on a small CPU instance).</p>
 
 
 
 
 
170
  """)
171
 
172
  # Textbox for patient summary
@@ -176,6 +182,12 @@ with gr.Blocks(css=custom_css) as demo:
176
  value="metastatic lung adenocarcinoma, KRAS G12C mutation, PD-L1 high, previously treated with pembrolizumab."
177
  )
178
 
 
 
 
 
 
 
179
  # Button to run the matching
180
  submit_btn = gr.Button("Find Matches")
181
 
@@ -204,7 +216,7 @@ with gr.Blocks(css=custom_css) as demo:
204
  # 1) "Find Matches" => updates the dropdown choices and the state
205
  submit_btn.click(
206
  fn=match_clinical_trials_dropdown,
207
- inputs=patient_summary_input,
208
  outputs=[trial_dropdown, results_state]
209
  )
210
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import torch
 
4
  import torch.nn.functional as F
5
+ import tempfile
6
  from sentence_transformers import SentenceTransformer
7
  from safetensors import safe_open
8
  from transformers import pipeline, AutoTokenizer
 
19
 
20
  # Load checker pipeline
21
  tokenizer = AutoTokenizer.from_pretrained("roberta-large")
22
+ checker_pipe = pipeline(
23
+ 'text-classification',
24
+ 'ksg-dfci/TrialChecker',
25
+ tokenizer=tokenizer,
26
+ truncation=True,
27
+ padding='max_length',
28
+ max_length=512
29
+ )
30
+
31
+ def match_clinical_trials_dropdown(patient_summary: str, max_results_str: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
  1) Runs the trial matching logic.
34
+ 2) Returns a Dropdown (with the matched trials) and a DataFrame (for further use).
35
+ 3) The user-supplied max_results_str is converted to an int (1-50).
36
  """
37
+ # Parse the max_results input
38
+ try:
39
+ max_results = int(max_results_str)
40
+ except ValueError:
41
+ max_results = 10 # if invalid input, default to 10
42
+
43
+ # Clamp within [1, 50]
44
+ if max_results < 1:
45
+ max_results = 1
46
+ elif max_results > 50:
47
+ max_results = 50
48
+
49
  # 1. Encode user input
50
  patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
51
 
52
  # 2. Compute similarities
53
  similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)
54
 
55
+ # 3. Pull top 'max_results'
56
  sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
57
+ top_indices = sorted_indices[:max_results].cpu().numpy()
58
 
59
  # 4. Build DataFrame
60
  relevant_spaces = trial_spaces.iloc[top_indices].this_space
 
84
  analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
85
  analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
86
 
87
+ # 7. Restrict to POSITIVE results only
88
+ analysis = analysis[analysis.trial_checker_result == 'POSITIVE'].reset_index(drop=True)
89
 
90
+ # 8. Final columns
91
  out_df = analysis[[
92
  'patient_summary_query',
93
  'nct_id',
 
99
  'trial_checker_score'
100
  ]]
101
 
102
+ # Build the dropdown choices, e.g., "1. NCT001 - Some Title"
103
  dropdown_options = []
104
+ for i, row in out_df.iterrows():
105
+ option_str = f"{i+1}. {row['nct_id']} - {row['trial_title']}"
106
  dropdown_options.append(option_str)
107
 
108
+ # If we have no results, keep the dropdown empty
109
+ if len(dropdown_options) == 0:
110
+ return gr.Dropdown(choices=[], interactive=True, value=None), out_df
 
 
 
111
 
112
+ # Otherwise, pick the first item as the default
113
+ return (
114
+ gr.Dropdown(choices=dropdown_options, interactive=True, value=dropdown_options[0]),
115
+ out_df
116
+ )
117
 
118
  def show_selected_trial(selected_option: str, df: pd.DataFrame):
119
  """
120
+ 1) Given the selected dropdown option, e.g. "1. NCT001 - Some Title"
121
  2) Find the row in df and build a summary string.
122
  """
123
  if not selected_option:
124
  return ""
125
 
126
+ # Parse the index from "1. NCT001 - Some Title"
127
+ chosen_index_str = selected_option.split(".")[0].strip()
128
+ try:
129
+ chosen_index = int(chosen_index_str) - 1
130
+ except ValueError:
131
+ return "No data found for the selected trial."
132
 
133
+ if chosen_index < 0 or chosen_index >= len(df):
 
 
134
  return "No data found for the selected trial."
135
 
136
+ record = df.iloc[chosen_index].to_dict()
137
  details = (
138
  f"Patient Summary Query: {record['patient_summary_query']}\n\n"
139
  f"NCT ID: {record['nct_id']}\n"
 
141
  f"Trial Space: {record['this_space']}\n\n"
142
  f"Trial Checker Result: {record['trial_checker_result']}\n"
143
  f"Trial Checker Score: {record['trial_checker_score']}\n\n"
144
+ f"Brief Summary: {record['trial_brief_summary']}\n\n"
145
+ f"Full Eligibility Criteria: {record['trial_eligibility_criteria']}\n\n"
 
146
  )
147
  return details
148
 
 
154
  df.to_csv(temp.name, index=False)
155
  return temp.name
156
 
157
+ # A little CSS for the input boxes
158
  custom_css = """
159
  #input_box textarea {
160
  width: 600px !important;
 
167
  gr.HTML("""
168
  <h3>Demonstration version of clinical trial search based on MatchMiner-AI</h3>
169
  <p>Based on clinicaltrials.gov cancer trials export 10/31/24.</p>
170
+ <p>Queries take approximately 30 seconds to run per ten results returned,
171
+ since demo is running on a small CPU instance.</p>
172
+ <p>Disclaimers:</p>
173
+ <p>1. Not a clinical decision support tool</p>
174
+ <p>2. AI-extracted trial "spaces" and candidate matches may contain errors</p>
175
+ <p>3. Will not necessarily return all trials that match a given query</p>
176
  """)
177
 
178
  # Textbox for patient summary
 
182
  value="metastatic lung adenocarcinoma, KRAS G12C mutation, PD-L1 high, previously treated with pembrolizumab."
183
  )
184
 
185
+ # Textbox for max results
186
+ max_results_input = gr.Textbox(
187
+ label="Enter the maximum number of results to return (1-50)",
188
+ value="10" # default
189
+ )
190
+
191
  # Button to run the matching
192
  submit_btn = gr.Button("Find Matches")
193
 
 
216
  # 1) "Find Matches" => updates the dropdown choices and the state
217
  submit_btn.click(
218
  fn=match_clinical_trials_dropdown,
219
+ inputs=[patient_summary_input, max_results_input],
220
  outputs=[trial_dropdown, results_state]
221
  )
222