Spaces:
Sleeping
Sleeping
File size: 4,159 Bytes
6dc443d 47d95c0 759caee 6dc443d 759caee 6dc443d 759caee 6dc443d 759caee 47d95c0 759caee 99e8682 b4ee158 759caee 3e38893 759caee 99e8682 759caee 3e38893 759caee 3e38893 759caee 99e8682 b4ee158 759caee 787f98f b4ee158 06a1d7b 5ccc81c 759caee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from typing import Dict
from pathlib import Path
import pickle, logging, sys
from typing import Tuple, List, Dict
import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix
import torch
import gradio as gr
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("final_model.pth")
def predict(patient_id: int):
print(f"predict patient {patient_id}")
df = pd.read_csv(f"Patient{patient_id}.csv",
header="infer",
sep=",",
encoding="utf-8",
dtype={'condition': 'str', 'user_key': 'float32'},
keep_default_na=False)
return {"Death": 0.9, "Alive": 0.1}
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(
"""
## Input:
(See examples for file structure)
"""
)
patient_upload_file = gr.File(label="Upload A Patient",
file_types = ['.csv'],
file_count = "single",
height=100)
gr.Markdown(
"""
## Examples - Correct Prediction:
"""
)
with gr.Row():
with gr.Column(variant='panel', min_width=100):
patient_1_input_btn = gr.Button("Patient 1", size="sm")
patient_1_download_btn = gr.DownloadButton(label="Download", value="Patient1.csv", size="sm")
patient_id_1 = gr.Number(value=1, visible=False)
with gr.Column(variant='panel', min_width=100):
patient_2_input_btn = gr.Button("Patient 2", size="sm")
patient_2_download_btn = gr.DownloadButton(label="Download", value="Patient2.csv", size="sm")
patient_id_2 = gr.Number(value=2, visible=False)
with gr.Column(variant='panel', min_width=100):
patient_3_input_btn = gr.Button("Patient 3", size="sm")
patient_3_download_btn = gr.DownloadButton(label="Download", value="Patient3.csv", size="sm")
patient_id_3 = gr.Number(value=3, visible=False)
gr.Markdown(
"""
## Examples - Wrong Prediction:
"""
)
with gr.Row():
with gr.Column(variant='panel', min_width=100):
patient_4_input_btn = gr.Button("Patient 4", size="sm")
patient_4_download_btn = gr.DownloadButton(label="Download", value="Patient4.csv", size="sm")
patient_id_4 = gr.Number(value=4, visible=False)
with gr.Column(variant='panel', min_width=100):
patient_5_input_btn = gr.Button("Patient 5", size="sm")
patient_5_download_btn = gr.DownloadButton(label="Download", value="Patient5.csv", size="sm")
patient_id_5 = gr.Number(value=5, visible=False)
with gr.Column(variant='panel', min_width=100):
patient_6_input_btn = gr.Button("Patient 6", size="sm")
patient_6_download_btn = gr.DownloadButton(label="Download", value="Patient6.csv", size="sm")
patient_id_6 = gr.Number(value=6, visible=False)
with gr.Column():
gr.Markdown(
"""
## Mortality Prediction:
In 24 hours after ICU admission.
"""
)
result = gr.Label(num_top_classes=2, label="Predictions")
# Choose a patient to predict.
patient_1_input_btn.click(fn=predict, inputs=patient_id_1, outputs=result, api_name="predict")
patient_2_input_btn.click(fn=predict, inputs=patient_id_2, outputs=result, api_name="predict")
patient_3_input_btn.click(fn=predict, inputs=patient_id_3, outputs=result, api_name="predict")
demo.launch(debug=True)
|