narainkumbari commited on
Commit
e7bec26
·
1 Parent(s): 129775b

Fix: CPU-safe model for HF Space

Browse files
Files changed (1) hide show
  1. app.py +43 -34
app.py CHANGED
@@ -1,12 +1,22 @@
1
  import streamlit as st
2
- import requests
 
3
  import re
4
  from pydub import AudioSegment
5
  import speech_recognition as sr
6
  import io
7
 
8
- # Backend API endpoint
9
- API_URL = "http://10.24.6.170:8000/predict" # Change to your server IP
 
 
 
 
 
 
 
 
 
10
 
11
  # Dictionaries to decode user inputs
12
  gender_map = {1: "Female", 2: "Male"}
@@ -14,16 +24,31 @@ cholesterol_map = {1: "Normal", 2: "Elevated", 3: "Peak"}
14
  glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
15
  binary_map = {0: "No", 1: "Yes"}
16
 
17
- # Function to call the backend API
18
- def get_prediction_from_api(data):
19
- try:
20
- response = requests.post(API_URL, json=data)
21
- if response.status_code == 200:
22
- return response.json().get("diagnosis", "No diagnosis returned.")
23
- else:
24
- return f"Error {response.status_code}: {response.text}"
25
- except Exception as e:
26
- return f"API request failed: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Function to extract patient features from a phrase or transcribed audio
29
  def extract_details_from_text(text):
@@ -61,20 +86,8 @@ if input_mode == "Manual Input":
61
  active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
62
 
63
  if st.button("Predict Diagnosis"):
64
- payload = {
65
- "age": age,
66
- "gender": gender,
67
- "height": height,
68
- "weight": weight,
69
- "ap_hi": ap_hi,
70
- "ap_lo": ap_lo,
71
- "cholesterol": cholesterol,
72
- "glucose": glucose,
73
- "smoke": smoke,
74
- "alco": alco,
75
- "active": active
76
- }
77
- diagnosis = get_prediction_from_api(payload)
78
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
79
 
80
  elif input_mode == "Text Phrase":
@@ -83,9 +96,7 @@ elif input_mode == "Text Phrase":
83
  try:
84
  values = extract_details_from_text(phrase)
85
  if all(v is not None for v in values):
86
- keys = ["age", "gender", "height", "weight", "ap_hi", "ap_lo", "cholesterol", "glucose", "smoke", "alco", "active"]
87
- payload = dict(zip(keys, values))
88
- diagnosis = get_prediction_from_api(payload)
89
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
90
  else:
91
  st.warning("Couldn't extract all fields from the text. Please revise.")
@@ -93,7 +104,7 @@ elif input_mode == "Text Phrase":
93
  st.error(f"Error: {e}")
94
 
95
  elif input_mode == "Audio Upload":
96
- uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A, MPEG)", type=["wav", "mp3", "m4a", "mpeg"])
97
  if uploaded_file:
98
  st.audio(uploaded_file, format='audio/wav')
99
  audio = AudioSegment.from_file(uploaded_file)
@@ -110,9 +121,7 @@ elif input_mode == "Audio Upload":
110
  st.markdown(f"**Transcribed Text:** _{text}_")
111
  values = extract_details_from_text(text)
112
  if all(v is not None for v in values):
113
- keys = ["age", "gender", "height", "weight", "ap_hi", "ap_lo", "cholesterol", "glucose", "smoke", "alco", "active"]
114
- payload = dict(zip(keys, values))
115
- diagnosis = get_prediction_from_api(payload)
116
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
117
  else:
118
  st.warning("Could not extract complete information from audio.")
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import re
5
  from pydub import AudioSegment
6
  import speech_recognition as sr
7
  import io
8
 
9
+ # Load model and tokenizer from local fine-tuned directory
10
+ MODEL_PATH = "Tufan1/BioMedLM-Cardio-Fold4-CPU"
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
12
+ #model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map=None, low_cpu_mem_usage=True, torch_dtype=torch.float32)
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ MODEL_PATH,
16
+ device_map="auto" if torch.cuda.is_available() else None,
17
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
18
+ ).to(device)
19
+ #model.to(device)
20
 
21
  # Dictionaries to decode user inputs
22
  gender_map = {1: "Female", 2: "Male"}
 
24
  glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
25
  binary_map = {0: "No", 1: "Yes"}
26
 
27
+ # Function to predict diagnosis using the LLM
28
+ def get_prediction(age, gender, height, weight, ap_hi, ap_lo,
29
+ cholesterol, glucose, smoke, alco, active):
30
+ input_text = f"""Patient Record:
31
+ - Age: {age} years
32
+ - Gender: {gender_map[gender]}
33
+ - Height: {height} cm
34
+ - Weight: {weight} kg
35
+ - Systolic BP: {ap_hi} mmHg
36
+ - Diastolic BP: {ap_lo} mmHg
37
+ - Cholesterol Level: {cholesterol_map[cholesterol]}
38
+ - Glucose Level: {glucose_map[glucose]}
39
+ - Smokes: {binary_map[smoke]}
40
+ - Alcohol Intake: {binary_map[alco]}
41
+ - Physically Active: {binary_map[active]}
42
+
43
+ Diagnosis:"""
44
+
45
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
46
+ model.eval()
47
+ with torch.no_grad():
48
+ outputs = model.generate(**inputs, max_new_tokens=4)
49
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ diagnosis = decoded.split("Diagnosis:")[-1].strip()
51
+ return diagnosis
52
 
53
  # Function to extract patient features from a phrase or transcribed audio
54
  def extract_details_from_text(text):
 
86
  active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
87
 
88
  if st.button("Predict Diagnosis"):
89
+ diagnosis = get_prediction(age, gender, height, weight, ap_hi, ap_lo,
90
+ cholesterol, glucose, smoke, alco, active)
 
 
 
 
 
 
 
 
 
 
 
 
91
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
92
 
93
  elif input_mode == "Text Phrase":
 
96
  try:
97
  values = extract_details_from_text(phrase)
98
  if all(v is not None for v in values):
99
+ diagnosis = get_prediction(*values)
 
 
100
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
101
  else:
102
  st.warning("Couldn't extract all fields from the text. Please revise.")
 
104
  st.error(f"Error: {e}")
105
 
106
  elif input_mode == "Audio Upload":
107
+ uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A, MPEG)", type=["wav", "mp3", "m4a","mpeg"])
108
  if uploaded_file:
109
  st.audio(uploaded_file, format='audio/wav')
110
  audio = AudioSegment.from_file(uploaded_file)
 
121
  st.markdown(f"**Transcribed Text:** _{text}_")
122
  values = extract_details_from_text(text)
123
  if all(v is not None for v in values):
124
+ diagnosis = get_prediction(*values)
 
 
125
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
126
  else:
127
  st.warning("Could not extract complete information from audio.")