Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,6 @@ tokenizer = AutoTokenizer.from_pretrained("roberta-large")
|
|
22 |
checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenizer=tokenizer,
|
23 |
truncation=True, padding='max_length', max_length=512)
|
24 |
|
25 |
-
|
26 |
import gradio as gr
|
27 |
import pandas as pd
|
28 |
import torch
|
@@ -32,11 +31,16 @@ from safetensors import safe_open
|
|
32 |
from transformers import pipeline, AutoTokenizer
|
33 |
import tempfile
|
34 |
|
35 |
-
# We assume the following objects have already been loaded:
|
36 |
# trial_spaces (DataFrame), embedding_model (SentenceTransformer),
|
37 |
# trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
|
38 |
|
39 |
-
def
|
|
|
|
|
|
|
|
|
|
|
40 |
# Encode patient summary
|
41 |
patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
|
42 |
|
@@ -47,12 +51,14 @@ def match_clinical_trials(patient_summary: str):
|
|
47 |
sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
|
48 |
top_indices = sorted_indices[0:10].cpu().numpy()
|
49 |
|
|
|
50 |
relevant_spaces = trial_spaces.iloc[top_indices].this_space
|
51 |
relevant_nctid = trial_spaces.iloc[top_indices].nct_id
|
52 |
relevant_title = trial_spaces.iloc[top_indices].title
|
53 |
relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary
|
54 |
relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria
|
55 |
|
|
|
56 |
analysis = pd.DataFrame({
|
57 |
'patient_summary_query': patient_summary,
|
58 |
'this_space': relevant_spaces,
|
@@ -62,6 +68,7 @@ def match_clinical_trials(patient_summary: str):
|
|
62 |
'trial_eligibility_criteria': relevant_eligibility_criteria
|
63 |
}).reset_index(drop=True)
|
64 |
|
|
|
65 |
analysis['pt_trial_pair'] = (
|
66 |
analysis['this_space']
|
67 |
+ "\nNow here is the patient summary:"
|
@@ -73,7 +80,7 @@ def match_clinical_trials(patient_summary: str):
|
|
73 |
analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
|
74 |
analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
|
75 |
|
76 |
-
#
|
77 |
out_df = analysis[[
|
78 |
'patient_summary_query',
|
79 |
'nct_id',
|
@@ -84,38 +91,91 @@ def match_clinical_trials(patient_summary: str):
|
|
84 |
'trial_checker_result',
|
85 |
'trial_checker_score'
|
86 |
]]
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
def export_results(df: pd.DataFrame):
|
90 |
-
|
|
|
|
|
|
|
91 |
temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
92 |
df.to_csv(temp.name, index=False)
|
93 |
return temp.name
|
94 |
|
|
|
95 |
custom_css = """
|
96 |
#input_box textarea {
|
97 |
width: 600px !important;
|
98 |
height: 250px !important;
|
99 |
}
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
border-collapse: collapse
|
|
|
|
|
105 |
}
|
106 |
|
107 |
-
|
108 |
-
min-width: 100px;
|
109 |
-
max-width: 300px;
|
110 |
-
overflow-wrap: anywhere; /* or 'word-wrap: break-word;' */
|
111 |
-
white-space: pre-wrap; /* or 'white-space: normal;' */
|
112 |
border: 1px solid #ccc;
|
113 |
-
padding:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
}
|
115 |
"""
|
116 |
|
|
|
117 |
with gr.Blocks(css=custom_css) as demo:
|
118 |
-
gr.HTML("<h3>Alpha Version of Clinical Trial Search
|
119 |
gr.HTML("<h3>Based on clinicaltrials.gov cancer trials export 10/31/24</h3>")
|
120 |
|
121 |
patient_summary_input = gr.Textbox(
|
@@ -126,38 +186,27 @@ with gr.Blocks(css=custom_css) as demo:
|
|
126 |
|
127 |
submit_btn = gr.Button("Find Matches")
|
128 |
|
129 |
-
# We'll store the DataFrame in a state
|
130 |
results_state = gr.State()
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
"patient_summary_query",
|
135 |
-
"nct_id",
|
136 |
-
"title",
|
137 |
-
"trial_brief_summary",
|
138 |
-
"eligibility_criteria",
|
139 |
-
"this_space",
|
140 |
-
"trial_checker_result",
|
141 |
-
"trial_checker_score"
|
142 |
-
],
|
143 |
-
elem_id="output_df"
|
144 |
-
)
|
145 |
|
146 |
export_btn = gr.Button("Export Results")
|
147 |
|
148 |
-
#
|
149 |
submit_btn.click(
|
150 |
-
fn=
|
151 |
inputs=patient_summary_input,
|
152 |
-
outputs=[
|
153 |
)
|
154 |
|
155 |
-
#
|
156 |
export_btn.click(
|
157 |
fn=export_results,
|
158 |
inputs=results_state,
|
159 |
outputs=gr.File(label="Download CSV")
|
160 |
)
|
161 |
|
162 |
-
if __name__ ==
|
163 |
demo.launch()
|
|
|
22 |
checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenizer=tokenizer,
|
23 |
truncation=True, padding='max_length', max_length=512)
|
24 |
|
|
|
25 |
import gradio as gr
|
26 |
import pandas as pd
|
27 |
import torch
|
|
|
31 |
from transformers import pipeline, AutoTokenizer
|
32 |
import tempfile
|
33 |
|
34 |
+
# We assume the following objects have already been loaded in your environment:
|
35 |
# trial_spaces (DataFrame), embedding_model (SentenceTransformer),
|
36 |
# trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
|
37 |
|
38 |
+
def match_clinical_trials_html(patient_summary: str):
|
39 |
+
"""
|
40 |
+
Takes in a patient_summary string, computes the top 10 matching trials,
|
41 |
+
and returns a tuple of:
|
42 |
+
(html_table_string, df_for_export)
|
43 |
+
"""
|
44 |
# Encode patient summary
|
45 |
patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
|
46 |
|
|
|
51 |
sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
|
52 |
top_indices = sorted_indices[0:10].cpu().numpy()
|
53 |
|
54 |
+
# Retrieve relevant columns from trial_spaces
|
55 |
relevant_spaces = trial_spaces.iloc[top_indices].this_space
|
56 |
relevant_nctid = trial_spaces.iloc[top_indices].nct_id
|
57 |
relevant_title = trial_spaces.iloc[top_indices].title
|
58 |
relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary
|
59 |
relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria
|
60 |
|
61 |
+
# Build the main DataFrame for analysis
|
62 |
analysis = pd.DataFrame({
|
63 |
'patient_summary_query': patient_summary,
|
64 |
'this_space': relevant_spaces,
|
|
|
68 |
'trial_eligibility_criteria': relevant_eligibility_criteria
|
69 |
}).reset_index(drop=True)
|
70 |
|
71 |
+
# Create a merged text input for the reranking checker
|
72 |
analysis['pt_trial_pair'] = (
|
73 |
analysis['this_space']
|
74 |
+ "\nNow here is the patient summary:"
|
|
|
80 |
analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
|
81 |
analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
|
82 |
|
83 |
+
# Subset (and reorder) the final columns we want
|
84 |
out_df = analysis[[
|
85 |
'patient_summary_query',
|
86 |
'nct_id',
|
|
|
91 |
'trial_checker_result',
|
92 |
'trial_checker_score'
|
93 |
]]
|
94 |
+
|
95 |
+
# Convert that DataFrame to an HTML table
|
96 |
+
html_table = df_to_html(out_df)
|
97 |
+
|
98 |
+
# Return (HTML for display, DataFrame for exporting)
|
99 |
+
return html_table, out_df
|
100 |
+
|
101 |
+
def df_to_html(df: pd.DataFrame) -> str:
|
102 |
+
"""
|
103 |
+
Utility function to convert a DataFrame into an HTML table
|
104 |
+
with wrapping text.
|
105 |
+
"""
|
106 |
+
# Build the table headers
|
107 |
+
header_row = "".join([f"<th>{col}</th>" for col in df.columns])
|
108 |
+
|
109 |
+
# Build the table rows
|
110 |
+
table_rows = []
|
111 |
+
for _, row in df.iterrows():
|
112 |
+
cells = ""
|
113 |
+
for col in df.columns:
|
114 |
+
cell_value = row[col]
|
115 |
+
# Convert to string and replace newlines with <br> (optional)
|
116 |
+
cell_str = str(cell_value).replace("\n", "<br>")
|
117 |
+
cells += f"<td>{cell_str}</td>"
|
118 |
+
table_rows.append(f"<tr>{cells}</tr>")
|
119 |
+
|
120 |
+
table_body = "\n".join(table_rows)
|
121 |
+
|
122 |
+
# Put it all together as an HTML string
|
123 |
+
table_html = f"""
|
124 |
+
<table class="styled-table">
|
125 |
+
<thead><tr>{header_row}</tr></thead>
|
126 |
+
<tbody>
|
127 |
+
{table_body}
|
128 |
+
</tbody>
|
129 |
+
</table>
|
130 |
+
"""
|
131 |
+
return table_html
|
132 |
|
133 |
def export_results(df: pd.DataFrame):
|
134 |
+
"""
|
135 |
+
Saves the DataFrame to a temporary CSV file and returns its path
|
136 |
+
so that Gradio can prompt the user to download it.
|
137 |
+
"""
|
138 |
temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
139 |
df.to_csv(temp.name, index=False)
|
140 |
return temp.name
|
141 |
|
142 |
+
|
143 |
custom_css = """
|
144 |
#input_box textarea {
|
145 |
width: 600px !important;
|
146 |
height: 250px !important;
|
147 |
}
|
148 |
|
149 |
+
/* Make the custom table more readable: allow wrapping text */
|
150 |
+
.styled-table {
|
151 |
+
width: 100%;
|
152 |
+
border-collapse: collapse;
|
153 |
+
table-layout: auto;
|
154 |
+
margin-top: 1em;
|
155 |
}
|
156 |
|
157 |
+
.styled-table th, .styled-table td {
|
|
|
|
|
|
|
|
|
158 |
border: 1px solid #ccc;
|
159 |
+
padding: 8px;
|
160 |
+
vertical-align: top;
|
161 |
+
text-align: left;
|
162 |
+
white-space: pre-wrap; /* Wrap text */
|
163 |
+
overflow-wrap: anywhere; /* Break long text automatically */
|
164 |
+
}
|
165 |
+
|
166 |
+
.styled-table thead tr {
|
167 |
+
background-color: #f2f2f2;
|
168 |
+
font-weight: bold;
|
169 |
+
}
|
170 |
+
|
171 |
+
.styled-table tbody tr:nth-of-type(even) {
|
172 |
+
background-color: #f9f9f9;
|
173 |
}
|
174 |
"""
|
175 |
|
176 |
+
# Build the Gradio interface
|
177 |
with gr.Blocks(css=custom_css) as demo:
|
178 |
+
gr.HTML("<h3>Alpha Version of Clinical Trial Search (HTML Table Output)</h3>")
|
179 |
gr.HTML("<h3>Based on clinicaltrials.gov cancer trials export 10/31/24</h3>")
|
180 |
|
181 |
patient_summary_input = gr.Textbox(
|
|
|
186 |
|
187 |
submit_btn = gr.Button("Find Matches")
|
188 |
|
189 |
+
# We'll store the DataFrame in a state for exporting to CSV
|
190 |
results_state = gr.State()
|
191 |
|
192 |
+
# The output is now HTML, instead of a DataFrame
|
193 |
+
output_html = gr.HTML(label="Results")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
export_btn = gr.Button("Export Results")
|
196 |
|
197 |
+
# When "Find Matches" is clicked, we get (HTML string, DataFrame)
|
198 |
submit_btn.click(
|
199 |
+
fn=match_clinical_trials_html,
|
200 |
inputs=patient_summary_input,
|
201 |
+
outputs=[output_html, results_state]
|
202 |
)
|
203 |
|
204 |
+
# When "Export Results" is clicked, we export the DataFrame as CSV
|
205 |
export_btn.click(
|
206 |
fn=export_results,
|
207 |
inputs=results_state,
|
208 |
outputs=gr.File(label="Download CSV")
|
209 |
)
|
210 |
|
211 |
+
if __name__ == "__main__":
|
212 |
demo.launch()
|