Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import torch
|
4 |
-
import tempfile
|
5 |
import torch.nn.functional as F
|
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from safetensors import safe_open
|
8 |
from transformers import pipeline, AutoTokenizer
|
@@ -19,43 +19,42 @@ with safe_open("trial_space_embeddings.safetensors", framework="pt") as f:
|
|
19 |
|
20 |
# Load checker pipeline
|
21 |
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
|
22 |
-
checker_pipe = pipeline(
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
# Assume the following are already loaded:
|
34 |
-
# trial_spaces (DataFrame), embedding_model (SentenceTransformer),
|
35 |
-
# trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
|
36 |
-
#
|
37 |
-
# For example:
|
38 |
-
# trial_spaces = pd.read_csv("some_file.csv")
|
39 |
-
# embedding_model = SentenceTransformer("model-name", device="cuda")
|
40 |
-
# trial_space_embeddings = torch.load("trial_space_embeddings.pt")
|
41 |
-
# checker_pipe = pipeline(...)
|
42 |
-
# etc.
|
43 |
-
|
44 |
-
def match_clinical_trials_dropdown(patient_summary: str):
|
45 |
"""
|
46 |
1) Runs the trial matching logic.
|
47 |
-
2) Returns a
|
48 |
-
|
49 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# 1. Encode user input
|
51 |
patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
|
52 |
|
53 |
# 2. Compute similarities
|
54 |
similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)
|
55 |
|
56 |
-
# 3. Pull top
|
57 |
sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
|
58 |
-
top_indices = sorted_indices[
|
59 |
|
60 |
# 4. Build DataFrame
|
61 |
relevant_spaces = trial_spaces.iloc[top_indices].this_space
|
@@ -85,10 +84,10 @@ def match_clinical_trials_dropdown(patient_summary: str):
|
|
85 |
analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
|
86 |
analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
|
87 |
|
88 |
-
#
|
89 |
-
analysis = analysis[analysis.trial_checker_result == 'POSITIVE'].reset_index()
|
90 |
|
91 |
-
#
|
92 |
out_df = analysis[[
|
93 |
'patient_summary_query',
|
94 |
'nct_id',
|
@@ -100,38 +99,41 @@ def match_clinical_trials_dropdown(patient_summary: str):
|
|
100 |
'trial_checker_score'
|
101 |
]]
|
102 |
|
103 |
-
# Build the dropdown choices, e.g., "NCT001 - Some Title"
|
104 |
dropdown_options = []
|
105 |
-
for
|
106 |
-
option_str = f"{
|
107 |
dropdown_options.append(option_str)
|
108 |
|
109 |
-
#
|
110 |
-
|
111 |
-
choices=
|
112 |
-
interactive=True,
|
113 |
-
value=dropdown_options[0]
|
114 |
-
)
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
117 |
|
118 |
def show_selected_trial(selected_option: str, df: pd.DataFrame):
|
119 |
"""
|
120 |
-
1) Given the selected dropdown option, e.g. "NCT001 - Some Title"
|
121 |
2) Find the row in df and build a summary string.
|
122 |
"""
|
123 |
if not selected_option:
|
124 |
return ""
|
125 |
|
126 |
-
# Parse
|
127 |
-
|
|
|
|
|
|
|
|
|
128 |
|
129 |
-
|
130 |
-
row = df.iloc[[int(chosen_index) - 1]]
|
131 |
-
if row.empty:
|
132 |
return "No data found for the selected trial."
|
133 |
|
134 |
-
record =
|
135 |
details = (
|
136 |
f"Patient Summary Query: {record['patient_summary_query']}\n\n"
|
137 |
f"NCT ID: {record['nct_id']}\n"
|
@@ -139,9 +141,8 @@ def show_selected_trial(selected_option: str, df: pd.DataFrame):
|
|
139 |
f"Trial Space: {record['this_space']}\n\n"
|
140 |
f"Trial Checker Result: {record['trial_checker_result']}\n"
|
141 |
f"Trial Checker Score: {record['trial_checker_score']}\n\n"
|
142 |
-
f"
|
143 |
-
f"
|
144 |
-
|
145 |
)
|
146 |
return details
|
147 |
|
@@ -153,7 +154,7 @@ def export_results(df: pd.DataFrame):
|
|
153 |
df.to_csv(temp.name, index=False)
|
154 |
return temp.name
|
155 |
|
156 |
-
# A little CSS for the input
|
157 |
custom_css = """
|
158 |
#input_box textarea {
|
159 |
width: 600px !important;
|
@@ -166,7 +167,12 @@ with gr.Blocks(css=custom_css) as demo:
|
|
166 |
gr.HTML("""
|
167 |
<h3>Demonstration version of clinical trial search based on MatchMiner-AI</h3>
|
168 |
<p>Based on clinicaltrials.gov cancer trials export 10/31/24.</p>
|
169 |
-
<p>Queries take approximately
|
|
|
|
|
|
|
|
|
|
|
170 |
""")
|
171 |
|
172 |
# Textbox for patient summary
|
@@ -176,6 +182,12 @@ with gr.Blocks(css=custom_css) as demo:
|
|
176 |
value="metastatic lung adenocarcinoma, KRAS G12C mutation, PD-L1 high, previously treated with pembrolizumab."
|
177 |
)
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
# Button to run the matching
|
180 |
submit_btn = gr.Button("Find Matches")
|
181 |
|
@@ -204,7 +216,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
204 |
# 1) "Find Matches" => updates the dropdown choices and the state
|
205 |
submit_btn.click(
|
206 |
fn=match_clinical_trials_dropdown,
|
207 |
-
inputs=patient_summary_input,
|
208 |
outputs=[trial_dropdown, results_state]
|
209 |
)
|
210 |
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import torch
|
|
|
4 |
import torch.nn.functional as F
|
5 |
+
import tempfile
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from safetensors import safe_open
|
8 |
from transformers import pipeline, AutoTokenizer
|
|
|
19 |
|
20 |
# Load checker pipeline
|
21 |
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
|
22 |
+
checker_pipe = pipeline(
|
23 |
+
'text-classification',
|
24 |
+
'ksg-dfci/TrialChecker',
|
25 |
+
tokenizer=tokenizer,
|
26 |
+
truncation=True,
|
27 |
+
padding='max_length',
|
28 |
+
max_length=512
|
29 |
+
)
|
30 |
+
|
31 |
+
def match_clinical_trials_dropdown(patient_summary: str, max_results_str: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
"""
|
33 |
1) Runs the trial matching logic.
|
34 |
+
2) Returns a Dropdown (with the matched trials) and a DataFrame (for further use).
|
35 |
+
3) The user-supplied max_results_str is converted to an int (1-50).
|
36 |
"""
|
37 |
+
# Parse the max_results input
|
38 |
+
try:
|
39 |
+
max_results = int(max_results_str)
|
40 |
+
except ValueError:
|
41 |
+
max_results = 10 # if invalid input, default to 10
|
42 |
+
|
43 |
+
# Clamp within [1, 50]
|
44 |
+
if max_results < 1:
|
45 |
+
max_results = 1
|
46 |
+
elif max_results > 50:
|
47 |
+
max_results = 50
|
48 |
+
|
49 |
# 1. Encode user input
|
50 |
patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
|
51 |
|
52 |
# 2. Compute similarities
|
53 |
similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)
|
54 |
|
55 |
+
# 3. Pull top 'max_results'
|
56 |
sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
|
57 |
+
top_indices = sorted_indices[:max_results].cpu().numpy()
|
58 |
|
59 |
# 4. Build DataFrame
|
60 |
relevant_spaces = trial_spaces.iloc[top_indices].this_space
|
|
|
84 |
analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
|
85 |
analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
|
86 |
|
87 |
+
# 7. Restrict to POSITIVE results only
|
88 |
+
analysis = analysis[analysis.trial_checker_result == 'POSITIVE'].reset_index(drop=True)
|
89 |
|
90 |
+
# 8. Final columns
|
91 |
out_df = analysis[[
|
92 |
'patient_summary_query',
|
93 |
'nct_id',
|
|
|
99 |
'trial_checker_score'
|
100 |
]]
|
101 |
|
102 |
+
# Build the dropdown choices, e.g., "1. NCT001 - Some Title"
|
103 |
dropdown_options = []
|
104 |
+
for i, row in out_df.iterrows():
|
105 |
+
option_str = f"{i+1}. {row['nct_id']} - {row['trial_title']}"
|
106 |
dropdown_options.append(option_str)
|
107 |
|
108 |
+
# If we have no results, keep the dropdown empty
|
109 |
+
if len(dropdown_options) == 0:
|
110 |
+
return gr.Dropdown(choices=[], interactive=True, value=None), out_df
|
|
|
|
|
|
|
111 |
|
112 |
+
# Otherwise, pick the first item as the default
|
113 |
+
return (
|
114 |
+
gr.Dropdown(choices=dropdown_options, interactive=True, value=dropdown_options[0]),
|
115 |
+
out_df
|
116 |
+
)
|
117 |
|
118 |
def show_selected_trial(selected_option: str, df: pd.DataFrame):
|
119 |
"""
|
120 |
+
1) Given the selected dropdown option, e.g. "1. NCT001 - Some Title"
|
121 |
2) Find the row in df and build a summary string.
|
122 |
"""
|
123 |
if not selected_option:
|
124 |
return ""
|
125 |
|
126 |
+
# Parse the index from "1. NCT001 - Some Title"
|
127 |
+
chosen_index_str = selected_option.split(".")[0].strip()
|
128 |
+
try:
|
129 |
+
chosen_index = int(chosen_index_str) - 1
|
130 |
+
except ValueError:
|
131 |
+
return "No data found for the selected trial."
|
132 |
|
133 |
+
if chosen_index < 0 or chosen_index >= len(df):
|
|
|
|
|
134 |
return "No data found for the selected trial."
|
135 |
|
136 |
+
record = df.iloc[chosen_index].to_dict()
|
137 |
details = (
|
138 |
f"Patient Summary Query: {record['patient_summary_query']}\n\n"
|
139 |
f"NCT ID: {record['nct_id']}\n"
|
|
|
141 |
f"Trial Space: {record['this_space']}\n\n"
|
142 |
f"Trial Checker Result: {record['trial_checker_result']}\n"
|
143 |
f"Trial Checker Score: {record['trial_checker_score']}\n\n"
|
144 |
+
f"Brief Summary: {record['trial_brief_summary']}\n\n"
|
145 |
+
f"Full Eligibility Criteria: {record['trial_eligibility_criteria']}\n\n"
|
|
|
146 |
)
|
147 |
return details
|
148 |
|
|
|
154 |
df.to_csv(temp.name, index=False)
|
155 |
return temp.name
|
156 |
|
157 |
+
# A little CSS for the input boxes
|
158 |
custom_css = """
|
159 |
#input_box textarea {
|
160 |
width: 600px !important;
|
|
|
167 |
gr.HTML("""
|
168 |
<h3>Demonstration version of clinical trial search based on MatchMiner-AI</h3>
|
169 |
<p>Based on clinicaltrials.gov cancer trials export 10/31/24.</p>
|
170 |
+
<p>Queries take approximately 30 seconds to run per ten results returned,
|
171 |
+
since demo is running on a small CPU instance.</p>
|
172 |
+
<p>Disclaimers:</p>
|
173 |
+
<p>1. Not a clinical decision support tool</p>
|
174 |
+
<p>2. AI-extracted trial "spaces" and candidate matches may contain errors</p>
|
175 |
+
<p>3. Will not necessarily return all trials that match a given query</p>
|
176 |
""")
|
177 |
|
178 |
# Textbox for patient summary
|
|
|
182 |
value="metastatic lung adenocarcinoma, KRAS G12C mutation, PD-L1 high, previously treated with pembrolizumab."
|
183 |
)
|
184 |
|
185 |
+
# Textbox for max results
|
186 |
+
max_results_input = gr.Textbox(
|
187 |
+
label="Enter the maximum number of results to return (1-50)",
|
188 |
+
value="10" # default
|
189 |
+
)
|
190 |
+
|
191 |
# Button to run the matching
|
192 |
submit_btn = gr.Button("Find Matches")
|
193 |
|
|
|
216 |
# 1) "Find Matches" => updates the dropdown choices and the state
|
217 |
submit_btn.click(
|
218 |
fn=match_clinical_trials_dropdown,
|
219 |
+
inputs=[patient_summary_input, max_results_input],
|
220 |
outputs=[trial_dropdown, results_state]
|
221 |
)
|
222 |
|