import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM import re from pydub import AudioSegment import speech_recognition as sr import io # Load model and tokenizer from local fine-tuned directory MODEL_PATH = "Tufan1/BioMedLM-Cardio-Fold2" tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Dictionaries to decode user inputs gender_map = {1: "Female", 2: "Male"} cholesterol_map = {1: "Normal", 2: "High", 3: "Extreme"} glucose_map = {1: "Normal", 2: "High", 3: "Extreme"} binary_map = {0: "No", 1: "Yes"} # Function to predict diagnosis using the LLM def get_prediction(age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active): input_text = f"""Patient Record: - Age: {age} years - Gender: {gender_map[gender]} - Height: {height} cm - Weight: {weight} kg - Systolic BP: {ap_hi} mmHg - Diastolic BP: {ap_lo} mmHg - Cholesterol Level: {cholesterol_map[cholesterol]} - Glucose Level: {glucose_map[glucose]} - Smokes: {binary_map[smoke]} - Alcohol Intake: {binary_map[alco]} - Physically Active: {binary_map[active]} Diagnosis:""" inputs = tokenizer(input_text, return_tensors="pt").to(device) model.eval() with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=4) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) diagnosis = decoded.split("Diagnosis:")[-1].strip() return diagnosis # Function to extract patient features from a phrase or transcribed audio def extract_details_from_text(text): age = int(re.search(r'(\d+)\s*year', text).group(1)) if re.search(r'(\d+)\s*year', text) else None gender = 2 if "man" in text.lower() else (1 if "female" in text.lower() else None) height = int(re.search(r'(\d+)\s*cm', text).group(1)) if re.search(r'(\d+)\s*cm', text) else None weight = int(re.search(r'(\d+)\s*kg', text).group(1)) if re.search(r'(\d+)\s*kg', text) else None bp_match = re.search(r'BP\s*(\d+)[/](\d+)', text) ap_hi, ap_lo = (int(bp_match.group(1)), int(bp_match.group(2))) if bp_match else (None, None) cholesterol = 3 if "peak" in text.lower() else 2 if "elevated" in text.lower() else 1 glucose = 3 if "extreme" in text.lower() else 2 if "high" in text.lower() else 1 smoke = 1 if "smoke" in text.lower() else 0 alco = 1 if "alcohol" in text.lower() else 0 active = 1 if "exercise" in text.lower() or "active" in text.lower() else 0 return age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active # Streamlit UI st.set_page_config(page_title="Cardiovascular Disease Predictor", layout="centered") st.title("🫀 Cardiovascular Disease Predictor (LLM Powered)") st.markdown("This tool uses a fine-tuned BioMedLM model to predict cardiovascular conditions from structured, text, or voice input.") input_mode = st.radio("Choose input method:", ["Manual Input", "Text Phrase", "Audio Upload"]) if input_mode == "Manual Input": age = st.number_input("Age (years)", min_value=1, max_value=120) gender = st.selectbox("Gender", [("Female", 1), ("Male", 2)], format_func=lambda x: x[0])[1] height = st.number_input("Height (cm)", min_value=50, max_value=250) weight = st.number_input("Weight (kg)", min_value=10, max_value=200) ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250) ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150) cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1] glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1] smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1] alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1] active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1] if st.button("Predict Diagnosis"): diagnosis = get_prediction(age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active) st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}") elif input_mode == "Text Phrase": phrase = st.text_area("Enter patient details in natural language:", height=200) if st.button("Extract & Predict"): try: values = extract_details_from_text(phrase) if all(v is not None for v in values): diagnosis = get_prediction(*values) st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}") else: st.warning("Couldn't extract all fields from the text. Please revise.") except Exception as e: st.error(f"Error: {e}") elif input_mode == "Audio Upload": uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A)", type=["wav", "mp3", "m4a"]) if uploaded_file: st.audio(uploaded_file, format='audio/wav') audio = AudioSegment.from_file(uploaded_file) wav_io = io.BytesIO() audio.export(wav_io, format="wav") wav_io.seek(0) recognizer = sr.Recognizer() with sr.AudioFile(wav_io) as source: audio_data = recognizer.record(source) try: text = recognizer.recognize_google(audio_data) st.markdown(f"**Transcribed Text:** _{text}_") values = extract_details_from_text(text) if all(v is not None for v in values): diagnosis = get_prediction(*values) st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}") else: st.warning("Could not extract complete information from audio.") except Exception as e: st.error(f"Audio processing error: {e}")