File size: 4,159 Bytes
6dc443d
47d95c0
759caee
 
6dc443d
759caee
 
 
 
6dc443d
759caee
6dc443d
759caee
 
47d95c0
759caee
 
 
 
 
 
 
 
 
99e8682
 
b4ee158
759caee
3e38893
 
759caee
 
 
 
 
 
99e8682
 
759caee
 
 
 
 
 
 
3e38893
759caee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e38893
759caee
 
 
 
 
 
99e8682
b4ee158
 
759caee
 
 
787f98f
b4ee158
06a1d7b
5ccc81c
759caee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import Dict
from pathlib import Path
import pickle, logging, sys
from typing import Tuple, List, Dict

import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix
import torch
import gradio as gr
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("final_model.pth")

def predict(patient_id: int):
    print(f"predict patient {patient_id}")
    df = pd.read_csv(f"Patient{patient_id}.csv", 
                     header="infer", 
                     sep=",", 
                     encoding="utf-8", 
                     dtype={'condition': 'str', 'user_key': 'float32'}, 
                     keep_default_na=False)
    
    return {"Death": 0.9, "Alive": 0.1}


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(
                """
                ## Input:
                (See examples for file structure)
                """
            )
            patient_upload_file = gr.File(label="Upload A Patient", 
                                          file_types = ['.csv'],
                                          file_count = "single",
                                          height=100)
            gr.Markdown(
                """
                ## Examples - Correct Prediction:
                """
            )
            with gr.Row():
                with gr.Column(variant='panel', min_width=100):
                    patient_1_input_btn = gr.Button("Patient 1", size="sm")
                    patient_1_download_btn = gr.DownloadButton(label="Download", value="Patient1.csv", size="sm")
                    patient_id_1 = gr.Number(value=1, visible=False)
                with gr.Column(variant='panel', min_width=100):
                    patient_2_input_btn = gr.Button("Patient 2", size="sm")
                    patient_2_download_btn = gr.DownloadButton(label="Download", value="Patient2.csv", size="sm")
                    patient_id_2 = gr.Number(value=2, visible=False)
                with gr.Column(variant='panel', min_width=100):
                    patient_3_input_btn = gr.Button("Patient 3", size="sm")
                    patient_3_download_btn = gr.DownloadButton(label="Download", value="Patient3.csv", size="sm")
                    patient_id_3 = gr.Number(value=3, visible=False)
            gr.Markdown(
                """
                ## Examples - Wrong Prediction:
                """
            )
            with gr.Row():
                with gr.Column(variant='panel', min_width=100):
                    patient_4_input_btn = gr.Button("Patient 4", size="sm")
                    patient_4_download_btn = gr.DownloadButton(label="Download", value="Patient4.csv", size="sm")
                    patient_id_4 = gr.Number(value=4, visible=False)
                with gr.Column(variant='panel', min_width=100):
                    patient_5_input_btn = gr.Button("Patient 5", size="sm")
                    patient_5_download_btn = gr.DownloadButton(label="Download", value="Patient5.csv", size="sm")
                    patient_id_5 = gr.Number(value=5, visible=False)
                with gr.Column(variant='panel', min_width=100):
                    patient_6_input_btn = gr.Button("Patient 6", size="sm")
                    patient_6_download_btn = gr.DownloadButton(label="Download", value="Patient6.csv", size="sm")
                    patient_id_6 = gr.Number(value=6, visible=False)
        with gr.Column():
            gr.Markdown(
                """
                ## Mortality Prediction:
                In 24 hours after ICU admission.
                """
            )
            result = gr.Label(num_top_classes=2, label="Predictions")
    
    # Choose a patient to predict.        
    patient_1_input_btn.click(fn=predict, inputs=patient_id_1, outputs=result, api_name="predict")
    patient_2_input_btn.click(fn=predict, inputs=patient_id_2, outputs=result, api_name="predict")
    patient_3_input_btn.click(fn=predict, inputs=patient_id_3, outputs=result, api_name="predict")
    
    
    
demo.launch(debug=True)