kenlkehl commited on
Commit
3815070
·
verified ·
1 Parent(s): 49dc990

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -13
app.py CHANGED
@@ -21,13 +21,7 @@ tokenizer = AutoTokenizer.from_pretrained("roberta-large")
21
  checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenizer=tokenizer,
22
  truncation=True, padding='max_length', max_length=512)
23
 
24
- import gradio as gr
25
- import pandas as pd
26
- import torch
27
- import torch.nn.functional as F
28
- from sentence_transformers import SentenceTransformer
29
- from safetensors import safe_open
30
- from transformers import pipeline, AutoTokenizer
31
 
32
  # We assume the following objects have already been loaded:
33
  # trial_spaces (DataFrame), embedding_model (SentenceTransformer),
@@ -51,7 +45,7 @@ def match_clinical_trials(patient_summary: str):
51
  relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria
52
 
53
  analysis = pd.DataFrame({
54
- 'patient_summary': patient_summary,
55
  'this_space': relevant_spaces,
56
  'nct_id': relevant_nctid,
57
  'trial_title': relevant_title,
@@ -59,15 +53,16 @@ def match_clinical_trials(patient_summary: str):
59
  'trial_eligibility_criteria': relevant_eligibility_criteria
60
  }).reset_index(drop=True)
61
 
62
- analysis['pt_trial_pair'] = analysis['this_space'] + "\nNow here is the patient summary:" + analysis['patient_summary']
63
 
64
  # Run checker pipeline
65
  classifier_results = checker_pipe(analysis.pt_trial_pair.tolist())
66
  analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
67
  analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
68
 
69
- # Return a subset of columns that are most relevant
70
  return analysis[[
 
71
  'nct_id',
72
  'trial_title',
73
  'trial_brief_summary',
@@ -76,6 +71,12 @@ def match_clinical_trials(patient_summary: str):
76
  'trial_checker_score'
77
  ]]
78
 
 
 
 
 
 
 
79
  custom_css = """
80
  #input_box textarea {
81
  width: 600px !important;
@@ -100,10 +101,21 @@ custom_css = """
100
 
101
  with gr.Blocks(css=custom_css) as demo:
102
  gr.HTML("<h3>Clinical Trial Matcher</h3>")
103
- patient_summary_input = gr.Textbox(label="Enter Patient Summary", elem_id="input_box")
 
 
 
 
 
 
104
  submit_btn = gr.Button("Find Matches")
 
 
 
 
105
  output_df = gr.DataFrame(
106
  headers=[
 
107
  "nct_id",
108
  "trial_title",
109
  "trial_brief_summary",
@@ -113,11 +125,18 @@ with gr.Blocks(css=custom_css) as demo:
113
  ],
114
  elem_id="output_df"
115
  )
 
 
116
 
 
117
  submit_btn.click(fn=match_clinical_trials,
118
  inputs=patient_summary_input,
119
- outputs=output_df)
120
 
 
 
 
 
121
 
122
- if __name__ == "__main__":
123
  demo.launch()
 
21
  checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenizer=tokenizer,
22
  truncation=True, padding='max_length', max_length=512)
23
 
24
+
 
 
 
 
 
 
25
 
26
  # We assume the following objects have already been loaded:
27
  # trial_spaces (DataFrame), embedding_model (SentenceTransformer),
 
45
  relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria
46
 
47
  analysis = pd.DataFrame({
48
+ 'patient_summary_query': patient_summary,
49
  'this_space': relevant_spaces,
50
  'nct_id': relevant_nctid,
51
  'trial_title': relevant_title,
 
53
  'trial_eligibility_criteria': relevant_eligibility_criteria
54
  }).reset_index(drop=True)
55
 
56
+ analysis['pt_trial_pair'] = analysis['this_space'] + "\nNow here is the patient summary:" + analysis['patient_summary_query']
57
 
58
  # Run checker pipeline
59
  classifier_results = checker_pipe(analysis.pt_trial_pair.tolist())
60
  analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
61
  analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
62
 
63
+ # Return the final subset of columns including patient_summary_query as first column
64
  return analysis[[
65
+ 'patient_summary_query',
66
  'nct_id',
67
  'trial_title',
68
  'trial_brief_summary',
 
71
  'trial_checker_score'
72
  ]]
73
 
74
+ def export_results(df: pd.DataFrame):
75
+ # Save the dataframe to a temporary CSV file and return its path
76
+ temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
77
+ df.to_csv(temp.name, index=False)
78
+ return temp.name
79
+
80
  custom_css = """
81
  #input_box textarea {
82
  width: 600px !important;
 
101
 
102
  with gr.Blocks(css=custom_css) as demo:
103
  gr.HTML("<h3>Clinical Trial Matcher</h3>")
104
+
105
+ patient_summary_input = gr.Textbox(
106
+ label="Enter Patient Summary",
107
+ elem_id="input_box",
108
+ value="70M with metastatic lung adenocarcinoma, KRAS G12C mutation, PD-L1 high, previously treated with pembrolizumab."
109
+ )
110
+
111
  submit_btn = gr.Button("Find Matches")
112
+
113
+ # We'll store the DataFrame in a state so we can export it after generation
114
+ results_state = gr.State()
115
+
116
  output_df = gr.DataFrame(
117
  headers=[
118
+ "patient_summary_query",
119
  "nct_id",
120
  "trial_title",
121
  "trial_brief_summary",
 
125
  ],
126
  elem_id="output_df"
127
  )
128
+
129
+ export_btn = gr.Button("Export Results")
130
 
131
+ # On "Find Matches", show the DataFrame and store it in state
132
  submit_btn.click(fn=match_clinical_trials,
133
  inputs=patient_summary_input,
134
+ outputs=[output_df, results_state])
135
 
136
+ # On "Export Results", use the state to create and return a CSV file
137
+ export_btn.click(fn=export_results,
138
+ inputs=results_state,
139
+ outputs=gr.File(label="Download CSV"))
140
 
141
+ if __name__ == 'main':
142
  demo.launch()