Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,6 @@ checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenize
|
|
23 |
truncation=True, padding='max_length', max_length=512)
|
24 |
|
25 |
|
26 |
-
|
27 |
import gradio as gr
|
28 |
import pandas as pd
|
29 |
import torch
|
@@ -37,17 +36,17 @@ import tempfile
|
|
37 |
# trial_spaces (DataFrame), embedding_model (SentenceTransformer),
|
38 |
# trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
|
39 |
|
40 |
-
def
|
41 |
"""
|
42 |
1) Perform the trial matching and classification.
|
43 |
-
2)
|
44 |
-
3)
|
45 |
"""
|
46 |
-
# Encode
|
47 |
patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
|
48 |
|
49 |
# Compute similarities
|
50 |
-
similarities =
|
51 |
|
52 |
# Pull top 10
|
53 |
sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
|
@@ -80,7 +79,7 @@ def match_clinical_trials_text(patient_summary: str):
|
|
80 |
analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
|
81 |
analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
|
82 |
|
83 |
-
# Subset
|
84 |
out_df = analysis[[
|
85 |
'patient_summary_query',
|
86 |
'nct_id',
|
@@ -92,38 +91,91 @@ def match_clinical_trials_text(patient_summary: str):
|
|
92 |
'trial_checker_score'
|
93 |
]]
|
94 |
|
95 |
-
# Convert
|
96 |
-
|
97 |
-
for idx, row in out_df.iterrows():
|
98 |
-
text_block = (
|
99 |
-
f"=== Result #{idx + 1} ===\n"
|
100 |
-
f"Patient Summary: {row['patient_summary_query']}\n"
|
101 |
-
f"NCT ID: {row['nct_id']}\n"
|
102 |
-
f"Title: {row['trial_title']}\n"
|
103 |
-
f"Brief Summary: {row['trial_brief_summary']}\n"
|
104 |
-
f"Eligibility Criteria: {row['trial_eligibility_criteria']}\n"
|
105 |
-
f"Trial Space: {row['this_space']}\n"
|
106 |
-
f"Checker Result: {row['trial_checker_result']}\n"
|
107 |
-
f"Checker Score: {row['trial_checker_score']}\n"
|
108 |
-
"------------------------------\n"
|
109 |
-
)
|
110 |
-
text_output_lines.append(text_block)
|
111 |
|
112 |
-
#
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
#
|
116 |
-
|
|
|
117 |
|
118 |
def export_results(df: pd.DataFrame):
|
119 |
"""
|
120 |
Saves the DataFrame to a temporary CSV file
|
121 |
so Gradio can provide it as a downloadable file.
|
122 |
"""
|
|
|
123 |
temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
124 |
df.to_csv(temp.name, index=False)
|
125 |
return temp.name
|
126 |
|
|
|
127 |
custom_css = """
|
128 |
#input_box textarea {
|
129 |
width: 600px !important;
|
@@ -132,47 +184,44 @@ custom_css = """
|
|
132 |
"""
|
133 |
|
134 |
with gr.Blocks(css=custom_css) as demo:
|
135 |
-
#
|
136 |
gr.HTML("""
|
137 |
-
<h3>Alpha Version of Clinical Trial Search (
|
138 |
<p>Based on clinicaltrials.gov cancer trials export 10/31/24</p>
|
139 |
<p>Queries take approximately 30 seconds to run.</p>
|
140 |
""")
|
141 |
|
142 |
-
# Input box for the patient summary
|
143 |
patient_summary_input = gr.Textbox(
|
144 |
label="Enter Patient Summary",
|
145 |
elem_id="input_box",
|
146 |
value="70M with metastatic lung adenocarcinoma, KRAS G12C mutation, PD-L1 high, previously treated with pembrolizumab."
|
147 |
)
|
148 |
|
149 |
-
# Button to start matching
|
150 |
submit_btn = gr.Button("Find Matches")
|
151 |
|
152 |
-
# We'll store the DataFrame in a state for CSV export
|
153 |
results_state = gr.State()
|
154 |
|
155 |
-
#
|
156 |
-
|
157 |
|
158 |
-
# Button to export the CSV
|
159 |
export_btn = gr.Button("Export Results")
|
160 |
|
161 |
-
# On
|
162 |
submit_btn.click(
|
163 |
-
fn=
|
164 |
inputs=patient_summary_input,
|
165 |
-
outputs=[
|
166 |
)
|
167 |
|
168 |
-
# On
|
169 |
export_btn.click(
|
170 |
fn=export_results,
|
171 |
inputs=results_state,
|
172 |
outputs=gr.File(label="Download CSV")
|
173 |
)
|
174 |
|
175 |
-
# Enable queue
|
176 |
demo.queue()
|
177 |
|
178 |
if __name__ == "__main__":
|
|
|
23 |
truncation=True, padding='max_length', max_length=512)
|
24 |
|
25 |
|
|
|
26 |
import gradio as gr
|
27 |
import pandas as pd
|
28 |
import torch
|
|
|
36 |
# trial_spaces (DataFrame), embedding_model (SentenceTransformer),
|
37 |
# trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
|
38 |
|
39 |
+
def match_clinical_trials_collapsible(patient_summary: str):
|
40 |
"""
|
41 |
1) Perform the trial matching and classification.
|
42 |
+
2) Generate an HTML string with collapsible items for each trial.
|
43 |
+
3) Return (collapsible_html, df_for_export).
|
44 |
"""
|
45 |
+
# Encode user input
|
46 |
patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
|
47 |
|
48 |
# Compute similarities
|
49 |
+
similarities = torch.nn.functional.cosine_similarity(patient_embedding, trial_space_embeddings)
|
50 |
|
51 |
# Pull top 10
|
52 |
sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
|
|
|
79 |
analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
|
80 |
analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
|
81 |
|
82 |
+
# Subset final columns
|
83 |
out_df = analysis[[
|
84 |
'patient_summary_query',
|
85 |
'nct_id',
|
|
|
91 |
'trial_checker_score'
|
92 |
]]
|
93 |
|
94 |
+
# Convert DataFrame to collapsible HTML
|
95 |
+
collapsible_html = df_to_collapsible_html(out_df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
+
# Return the HTML plus the DataFrame for CSV export
|
98 |
+
return collapsible_html, out_df
|
99 |
+
|
100 |
+
def df_to_collapsible_html(df: pd.DataFrame) -> str:
|
101 |
+
"""
|
102 |
+
Creates an HTML string with an accordion-like display.
|
103 |
+
Clicking on an NCT ID + Title header reveals/hides more details.
|
104 |
+
"""
|
105 |
+
# Basic styling for the accordion
|
106 |
+
css = """
|
107 |
+
<style>
|
108 |
+
.accordion-header {
|
109 |
+
cursor: pointer;
|
110 |
+
background-color: #f2f2f2;
|
111 |
+
padding: 8px;
|
112 |
+
margin-bottom: 4px;
|
113 |
+
border: 1px solid #ccc;
|
114 |
+
font-weight: bold;
|
115 |
+
}
|
116 |
+
.accordion-content {
|
117 |
+
display: none;
|
118 |
+
border-left: 2px solid #ccc;
|
119 |
+
margin-left: 10px;
|
120 |
+
padding-left: 10px;
|
121 |
+
padding-top: 4px;
|
122 |
+
padding-bottom: 4px;
|
123 |
+
margin-bottom: 10px;
|
124 |
+
}
|
125 |
+
</style>
|
126 |
+
"""
|
127 |
+
|
128 |
+
# JavaScript for toggling the display of each accordion content
|
129 |
+
script = """
|
130 |
+
<script>
|
131 |
+
function toggleAccordion(contentId) {
|
132 |
+
var content = document.getElementById(contentId);
|
133 |
+
if (content.style.display === "none" || content.style.display === "") {
|
134 |
+
content.style.display = "block";
|
135 |
+
} else {
|
136 |
+
content.style.display = "none";
|
137 |
+
}
|
138 |
+
}
|
139 |
+
</script>
|
140 |
+
"""
|
141 |
+
|
142 |
+
# Build the accordion items
|
143 |
+
accordion_items = []
|
144 |
+
for idx, row in df.iterrows():
|
145 |
+
content_id = f"accordion-content-{idx}"
|
146 |
+
|
147 |
+
header_html = f"""
|
148 |
+
<div class="accordion-header" onclick="toggleAccordion('{content_id}')">
|
149 |
+
[{idx + 1}] NCT ID: {row['nct_id']} - {row['trial_title']}
|
150 |
+
</div>
|
151 |
+
"""
|
152 |
+
|
153 |
+
content_html = f"""
|
154 |
+
<div id="{content_id}" class="accordion-content">
|
155 |
+
<p><strong>Brief Summary:</strong> {row['trial_brief_summary']}</p>
|
156 |
+
<p><strong>Eligibility Criteria:</strong> {row['trial_eligibility_criteria']}</p>
|
157 |
+
<p><strong>Trial Space:</strong> {row['this_space']}</p>
|
158 |
+
<p><strong>Checker Result:</strong> {row['trial_checker_result']}</p>
|
159 |
+
<p><strong>Checker Score:</strong> {row['trial_checker_score']}</p>
|
160 |
+
</div>
|
161 |
+
"""
|
162 |
+
accordion_items.append(header_html + content_html)
|
163 |
|
164 |
+
# Combine everything
|
165 |
+
full_html = css + script + "<div>" + "".join(accordion_items) + "</div>"
|
166 |
+
return full_html
|
167 |
|
168 |
def export_results(df: pd.DataFrame):
|
169 |
"""
|
170 |
Saves the DataFrame to a temporary CSV file
|
171 |
so Gradio can provide it as a downloadable file.
|
172 |
"""
|
173 |
+
import tempfile
|
174 |
temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
175 |
df.to_csv(temp.name, index=False)
|
176 |
return temp.name
|
177 |
|
178 |
+
# Minimal CSS for the input box
|
179 |
custom_css = """
|
180 |
#input_box textarea {
|
181 |
width: 600px !important;
|
|
|
184 |
"""
|
185 |
|
186 |
with gr.Blocks(css=custom_css) as demo:
|
187 |
+
# Intro text
|
188 |
gr.HTML("""
|
189 |
+
<h3>Alpha Version of Clinical Trial Search (Collapsible Results)</h3>
|
190 |
<p>Based on clinicaltrials.gov cancer trials export 10/31/24</p>
|
191 |
<p>Queries take approximately 30 seconds to run.</p>
|
192 |
""")
|
193 |
|
|
|
194 |
patient_summary_input = gr.Textbox(
|
195 |
label="Enter Patient Summary",
|
196 |
elem_id="input_box",
|
197 |
value="70M with metastatic lung adenocarcinoma, KRAS G12C mutation, PD-L1 high, previously treated with pembrolizumab."
|
198 |
)
|
199 |
|
|
|
200 |
submit_btn = gr.Button("Find Matches")
|
201 |
|
202 |
+
# We'll store the DataFrame in a state for CSV export.
|
203 |
results_state = gr.State()
|
204 |
|
205 |
+
# Display the collapsible results in a gr.HTML component
|
206 |
+
output_html = gr.HTML(label="Results")
|
207 |
|
|
|
208 |
export_btn = gr.Button("Export Results")
|
209 |
|
210 |
+
# On "Find Matches", produce (collapsible_html, df)
|
211 |
submit_btn.click(
|
212 |
+
fn=match_clinical_trials_collapsible,
|
213 |
inputs=patient_summary_input,
|
214 |
+
outputs=[output_html, results_state]
|
215 |
)
|
216 |
|
217 |
+
# On "Export Results", convert state (DataFrame) to a downloadable CSV
|
218 |
export_btn.click(
|
219 |
fn=export_results,
|
220 |
inputs=results_state,
|
221 |
outputs=gr.File(label="Download CSV")
|
222 |
)
|
223 |
|
224 |
+
# Enable queue for "Processing..." feedback
|
225 |
demo.queue()
|
226 |
|
227 |
if __name__ == "__main__":
|