kenlkehl commited on
Commit
fd5440d
·
verified ·
1 Parent(s): d1e19bb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -36
app.py CHANGED
@@ -22,7 +22,6 @@ tokenizer = AutoTokenizer.from_pretrained("roberta-large")
22
  checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenizer=tokenizer,
23
  truncation=True, padding='max_length', max_length=512)
24
 
25
-
26
  import gradio as gr
27
  import pandas as pd
28
  import torch
@@ -32,11 +31,16 @@ from safetensors import safe_open
32
  from transformers import pipeline, AutoTokenizer
33
  import tempfile
34
 
35
- # We assume the following objects have already been loaded:
36
  # trial_spaces (DataFrame), embedding_model (SentenceTransformer),
37
  # trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
38
 
39
- def match_clinical_trials(patient_summary: str):
 
 
 
 
 
40
  # Encode patient summary
41
  patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
42
 
@@ -47,12 +51,14 @@ def match_clinical_trials(patient_summary: str):
47
  sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
48
  top_indices = sorted_indices[0:10].cpu().numpy()
49
 
 
50
  relevant_spaces = trial_spaces.iloc[top_indices].this_space
51
  relevant_nctid = trial_spaces.iloc[top_indices].nct_id
52
  relevant_title = trial_spaces.iloc[top_indices].title
53
  relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary
54
  relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria
55
 
 
56
  analysis = pd.DataFrame({
57
  'patient_summary_query': patient_summary,
58
  'this_space': relevant_spaces,
@@ -62,6 +68,7 @@ def match_clinical_trials(patient_summary: str):
62
  'trial_eligibility_criteria': relevant_eligibility_criteria
63
  }).reset_index(drop=True)
64
 
 
65
  analysis['pt_trial_pair'] = (
66
  analysis['this_space']
67
  + "\nNow here is the patient summary:"
@@ -73,7 +80,7 @@ def match_clinical_trials(patient_summary: str):
73
  analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
74
  analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
75
 
76
- # Return the final subset of columns including patient_summary_query as first column
77
  out_df = analysis[[
78
  'patient_summary_query',
79
  'nct_id',
@@ -84,38 +91,91 @@ def match_clinical_trials(patient_summary: str):
84
  'trial_checker_result',
85
  'trial_checker_score'
86
  ]]
87
- return out_df, out_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def export_results(df: pd.DataFrame):
90
- # Save the dataframe to a temporary CSV file and return its path
 
 
 
91
  temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
92
  df.to_csv(temp.name, index=False)
93
  return temp.name
94
 
 
95
  custom_css = """
96
  #input_box textarea {
97
  width: 600px !important;
98
  height: 250px !important;
99
  }
100
 
101
- #output_df table {
102
- width: 100% !important;
103
- table-layout: auto !important;
104
- border-collapse: collapse !important;
 
 
105
  }
106
 
107
- #output_df table td, #output_df table th {
108
- min-width: 100px;
109
- max-width: 300px;
110
- overflow-wrap: anywhere; /* or 'word-wrap: break-word;' */
111
- white-space: pre-wrap; /* or 'white-space: normal;' */
112
  border: 1px solid #ccc;
113
- padding: 4px;
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  }
115
  """
116
 
 
117
  with gr.Blocks(css=custom_css) as demo:
118
- gr.HTML("<h3>Alpha Version of Clinical Trial Search based on MatchMiner-AI models</h3>")
119
  gr.HTML("<h3>Based on clinicaltrials.gov cancer trials export 10/31/24</h3>")
120
 
121
  patient_summary_input = gr.Textbox(
@@ -126,38 +186,27 @@ with gr.Blocks(css=custom_css) as demo:
126
 
127
  submit_btn = gr.Button("Find Matches")
128
 
129
- # We'll store the DataFrame in a state so we can export it after generation
130
  results_state = gr.State()
131
 
132
- output_df = gr.DataFrame(
133
- headers=[
134
- "patient_summary_query",
135
- "nct_id",
136
- "title",
137
- "trial_brief_summary",
138
- "eligibility_criteria",
139
- "this_space",
140
- "trial_checker_result",
141
- "trial_checker_score"
142
- ],
143
- elem_id="output_df"
144
- )
145
 
146
  export_btn = gr.Button("Export Results")
147
 
148
- # On "Find Matches", show the DataFrame and store it in state
149
  submit_btn.click(
150
- fn=match_clinical_trials,
151
  inputs=patient_summary_input,
152
- outputs=[output_df, results_state]
153
  )
154
 
155
- # On "Export Results", use the state to create and return a CSV file
156
  export_btn.click(
157
  fn=export_results,
158
  inputs=results_state,
159
  outputs=gr.File(label="Download CSV")
160
  )
161
 
162
- if __name__ == '__main__':
163
  demo.launch()
 
22
  checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenizer=tokenizer,
23
  truncation=True, padding='max_length', max_length=512)
24
 
 
25
  import gradio as gr
26
  import pandas as pd
27
  import torch
 
31
  from transformers import pipeline, AutoTokenizer
32
  import tempfile
33
 
34
+ # We assume the following objects have already been loaded in your environment:
35
  # trial_spaces (DataFrame), embedding_model (SentenceTransformer),
36
  # trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
37
 
38
+ def match_clinical_trials_html(patient_summary: str):
39
+ """
40
+ Takes in a patient_summary string, computes the top 10 matching trials,
41
+ and returns a tuple of:
42
+ (html_table_string, df_for_export)
43
+ """
44
  # Encode patient summary
45
  patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
46
 
 
51
  sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
52
  top_indices = sorted_indices[0:10].cpu().numpy()
53
 
54
+ # Retrieve relevant columns from trial_spaces
55
  relevant_spaces = trial_spaces.iloc[top_indices].this_space
56
  relevant_nctid = trial_spaces.iloc[top_indices].nct_id
57
  relevant_title = trial_spaces.iloc[top_indices].title
58
  relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary
59
  relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria
60
 
61
+ # Build the main DataFrame for analysis
62
  analysis = pd.DataFrame({
63
  'patient_summary_query': patient_summary,
64
  'this_space': relevant_spaces,
 
68
  'trial_eligibility_criteria': relevant_eligibility_criteria
69
  }).reset_index(drop=True)
70
 
71
+ # Create a merged text input for the reranking checker
72
  analysis['pt_trial_pair'] = (
73
  analysis['this_space']
74
  + "\nNow here is the patient summary:"
 
80
  analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
81
  analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
82
 
83
+ # Subset (and reorder) the final columns we want
84
  out_df = analysis[[
85
  'patient_summary_query',
86
  'nct_id',
 
91
  'trial_checker_result',
92
  'trial_checker_score'
93
  ]]
94
+
95
+ # Convert that DataFrame to an HTML table
96
+ html_table = df_to_html(out_df)
97
+
98
+ # Return (HTML for display, DataFrame for exporting)
99
+ return html_table, out_df
100
+
101
+ def df_to_html(df: pd.DataFrame) -> str:
102
+ """
103
+ Utility function to convert a DataFrame into an HTML table
104
+ with wrapping text.
105
+ """
106
+ # Build the table headers
107
+ header_row = "".join([f"<th>{col}</th>" for col in df.columns])
108
+
109
+ # Build the table rows
110
+ table_rows = []
111
+ for _, row in df.iterrows():
112
+ cells = ""
113
+ for col in df.columns:
114
+ cell_value = row[col]
115
+ # Convert to string and replace newlines with <br> (optional)
116
+ cell_str = str(cell_value).replace("\n", "<br>")
117
+ cells += f"<td>{cell_str}</td>"
118
+ table_rows.append(f"<tr>{cells}</tr>")
119
+
120
+ table_body = "\n".join(table_rows)
121
+
122
+ # Put it all together as an HTML string
123
+ table_html = f"""
124
+ <table class="styled-table">
125
+ <thead><tr>{header_row}</tr></thead>
126
+ <tbody>
127
+ {table_body}
128
+ </tbody>
129
+ </table>
130
+ """
131
+ return table_html
132
 
133
  def export_results(df: pd.DataFrame):
134
+ """
135
+ Saves the DataFrame to a temporary CSV file and returns its path
136
+ so that Gradio can prompt the user to download it.
137
+ """
138
  temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
139
  df.to_csv(temp.name, index=False)
140
  return temp.name
141
 
142
+
143
  custom_css = """
144
  #input_box textarea {
145
  width: 600px !important;
146
  height: 250px !important;
147
  }
148
 
149
+ /* Make the custom table more readable: allow wrapping text */
150
+ .styled-table {
151
+ width: 100%;
152
+ border-collapse: collapse;
153
+ table-layout: auto;
154
+ margin-top: 1em;
155
  }
156
 
157
+ .styled-table th, .styled-table td {
 
 
 
 
158
  border: 1px solid #ccc;
159
+ padding: 8px;
160
+ vertical-align: top;
161
+ text-align: left;
162
+ white-space: pre-wrap; /* Wrap text */
163
+ overflow-wrap: anywhere; /* Break long text automatically */
164
+ }
165
+
166
+ .styled-table thead tr {
167
+ background-color: #f2f2f2;
168
+ font-weight: bold;
169
+ }
170
+
171
+ .styled-table tbody tr:nth-of-type(even) {
172
+ background-color: #f9f9f9;
173
  }
174
  """
175
 
176
+ # Build the Gradio interface
177
  with gr.Blocks(css=custom_css) as demo:
178
+ gr.HTML("<h3>Alpha Version of Clinical Trial Search (HTML Table Output)</h3>")
179
  gr.HTML("<h3>Based on clinicaltrials.gov cancer trials export 10/31/24</h3>")
180
 
181
  patient_summary_input = gr.Textbox(
 
186
 
187
  submit_btn = gr.Button("Find Matches")
188
 
189
+ # We'll store the DataFrame in a state for exporting to CSV
190
  results_state = gr.State()
191
 
192
+ # The output is now HTML, instead of a DataFrame
193
+ output_html = gr.HTML(label="Results")
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  export_btn = gr.Button("Export Results")
196
 
197
+ # When "Find Matches" is clicked, we get (HTML string, DataFrame)
198
  submit_btn.click(
199
+ fn=match_clinical_trials_html,
200
  inputs=patient_summary_input,
201
+ outputs=[output_html, results_state]
202
  )
203
 
204
+ # When "Export Results" is clicked, we export the DataFrame as CSV
205
  export_btn.click(
206
  fn=export_results,
207
  inputs=results_state,
208
  outputs=gr.File(label="Download CSV")
209
  )
210
 
211
+ if __name__ == "__main__":
212
  demo.launch()