File size: 3,856 Bytes
e7de3df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding: utf-8 -*-
"""app.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1oJP09Coya1D16dQ_7fpVGlbuLnCWxHTN
"""

import gradio as gr
import joblib
import numpy as np
from tensorflow.keras.models import load_model

# Load models
svr_model = joblib.load('SVR_model.joblib')
rf_model = joblib.load('RandomForestRegressor_model.joblib')
dt_model = joblib.load('best_DecisionTreeRegressor_model.joblib')
dl_model = load_model('best_DeepLearning_model.h5')

# Load preprocessing objects
label_encoder = joblib.load('label_encoder.joblib')
column_transformer = joblib.load('column_transformer.joblib')

def predict_warfarin_dose(gender, race, age, height, weight, diabetes,
                         simvastatin, amiodarone, inr_reported,
                         cyp2c9, vkorc1, model_choice):
    try:
        # Encode Age
        age_encoded = label_encoder.transform([age])

        # Create input list
        inputs = [
            str(gender),
            str(race),
            str(age),
            float(height) if height is not None else 0.0,
            float(weight) if weight is not None else 0.0,
            float(diabetes),
            float(simvastatin),
            float(amiodarone),
            float(inr_reported) if inr_reported is not None else 0.0,
            str(cyp2c9),
            str(vkorc1)
        ]

        # Transform inputs
        inputs_transformed = column_transformer.transform([inputs])
        inputs_transformed[0][-7] = age_encoded[0]
        input_data = np.array(inputs_transformed, dtype=np.float32)

        # Make prediction based on model choice
        if model_choice == 'Deep Learning':
            prediction = dl_model.predict(input_data)[0][0]
        elif model_choice == 'Support Vector Regression':
            prediction = svr_model.predict(input_data)[0]
        elif model_choice == 'Random Forest':
            prediction = rf_model.predict(input_data)[0]
        else:
            prediction = dt_model.predict(input_data)[0]

        return f"Predicted Warfarin Dose: {prediction:.2f} mg/week"

    except Exception as e:
        return f"Error in prediction: {str(e)}"

# Create the Gradio interface
iface = gr.Interface(
    fn=predict_warfarin_dose,
    inputs=[
        gr.Radio(["male", "female"], label="Gender"),
        gr.Dropdown(["Asian", "Black", "White", "Unknown", "Mixed or Missing"], label="Race"),
        gr.Dropdown(["0-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69",
                    "70-79", "80-89", "90+"], label="Age"),
        gr.Number(label="Height (cm)"),
        gr.Number(label="Weight (kg)"),
        gr.Radio([0.0, 1.0], label="Diabetes"),
        gr.Radio([0.0, 1.0], label="Simvastatin (Zocor)"),
        gr.Radio([0.0, 1.0], label="Amiodarone (Cordarone)"),
        gr.Number(label="INR on Reported Therapeutic Dose of Warfarin"),
        gr.Dropdown(["*1/*1", "*1/*2", "*1/*3", "*2/*2", "*2/*3", "*3/*3"],
                   label="Cyp2C9 genotypes"),
        gr.Radio(["A/A", "A/G", "G/G"], label="VKORC1 genotypes"),
        gr.Dropdown(['Decision Tree', 'Support Vector Regression',
                    'Random Forest', 'Deep Learning'], label="Model Selection")
    ],
    outputs=gr.Textbox(label="Prediction Result"),
    title="Warfarin Dosage Prediction System",
    description="""This system predicts the optimal warfarin dosage based on patient characteristics.
                   Enter the required information below and select a model for prediction.""",
    examples=[
        ["male", "Asian", "50-59", 170, 70, 0.0, 0.0, 0.0, 2.5, "*1/*1", "A/G", "Random Forest"],
        ["female", "White", "60-69", 165, 65, 1.0, 1.0, 0.0, 2.8, "*1/*2", "G/G", "Deep Learning"]
    ],
    theme="default"
)

# Launch the interface
if __name__ == "__main__":
    iface.launch()