tnt306's picture
Debug 20
759caee
raw
history blame
4.16 kB
from typing import Dict
from pathlib import Path
import pickle, logging, sys
from typing import Tuple, List, Dict
import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix
import torch
import gradio as gr
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("final_model.pth")
def predict(patient_id: int):
print(f"predict patient {patient_id}")
df = pd.read_csv(f"Patient{patient_id}.csv",
header="infer",
sep=",",
encoding="utf-8",
dtype={'condition': 'str', 'user_key': 'float32'},
keep_default_na=False)
return {"Death": 0.9, "Alive": 0.1}
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(
"""
## Input:
(See examples for file structure)
"""
)
patient_upload_file = gr.File(label="Upload A Patient",
file_types = ['.csv'],
file_count = "single",
height=100)
gr.Markdown(
"""
## Examples - Correct Prediction:
"""
)
with gr.Row():
with gr.Column(variant='panel', min_width=100):
patient_1_input_btn = gr.Button("Patient 1", size="sm")
patient_1_download_btn = gr.DownloadButton(label="Download", value="Patient1.csv", size="sm")
patient_id_1 = gr.Number(value=1, visible=False)
with gr.Column(variant='panel', min_width=100):
patient_2_input_btn = gr.Button("Patient 2", size="sm")
patient_2_download_btn = gr.DownloadButton(label="Download", value="Patient2.csv", size="sm")
patient_id_2 = gr.Number(value=2, visible=False)
with gr.Column(variant='panel', min_width=100):
patient_3_input_btn = gr.Button("Patient 3", size="sm")
patient_3_download_btn = gr.DownloadButton(label="Download", value="Patient3.csv", size="sm")
patient_id_3 = gr.Number(value=3, visible=False)
gr.Markdown(
"""
## Examples - Wrong Prediction:
"""
)
with gr.Row():
with gr.Column(variant='panel', min_width=100):
patient_4_input_btn = gr.Button("Patient 4", size="sm")
patient_4_download_btn = gr.DownloadButton(label="Download", value="Patient4.csv", size="sm")
patient_id_4 = gr.Number(value=4, visible=False)
with gr.Column(variant='panel', min_width=100):
patient_5_input_btn = gr.Button("Patient 5", size="sm")
patient_5_download_btn = gr.DownloadButton(label="Download", value="Patient5.csv", size="sm")
patient_id_5 = gr.Number(value=5, visible=False)
with gr.Column(variant='panel', min_width=100):
patient_6_input_btn = gr.Button("Patient 6", size="sm")
patient_6_download_btn = gr.DownloadButton(label="Download", value="Patient6.csv", size="sm")
patient_id_6 = gr.Number(value=6, visible=False)
with gr.Column():
gr.Markdown(
"""
## Mortality Prediction:
In 24 hours after ICU admission.
"""
)
result = gr.Label(num_top_classes=2, label="Predictions")
# Choose a patient to predict.
patient_1_input_btn.click(fn=predict, inputs=patient_id_1, outputs=result, api_name="predict")
patient_2_input_btn.click(fn=predict, inputs=patient_id_2, outputs=result, api_name="predict")
patient_3_input_btn.click(fn=predict, inputs=patient_id_3, outputs=result, api_name="predict")
demo.launch(debug=True)