Spaces:
Sleeping
Sleeping
Commit
·
129775b
1
Parent(s):
86e415b
Updated app with remote API inference and multi-input support
Browse files- __pycache__/main.cpython-310.pyc +0 -0
- app.py +35 -39
- main.py +24 -0
__pycache__/main.cpython-310.pyc
ADDED
Binary file (1.21 kB). View file
|
|
app.py
CHANGED
@@ -1,17 +1,12 @@
|
|
1 |
import streamlit as st
|
2 |
-
import
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
import re
|
5 |
from pydub import AudioSegment
|
6 |
import speech_recognition as sr
|
7 |
import io
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
12 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map=None, low_cpu_mem_usage=True, torch_dtype=torch.float32)
|
13 |
-
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
-
#model.to(device)
|
15 |
|
16 |
# Dictionaries to decode user inputs
|
17 |
gender_map = {1: "Female", 2: "Male"}
|
@@ -19,31 +14,16 @@ cholesterol_map = {1: "Normal", 2: "Elevated", 3: "Peak"}
|
|
19 |
glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
|
20 |
binary_map = {0: "No", 1: "Yes"}
|
21 |
|
22 |
-
# Function to
|
23 |
-
def
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
- Cholesterol Level: {cholesterol_map[cholesterol]}
|
33 |
-
- Glucose Level: {glucose_map[glucose]}
|
34 |
-
- Smokes: {binary_map[smoke]}
|
35 |
-
- Alcohol Intake: {binary_map[alco]}
|
36 |
-
- Physically Active: {binary_map[active]}
|
37 |
-
|
38 |
-
Diagnosis:"""
|
39 |
-
|
40 |
-
inputs = tokenizer(input_text, return_tensors="pt")#.to(device)
|
41 |
-
model.eval()
|
42 |
-
with torch.no_grad():
|
43 |
-
outputs = model.generate(**inputs, max_new_tokens=4)
|
44 |
-
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
45 |
-
diagnosis = decoded.split("Diagnosis:")[-1].strip()
|
46 |
-
return diagnosis
|
47 |
|
48 |
# Function to extract patient features from a phrase or transcribed audio
|
49 |
def extract_details_from_text(text):
|
@@ -74,15 +54,27 @@ if input_mode == "Manual Input":
|
|
74 |
weight = st.number_input("Weight (kg)", min_value=10, max_value=200)
|
75 |
ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250)
|
76 |
ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150)
|
77 |
-
cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("
|
78 |
glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
|
79 |
smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
|
80 |
alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
|
81 |
active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
|
82 |
|
83 |
if st.button("Predict Diagnosis"):
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
|
87 |
|
88 |
elif input_mode == "Text Phrase":
|
@@ -91,7 +83,9 @@ elif input_mode == "Text Phrase":
|
|
91 |
try:
|
92 |
values = extract_details_from_text(phrase)
|
93 |
if all(v is not None for v in values):
|
94 |
-
|
|
|
|
|
95 |
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
|
96 |
else:
|
97 |
st.warning("Couldn't extract all fields from the text. Please revise.")
|
@@ -99,7 +93,7 @@ elif input_mode == "Text Phrase":
|
|
99 |
st.error(f"Error: {e}")
|
100 |
|
101 |
elif input_mode == "Audio Upload":
|
102 |
-
uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A)", type=["wav", "mp3", "m4a"])
|
103 |
if uploaded_file:
|
104 |
st.audio(uploaded_file, format='audio/wav')
|
105 |
audio = AudioSegment.from_file(uploaded_file)
|
@@ -116,7 +110,9 @@ elif input_mode == "Audio Upload":
|
|
116 |
st.markdown(f"**Transcribed Text:** _{text}_")
|
117 |
values = extract_details_from_text(text)
|
118 |
if all(v is not None for v in values):
|
119 |
-
|
|
|
|
|
120 |
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
|
121 |
else:
|
122 |
st.warning("Could not extract complete information from audio.")
|
|
|
1 |
import streamlit as st
|
2 |
+
import requests
|
|
|
3 |
import re
|
4 |
from pydub import AudioSegment
|
5 |
import speech_recognition as sr
|
6 |
import io
|
7 |
|
8 |
+
# Backend API endpoint
|
9 |
+
API_URL = "http://10.24.6.170:8000/predict" # Change to your server IP
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Dictionaries to decode user inputs
|
12 |
gender_map = {1: "Female", 2: "Male"}
|
|
|
14 |
glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
|
15 |
binary_map = {0: "No", 1: "Yes"}
|
16 |
|
17 |
+
# Function to call the backend API
|
18 |
+
def get_prediction_from_api(data):
|
19 |
+
try:
|
20 |
+
response = requests.post(API_URL, json=data)
|
21 |
+
if response.status_code == 200:
|
22 |
+
return response.json().get("diagnosis", "No diagnosis returned.")
|
23 |
+
else:
|
24 |
+
return f"Error {response.status_code}: {response.text}"
|
25 |
+
except Exception as e:
|
26 |
+
return f"API request failed: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# Function to extract patient features from a phrase or transcribed audio
|
29 |
def extract_details_from_text(text):
|
|
|
54 |
weight = st.number_input("Weight (kg)", min_value=10, max_value=200)
|
55 |
ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250)
|
56 |
ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150)
|
57 |
+
cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("Elevated", 2), ("Peak", 3)], format_func=lambda x: x[0])[1]
|
58 |
glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
|
59 |
smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
|
60 |
alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
|
61 |
active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
|
62 |
|
63 |
if st.button("Predict Diagnosis"):
|
64 |
+
payload = {
|
65 |
+
"age": age,
|
66 |
+
"gender": gender,
|
67 |
+
"height": height,
|
68 |
+
"weight": weight,
|
69 |
+
"ap_hi": ap_hi,
|
70 |
+
"ap_lo": ap_lo,
|
71 |
+
"cholesterol": cholesterol,
|
72 |
+
"glucose": glucose,
|
73 |
+
"smoke": smoke,
|
74 |
+
"alco": alco,
|
75 |
+
"active": active
|
76 |
+
}
|
77 |
+
diagnosis = get_prediction_from_api(payload)
|
78 |
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
|
79 |
|
80 |
elif input_mode == "Text Phrase":
|
|
|
83 |
try:
|
84 |
values = extract_details_from_text(phrase)
|
85 |
if all(v is not None for v in values):
|
86 |
+
keys = ["age", "gender", "height", "weight", "ap_hi", "ap_lo", "cholesterol", "glucose", "smoke", "alco", "active"]
|
87 |
+
payload = dict(zip(keys, values))
|
88 |
+
diagnosis = get_prediction_from_api(payload)
|
89 |
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
|
90 |
else:
|
91 |
st.warning("Couldn't extract all fields from the text. Please revise.")
|
|
|
93 |
st.error(f"Error: {e}")
|
94 |
|
95 |
elif input_mode == "Audio Upload":
|
96 |
+
uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A, MPEG)", type=["wav", "mp3", "m4a", "mpeg"])
|
97 |
if uploaded_file:
|
98 |
st.audio(uploaded_file, format='audio/wav')
|
99 |
audio = AudioSegment.from_file(uploaded_file)
|
|
|
110 |
st.markdown(f"**Transcribed Text:** _{text}_")
|
111 |
values = extract_details_from_text(text)
|
112 |
if all(v is not None for v in values):
|
113 |
+
keys = ["age", "gender", "height", "weight", "ap_hi", "ap_lo", "cholesterol", "glucose", "smoke", "alco", "active"]
|
114 |
+
payload = dict(zip(keys, values))
|
115 |
+
diagnosis = get_prediction_from_api(payload)
|
116 |
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
|
117 |
else:
|
118 |
st.warning("Could not extract complete information from audio.")
|
main.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main.py
|
2 |
+
from fastapi import FastAPI, Request
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
import torch
|
6 |
+
|
7 |
+
app = FastAPI()
|
8 |
+
|
9 |
+
MODEL_PATH = "Tufan1/BioMedLM-Cardio-Fold1-CPU"
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.float16)
|
12 |
+
|
13 |
+
class PatientData(BaseModel):
|
14 |
+
input_text: str
|
15 |
+
|
16 |
+
@app.post("/predict")
|
17 |
+
def predict(data: PatientData):
|
18 |
+
inputs = tokenizer(data.input_text, return_tensors="pt").to("cuda")
|
19 |
+
model.eval()
|
20 |
+
with torch.no_grad():
|
21 |
+
outputs = model.generate(**inputs, max_new_tokens=4)
|
22 |
+
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
23 |
+
diagnosis = decoded.split("Diagnosis:")[-1].strip()
|
24 |
+
return {"diagnosis": diagnosis}
|