narainkumbari commited on
Commit
129775b
·
1 Parent(s): 86e415b

Updated app with remote API inference and multi-input support

Browse files
Files changed (3) hide show
  1. __pycache__/main.cpython-310.pyc +0 -0
  2. app.py +35 -39
  3. main.py +24 -0
__pycache__/main.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
app.py CHANGED
@@ -1,17 +1,12 @@
1
  import streamlit as st
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import re
5
  from pydub import AudioSegment
6
  import speech_recognition as sr
7
  import io
8
 
9
- # Load model and tokenizer from local fine-tuned directory
10
- MODEL_PATH = "Tufan1/BioMedLM-Cardio-Fold1-CPU"
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
12
- model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map=None, low_cpu_mem_usage=True, torch_dtype=torch.float32)
13
- #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- #model.to(device)
15
 
16
  # Dictionaries to decode user inputs
17
  gender_map = {1: "Female", 2: "Male"}
@@ -19,31 +14,16 @@ cholesterol_map = {1: "Normal", 2: "Elevated", 3: "Peak"}
19
  glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
20
  binary_map = {0: "No", 1: "Yes"}
21
 
22
- # Function to predict diagnosis using the LLM
23
- def get_prediction(age, gender, height, weight, ap_hi, ap_lo,
24
- cholesterol, glucose, smoke, alco, active):
25
- input_text = f"""Patient Record:
26
- - Age: {age} years
27
- - Gender: {gender_map[gender]}
28
- - Height: {height} cm
29
- - Weight: {weight} kg
30
- - Systolic BP: {ap_hi} mmHg
31
- - Diastolic BP: {ap_lo} mmHg
32
- - Cholesterol Level: {cholesterol_map[cholesterol]}
33
- - Glucose Level: {glucose_map[glucose]}
34
- - Smokes: {binary_map[smoke]}
35
- - Alcohol Intake: {binary_map[alco]}
36
- - Physically Active: {binary_map[active]}
37
-
38
- Diagnosis:"""
39
-
40
- inputs = tokenizer(input_text, return_tensors="pt")#.to(device)
41
- model.eval()
42
- with torch.no_grad():
43
- outputs = model.generate(**inputs, max_new_tokens=4)
44
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
- diagnosis = decoded.split("Diagnosis:")[-1].strip()
46
- return diagnosis
47
 
48
  # Function to extract patient features from a phrase or transcribed audio
49
  def extract_details_from_text(text):
@@ -74,15 +54,27 @@ if input_mode == "Manual Input":
74
  weight = st.number_input("Weight (kg)", min_value=10, max_value=200)
75
  ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250)
76
  ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150)
77
- cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
78
  glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
79
  smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
80
  alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
81
  active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
82
 
83
  if st.button("Predict Diagnosis"):
84
- diagnosis = get_prediction(age, gender, height, weight, ap_hi, ap_lo,
85
- cholesterol, glucose, smoke, alco, active)
 
 
 
 
 
 
 
 
 
 
 
 
86
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
87
 
88
  elif input_mode == "Text Phrase":
@@ -91,7 +83,9 @@ elif input_mode == "Text Phrase":
91
  try:
92
  values = extract_details_from_text(phrase)
93
  if all(v is not None for v in values):
94
- diagnosis = get_prediction(*values)
 
 
95
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
96
  else:
97
  st.warning("Couldn't extract all fields from the text. Please revise.")
@@ -99,7 +93,7 @@ elif input_mode == "Text Phrase":
99
  st.error(f"Error: {e}")
100
 
101
  elif input_mode == "Audio Upload":
102
- uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A)", type=["wav", "mp3", "m4a"])
103
  if uploaded_file:
104
  st.audio(uploaded_file, format='audio/wav')
105
  audio = AudioSegment.from_file(uploaded_file)
@@ -116,7 +110,9 @@ elif input_mode == "Audio Upload":
116
  st.markdown(f"**Transcribed Text:** _{text}_")
117
  values = extract_details_from_text(text)
118
  if all(v is not None for v in values):
119
- diagnosis = get_prediction(*values)
 
 
120
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
121
  else:
122
  st.warning("Could not extract complete information from audio.")
 
1
  import streamlit as st
2
+ import requests
 
3
  import re
4
  from pydub import AudioSegment
5
  import speech_recognition as sr
6
  import io
7
 
8
+ # Backend API endpoint
9
+ API_URL = "http://10.24.6.170:8000/predict" # Change to your server IP
 
 
 
 
10
 
11
  # Dictionaries to decode user inputs
12
  gender_map = {1: "Female", 2: "Male"}
 
14
  glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
15
  binary_map = {0: "No", 1: "Yes"}
16
 
17
+ # Function to call the backend API
18
+ def get_prediction_from_api(data):
19
+ try:
20
+ response = requests.post(API_URL, json=data)
21
+ if response.status_code == 200:
22
+ return response.json().get("diagnosis", "No diagnosis returned.")
23
+ else:
24
+ return f"Error {response.status_code}: {response.text}"
25
+ except Exception as e:
26
+ return f"API request failed: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Function to extract patient features from a phrase or transcribed audio
29
  def extract_details_from_text(text):
 
54
  weight = st.number_input("Weight (kg)", min_value=10, max_value=200)
55
  ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250)
56
  ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150)
57
+ cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("Elevated", 2), ("Peak", 3)], format_func=lambda x: x[0])[1]
58
  glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
59
  smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
60
  alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
61
  active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
62
 
63
  if st.button("Predict Diagnosis"):
64
+ payload = {
65
+ "age": age,
66
+ "gender": gender,
67
+ "height": height,
68
+ "weight": weight,
69
+ "ap_hi": ap_hi,
70
+ "ap_lo": ap_lo,
71
+ "cholesterol": cholesterol,
72
+ "glucose": glucose,
73
+ "smoke": smoke,
74
+ "alco": alco,
75
+ "active": active
76
+ }
77
+ diagnosis = get_prediction_from_api(payload)
78
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
79
 
80
  elif input_mode == "Text Phrase":
 
83
  try:
84
  values = extract_details_from_text(phrase)
85
  if all(v is not None for v in values):
86
+ keys = ["age", "gender", "height", "weight", "ap_hi", "ap_lo", "cholesterol", "glucose", "smoke", "alco", "active"]
87
+ payload = dict(zip(keys, values))
88
+ diagnosis = get_prediction_from_api(payload)
89
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
90
  else:
91
  st.warning("Couldn't extract all fields from the text. Please revise.")
 
93
  st.error(f"Error: {e}")
94
 
95
  elif input_mode == "Audio Upload":
96
+ uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A, MPEG)", type=["wav", "mp3", "m4a", "mpeg"])
97
  if uploaded_file:
98
  st.audio(uploaded_file, format='audio/wav')
99
  audio = AudioSegment.from_file(uploaded_file)
 
110
  st.markdown(f"**Transcribed Text:** _{text}_")
111
  values = extract_details_from_text(text)
112
  if all(v is not None for v in values):
113
+ keys = ["age", "gender", "height", "weight", "ap_hi", "ap_lo", "cholesterol", "glucose", "smoke", "alco", "active"]
114
+ payload = dict(zip(keys, values))
115
+ diagnosis = get_prediction_from_api(payload)
116
  st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
117
  else:
118
  st.warning("Could not extract complete information from audio.")
main.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ from fastapi import FastAPI, Request
3
+ from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+
7
+ app = FastAPI()
8
+
9
+ MODEL_PATH = "Tufan1/BioMedLM-Cardio-Fold1-CPU"
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
11
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.float16)
12
+
13
+ class PatientData(BaseModel):
14
+ input_text: str
15
+
16
+ @app.post("/predict")
17
+ def predict(data: PatientData):
18
+ inputs = tokenizer(data.input_text, return_tensors="pt").to("cuda")
19
+ model.eval()
20
+ with torch.no_grad():
21
+ outputs = model.generate(**inputs, max_new_tokens=4)
22
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
+ diagnosis = decoded.split("Diagnosis:")[-1].strip()
24
+ return {"diagnosis": diagnosis}