kenlkehl commited on
Commit
0bf83d7
·
verified ·
1 Parent(s): a45c425

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -40
app.py CHANGED
@@ -23,7 +23,6 @@ checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenize
23
  truncation=True, padding='max_length', max_length=512)
24
 
25
 
26
-
27
  import gradio as gr
28
  import pandas as pd
29
  import torch
@@ -37,17 +36,17 @@ import tempfile
37
  # trial_spaces (DataFrame), embedding_model (SentenceTransformer),
38
  # trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
39
 
40
- def match_clinical_trials_text(patient_summary: str):
41
  """
42
  1) Perform the trial matching and classification.
43
- 2) Return a free-text representation of the results.
44
- 3) Also return the DataFrame for CSV export in a second output.
45
  """
46
- # Encode the patient summary
47
  patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
48
 
49
  # Compute similarities
50
- similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)
51
 
52
  # Pull top 10
53
  sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
@@ -80,7 +79,7 @@ def match_clinical_trials_text(patient_summary: str):
80
  analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
81
  analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
82
 
83
- # Subset of final columns
84
  out_df = analysis[[
85
  'patient_summary_query',
86
  'nct_id',
@@ -92,38 +91,91 @@ def match_clinical_trials_text(patient_summary: str):
92
  'trial_checker_score'
93
  ]]
94
 
95
- # Convert the DataFrame rows into a free-text summary
96
- text_output_lines = []
97
- for idx, row in out_df.iterrows():
98
- text_block = (
99
- f"=== Result #{idx + 1} ===\n"
100
- f"Patient Summary: {row['patient_summary_query']}\n"
101
- f"NCT ID: {row['nct_id']}\n"
102
- f"Title: {row['trial_title']}\n"
103
- f"Brief Summary: {row['trial_brief_summary']}\n"
104
- f"Eligibility Criteria: {row['trial_eligibility_criteria']}\n"
105
- f"Trial Space: {row['this_space']}\n"
106
- f"Checker Result: {row['trial_checker_result']}\n"
107
- f"Checker Score: {row['trial_checker_score']}\n"
108
- "------------------------------\n"
109
- )
110
- text_output_lines.append(text_block)
111
 
112
- # Combine into a single multi-line string
113
- final_text_output = "".join(text_output_lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- # Return (free text, DataFrame for export)
116
- return final_text_output, out_df
 
117
 
118
  def export_results(df: pd.DataFrame):
119
  """
120
  Saves the DataFrame to a temporary CSV file
121
  so Gradio can provide it as a downloadable file.
122
  """
 
123
  temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
124
  df.to_csv(temp.name, index=False)
125
  return temp.name
126
 
 
127
  custom_css = """
128
  #input_box textarea {
129
  width: 600px !important;
@@ -132,47 +184,44 @@ custom_css = """
132
  """
133
 
134
  with gr.Blocks(css=custom_css) as demo:
135
- # Display some introductory text
136
  gr.HTML("""
137
- <h3>Alpha Version of Clinical Trial Search (Free Text Output)</h3>
138
  <p>Based on clinicaltrials.gov cancer trials export 10/31/24</p>
139
  <p>Queries take approximately 30 seconds to run.</p>
140
  """)
141
 
142
- # Input box for the patient summary
143
  patient_summary_input = gr.Textbox(
144
  label="Enter Patient Summary",
145
  elem_id="input_box",
146
  value="70M with metastatic lung adenocarcinoma, KRAS G12C mutation, PD-L1 high, previously treated with pembrolizumab."
147
  )
148
 
149
- # Button to start matching
150
  submit_btn = gr.Button("Find Matches")
151
 
152
- # We'll store the DataFrame in a state for CSV export
153
  results_state = gr.State()
154
 
155
- # Free-text output (multi-line)
156
- output_text = gr.Textbox(label="Results (Free Text)", lines=20, interactive=False)
157
 
158
- # Button to export the CSV
159
  export_btn = gr.Button("Export Results")
160
 
161
- # On submit, run match_clinical_trials_text
162
  submit_btn.click(
163
- fn=match_clinical_trials_text,
164
  inputs=patient_summary_input,
165
- outputs=[output_text, results_state]
166
  )
167
 
168
- # On export, convert state (DataFrame) to a downloadable CSV
169
  export_btn.click(
170
  fn=export_results,
171
  inputs=results_state,
172
  outputs=gr.File(label="Download CSV")
173
  )
174
 
175
- # Enable queue so there's a visible "Processing..."
176
  demo.queue()
177
 
178
  if __name__ == "__main__":
 
23
  truncation=True, padding='max_length', max_length=512)
24
 
25
 
 
26
  import gradio as gr
27
  import pandas as pd
28
  import torch
 
36
  # trial_spaces (DataFrame), embedding_model (SentenceTransformer),
37
  # trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
38
 
39
+ def match_clinical_trials_collapsible(patient_summary: str):
40
  """
41
  1) Perform the trial matching and classification.
42
+ 2) Generate an HTML string with collapsible items for each trial.
43
+ 3) Return (collapsible_html, df_for_export).
44
  """
45
+ # Encode user input
46
  patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
47
 
48
  # Compute similarities
49
+ similarities = torch.nn.functional.cosine_similarity(patient_embedding, trial_space_embeddings)
50
 
51
  # Pull top 10
52
  sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
 
79
  analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
80
  analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
81
 
82
+ # Subset final columns
83
  out_df = analysis[[
84
  'patient_summary_query',
85
  'nct_id',
 
91
  'trial_checker_score'
92
  ]]
93
 
94
+ # Convert DataFrame to collapsible HTML
95
+ collapsible_html = df_to_collapsible_html(out_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # Return the HTML plus the DataFrame for CSV export
98
+ return collapsible_html, out_df
99
+
100
+ def df_to_collapsible_html(df: pd.DataFrame) -> str:
101
+ """
102
+ Creates an HTML string with an accordion-like display.
103
+ Clicking on an NCT ID + Title header reveals/hides more details.
104
+ """
105
+ # Basic styling for the accordion
106
+ css = """
107
+ <style>
108
+ .accordion-header {
109
+ cursor: pointer;
110
+ background-color: #f2f2f2;
111
+ padding: 8px;
112
+ margin-bottom: 4px;
113
+ border: 1px solid #ccc;
114
+ font-weight: bold;
115
+ }
116
+ .accordion-content {
117
+ display: none;
118
+ border-left: 2px solid #ccc;
119
+ margin-left: 10px;
120
+ padding-left: 10px;
121
+ padding-top: 4px;
122
+ padding-bottom: 4px;
123
+ margin-bottom: 10px;
124
+ }
125
+ </style>
126
+ """
127
+
128
+ # JavaScript for toggling the display of each accordion content
129
+ script = """
130
+ <script>
131
+ function toggleAccordion(contentId) {
132
+ var content = document.getElementById(contentId);
133
+ if (content.style.display === "none" || content.style.display === "") {
134
+ content.style.display = "block";
135
+ } else {
136
+ content.style.display = "none";
137
+ }
138
+ }
139
+ </script>
140
+ """
141
+
142
+ # Build the accordion items
143
+ accordion_items = []
144
+ for idx, row in df.iterrows():
145
+ content_id = f"accordion-content-{idx}"
146
+
147
+ header_html = f"""
148
+ <div class="accordion-header" onclick="toggleAccordion('{content_id}')">
149
+ [{idx + 1}] NCT ID: {row['nct_id']} - {row['trial_title']}
150
+ </div>
151
+ """
152
+
153
+ content_html = f"""
154
+ <div id="{content_id}" class="accordion-content">
155
+ <p><strong>Brief Summary:</strong> {row['trial_brief_summary']}</p>
156
+ <p><strong>Eligibility Criteria:</strong> {row['trial_eligibility_criteria']}</p>
157
+ <p><strong>Trial Space:</strong> {row['this_space']}</p>
158
+ <p><strong>Checker Result:</strong> {row['trial_checker_result']}</p>
159
+ <p><strong>Checker Score:</strong> {row['trial_checker_score']}</p>
160
+ </div>
161
+ """
162
+ accordion_items.append(header_html + content_html)
163
 
164
+ # Combine everything
165
+ full_html = css + script + "<div>" + "".join(accordion_items) + "</div>"
166
+ return full_html
167
 
168
  def export_results(df: pd.DataFrame):
169
  """
170
  Saves the DataFrame to a temporary CSV file
171
  so Gradio can provide it as a downloadable file.
172
  """
173
+ import tempfile
174
  temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
175
  df.to_csv(temp.name, index=False)
176
  return temp.name
177
 
178
+ # Minimal CSS for the input box
179
  custom_css = """
180
  #input_box textarea {
181
  width: 600px !important;
 
184
  """
185
 
186
  with gr.Blocks(css=custom_css) as demo:
187
+ # Intro text
188
  gr.HTML("""
189
+ <h3>Alpha Version of Clinical Trial Search (Collapsible Results)</h3>
190
  <p>Based on clinicaltrials.gov cancer trials export 10/31/24</p>
191
  <p>Queries take approximately 30 seconds to run.</p>
192
  """)
193
 
 
194
  patient_summary_input = gr.Textbox(
195
  label="Enter Patient Summary",
196
  elem_id="input_box",
197
  value="70M with metastatic lung adenocarcinoma, KRAS G12C mutation, PD-L1 high, previously treated with pembrolizumab."
198
  )
199
 
 
200
  submit_btn = gr.Button("Find Matches")
201
 
202
+ # We'll store the DataFrame in a state for CSV export.
203
  results_state = gr.State()
204
 
205
+ # Display the collapsible results in a gr.HTML component
206
+ output_html = gr.HTML(label="Results")
207
 
 
208
  export_btn = gr.Button("Export Results")
209
 
210
+ # On "Find Matches", produce (collapsible_html, df)
211
  submit_btn.click(
212
+ fn=match_clinical_trials_collapsible,
213
  inputs=patient_summary_input,
214
+ outputs=[output_html, results_state]
215
  )
216
 
217
+ # On "Export Results", convert state (DataFrame) to a downloadable CSV
218
  export_btn.click(
219
  fn=export_results,
220
  inputs=results_state,
221
  outputs=gr.File(label="Download CSV")
222
  )
223
 
224
+ # Enable queue for "Processing..." feedback
225
  demo.queue()
226
 
227
  if __name__ == "__main__":