Spaces:
Running
Running
from typing import Dict | |
from pathlib import Path | |
import pickle, logging, sys | |
from typing import Tuple, List, Dict | |
import numpy as np | |
import numpy.typing as npt | |
from scipy.sparse import csr_matrix | |
import torch | |
import gradio as gr | |
import pandas as pd | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = torch.jit.load("final_model.pth") | |
def predict(patient_id: int): | |
print(f"predict patient {patient_id}") | |
df = pd.read_csv(f"Patient{patient_id}.csv", | |
header="infer", | |
sep=",", | |
encoding="utf-8", | |
dtype={'condition': 'str', 'user_key': 'float32'}, | |
keep_default_na=False) | |
return {"Death": 0.9, "Alive": 0.1} | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
## Input: | |
(See examples for file structure) | |
""" | |
) | |
patient_upload_file = gr.File(label="Upload A Patient", | |
file_types = ['.csv'], | |
file_count = "single", | |
height=100) | |
gr.Markdown( | |
""" | |
## Examples - Correct Prediction: | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(variant='panel', min_width=100): | |
patient_1_input_btn = gr.Button("Patient 1", size="sm") | |
patient_1_download_btn = gr.DownloadButton(label="Download", value="Patient1.csv", size="sm") | |
patient_id_1 = gr.Number(value=1, visible=False) | |
with gr.Column(variant='panel', min_width=100): | |
patient_2_input_btn = gr.Button("Patient 2", size="sm") | |
patient_2_download_btn = gr.DownloadButton(label="Download", value="Patient2.csv", size="sm") | |
patient_id_2 = gr.Number(value=2, visible=False) | |
with gr.Column(variant='panel', min_width=100): | |
patient_3_input_btn = gr.Button("Patient 3", size="sm") | |
patient_3_download_btn = gr.DownloadButton(label="Download", value="Patient3.csv", size="sm") | |
patient_id_3 = gr.Number(value=3, visible=False) | |
gr.Markdown( | |
""" | |
## Examples - Wrong Prediction: | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(variant='panel', min_width=100): | |
patient_4_input_btn = gr.Button("Patient 4", size="sm") | |
patient_4_download_btn = gr.DownloadButton(label="Download", value="Patient4.csv", size="sm") | |
patient_id_4 = gr.Number(value=4, visible=False) | |
with gr.Column(variant='panel', min_width=100): | |
patient_5_input_btn = gr.Button("Patient 5", size="sm") | |
patient_5_download_btn = gr.DownloadButton(label="Download", value="Patient5.csv", size="sm") | |
patient_id_5 = gr.Number(value=5, visible=False) | |
with gr.Column(variant='panel', min_width=100): | |
patient_6_input_btn = gr.Button("Patient 6", size="sm") | |
patient_6_download_btn = gr.DownloadButton(label="Download", value="Patient6.csv", size="sm") | |
patient_id_6 = gr.Number(value=6, visible=False) | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
## Mortality Prediction: | |
In 24 hours after ICU admission. | |
""" | |
) | |
result = gr.Label(num_top_classes=2, label="Predictions") | |
# Choose a patient to predict. | |
patient_1_input_btn.click(fn=predict, inputs=patient_id_1, outputs=result, api_name="predict") | |
patient_2_input_btn.click(fn=predict, inputs=patient_id_2, outputs=result, api_name="predict") | |
patient_3_input_btn.click(fn=predict, inputs=patient_id_3, outputs=result, api_name="predict") | |
demo.launch(debug=True) | |