CVD-Predictor / app.py
narainkumbari's picture
Fix: Load CPU-safe model for HF Space
c1375d5
raw
history blame
6.21 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
from pydub import AudioSegment
import speech_recognition as sr
import io
# ✅ NEW: Update model path (make sure it's uploaded to HF Hub or copied to Space)
MODEL_PATH = "Tufan1/BioMedLM-Cardio-Fold2-CPU"
# ✅ NEW: Load model and tokenizer safely with CPU fallback
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto", # Will load on CPU in HF Space
trust_remote_code=True
)
return tokenizer, model
tokenizer, model = load_model()
# Dictionaries to decode user inputs
gender_map = {1: "Female", 2: "Male"}
cholesterol_map = {1: "Normal", 2: "High", 3: "Extreme"}
glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
binary_map = {0: "No", 1: "Yes"}
# Function to predict diagnosis using the LLM
def get_prediction(age, gender, height, weight, ap_hi, ap_lo,
cholesterol, glucose, smoke, alco, active):
input_text = f"""Patient Record:
- Age: {age} years
- Gender: {gender_map[gender]}
- Height: {height} cm
- Weight: {weight} kg
- Systolic BP: {ap_hi} mmHg
- Diastolic BP: {ap_lo} mmHg
- Cholesterol Level: {cholesterol_map[cholesterol]}
- Glucose Level: {glucose_map[glucose]}
- Smokes: {binary_map[smoke]}
- Alcohol Intake: {binary_map[alco]}
- Physically Active: {binary_map[active]}
Diagnosis:"""
inputs = tokenizer(input_text, return_tensors="pt").to(device)
model.eval()
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=4)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
diagnosis = decoded.split("Diagnosis:")[-1].strip()
return diagnosis
# Function to extract patient features from a phrase or transcribed audio
def extract_details_from_text(text):
age = int(re.search(r'(\d+)\s*year', text).group(1)) if re.search(r'(\d+)\s*year', text) else None
gender = 2 if "man" in text.lower() else (1 if "female" in text.lower() else None)
height = int(re.search(r'(\d+)\s*cm', text).group(1)) if re.search(r'(\d+)\s*cm', text) else None
weight = int(re.search(r'(\d+)\s*kg', text).group(1)) if re.search(r'(\d+)\s*kg', text) else None
bp_match = re.search(r'BP\s*(\d+)[/](\d+)', text)
ap_hi, ap_lo = (int(bp_match.group(1)), int(bp_match.group(2))) if bp_match else (None, None)
cholesterol = 3 if "peak" in text.lower() else 2 if "elevated" in text.lower() else 1
glucose = 3 if "extreme" in text.lower() else 2 if "high" in text.lower() else 1
smoke = 1 if "smoke" in text.lower() else 0
alco = 1 if "alcohol" in text.lower() else 0
active = 1 if "exercise" in text.lower() or "active" in text.lower() else 0
return age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active
# Streamlit UI
st.set_page_config(page_title="Cardiovascular Disease Predictor", layout="centered")
st.title("🫀 Cardiovascular Disease Predictor (LLM Powered)")
st.markdown("This tool uses a fine-tuned BioMedLM model to predict cardiovascular conditions from structured, text, or voice input.")
input_mode = st.radio("Choose input method:", ["Manual Input", "Text Phrase", "Audio Upload"])
if input_mode == "Manual Input":
age = st.number_input("Age (years)", min_value=1, max_value=120)
gender = st.selectbox("Gender", [("Female", 1), ("Male", 2)], format_func=lambda x: x[0])[1]
height = st.number_input("Height (cm)", min_value=50, max_value=250)
weight = st.number_input("Weight (kg)", min_value=10, max_value=200)
ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250)
ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150)
cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("Peak", 2), ("Elevated", 3)], format_func=lambda x: x[0])[1]
glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
if st.button("Predict Diagnosis"):
diagnosis = get_prediction(age, gender, height, weight, ap_hi, ap_lo,
cholesterol, glucose, smoke, alco, active)
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
elif input_mode == "Text Phrase":
phrase = st.text_area("Enter patient details in natural language:", height=200)
if st.button("Extract & Predict"):
try:
values = extract_details_from_text(phrase)
if all(v is not None for v in values):
diagnosis = get_prediction(*values)
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
else:
st.warning("Couldn't extract all fields from the text. Please revise.")
except Exception as e:
st.error(f"Error: {e}")
elif input_mode == "Audio Upload":
uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A)", type=["wav", "mp3", "m4a"])
if uploaded_file:
st.audio(uploaded_file, format='audio/wav')
audio = AudioSegment.from_file(uploaded_file)
wav_io = io.BytesIO()
audio.export(wav_io, format="wav")
wav_io.seek(0)
recognizer = sr.Recognizer()
with sr.AudioFile(wav_io) as source:
audio_data = recognizer.record(source)
try:
text = recognizer.recognize_google(audio_data)
st.markdown(f"**Transcribed Text:** _{text}_")
values = extract_details_from_text(text)
if all(v is not None for v in values):
diagnosis = get_prediction(*values)
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
else:
st.warning("Could not extract complete information from audio.")
except Exception as e:
st.error(f"Audio processing error: {e}")