narainkumbari commited on
Commit
4678109
·
1 Parent(s): 59732cc

Add Streamlit app for CVD prediction

Browse files
Files changed (2) hide show
  1. app.py +124 -0
  2. requirements.txt +348 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import re
5
+ from pydub import AudioSegment
6
+ import speech_recognition as sr
7
+ import io
8
+
9
+ # Load model and tokenizer from local fine-tuned directory
10
+ MODEL_PATH = "Tufan1/BioMedLM-Cardio-Fold2"
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
12
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+
16
+ # Dictionaries to decode user inputs
17
+ gender_map = {1: "Female", 2: "Male"}
18
+ cholesterol_map = {1: "Normal", 2: "High", 3: "Extreme"}
19
+ glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
20
+ binary_map = {0: "No", 1: "Yes"}
21
+
22
+ # Function to predict diagnosis using the LLM
23
+ def get_prediction(age, gender, height, weight, ap_hi, ap_lo,
24
+ cholesterol, glucose, smoke, alco, active):
25
+ input_text = f"""Patient Record:
26
+ - Age: {age} years
27
+ - Gender: {gender_map[gender]}
28
+ - Height: {height} cm
29
+ - Weight: {weight} kg
30
+ - Systolic BP: {ap_hi} mmHg
31
+ - Diastolic BP: {ap_lo} mmHg
32
+ - Cholesterol Level: {cholesterol_map[cholesterol]}
33
+ - Glucose Level: {glucose_map[glucose]}
34
+ - Smokes: {binary_map[smoke]}
35
+ - Alcohol Intake: {binary_map[alco]}
36
+ - Physically Active: {binary_map[active]}
37
+
38
+ Diagnosis:"""
39
+
40
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
41
+ model.eval()
42
+ with torch.no_grad():
43
+ outputs = model.generate(**inputs, max_new_tokens=4)
44
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ diagnosis = decoded.split("Diagnosis:")[-1].strip()
46
+ return diagnosis
47
+
48
+ # Function to extract patient features from a phrase or transcribed audio
49
+ def extract_details_from_text(text):
50
+ age = int(re.search(r'(\d+)\s*year', text).group(1)) if re.search(r'(\d+)\s*year', text) else None
51
+ gender = 2 if "man" in text.lower() else (1 if "female" in text.lower() else None)
52
+ height = int(re.search(r'(\d+)\s*cm', text).group(1)) if re.search(r'(\d+)\s*cm', text) else None
53
+ weight = int(re.search(r'(\d+)\s*kg', text).group(1)) if re.search(r'(\d+)\s*kg', text) else None
54
+ bp_match = re.search(r'BP\s*(\d+)[/](\d+)', text)
55
+ ap_hi, ap_lo = (int(bp_match.group(1)), int(bp_match.group(2))) if bp_match else (None, None)
56
+ cholesterol = 3 if "peak" in text.lower() else 2 if "elevated" in text.lower() else 1
57
+ glucose = 3 if "extreme" in text.lower() else 2 if "high" in text.lower() else 1
58
+ smoke = 1 if "smoke" in text.lower() else 0
59
+ alco = 1 if "alcohol" in text.lower() else 0
60
+ active = 1 if "exercise" in text.lower() or "active" in text.lower() else 0
61
+ return age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active
62
+
63
+ # Streamlit UI
64
+ st.set_page_config(page_title="Cardiovascular Disease Predictor", layout="centered")
65
+ st.title("🫀 Cardiovascular Disease Predictor (LLM Powered)")
66
+ st.markdown("This tool uses a fine-tuned BioMedLM model to predict cardiovascular conditions from structured, text, or voice input.")
67
+
68
+ input_mode = st.radio("Choose input method:", ["Manual Input", "Text Phrase", "Audio Upload"])
69
+
70
+ if input_mode == "Manual Input":
71
+ age = st.number_input("Age (years)", min_value=1, max_value=120)
72
+ gender = st.selectbox("Gender", [("Female", 1), ("Male", 2)], format_func=lambda x: x[0])[1]
73
+ height = st.number_input("Height (cm)", min_value=50, max_value=250)
74
+ weight = st.number_input("Weight (kg)", min_value=10, max_value=200)
75
+ ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250)
76
+ ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150)
77
+ cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
78
+ glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
79
+ smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
80
+ alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
81
+ active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
82
+
83
+ if st.button("Predict Diagnosis"):
84
+ diagnosis = get_prediction(age, gender, height, weight, ap_hi, ap_lo,
85
+ cholesterol, glucose, smoke, alco, active)
86
+ st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
87
+
88
+ elif input_mode == "Text Phrase":
89
+ phrase = st.text_area("Enter patient details in natural language:", height=200)
90
+ if st.button("Extract & Predict"):
91
+ try:
92
+ values = extract_details_from_text(phrase)
93
+ if all(v is not None for v in values):
94
+ diagnosis = get_prediction(*values)
95
+ st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
96
+ else:
97
+ st.warning("Couldn't extract all fields from the text. Please revise.")
98
+ except Exception as e:
99
+ st.error(f"Error: {e}")
100
+
101
+ elif input_mode == "Audio Upload":
102
+ uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A)", type=["wav", "mp3", "m4a"])
103
+ if uploaded_file:
104
+ st.audio(uploaded_file, format='audio/wav')
105
+ audio = AudioSegment.from_file(uploaded_file)
106
+ wav_io = io.BytesIO()
107
+ audio.export(wav_io, format="wav")
108
+ wav_io.seek(0)
109
+
110
+ recognizer = sr.Recognizer()
111
+ with sr.AudioFile(wav_io) as source:
112
+ audio_data = recognizer.record(source)
113
+
114
+ try:
115
+ text = recognizer.recognize_google(audio_data)
116
+ st.markdown(f"**Transcribed Text:** _{text}_")
117
+ values = extract_details_from_text(text)
118
+ if all(v is not None for v in values):
119
+ diagnosis = get_prediction(*values)
120
+ st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
121
+ else:
122
+ st.warning("Could not extract complete information from audio.")
123
+ except Exception as e:
124
+ st.error(f"Audio processing error: {e}")
requirements.txt ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ eabsl-py==2.1.0
2
+ accelerate==1.6.0
3
+ addict==2.4.0
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.11.14
6
+ aiosignal==1.3.2
7
+ aliyun-python-sdk-core==2.16.0
8
+ aliyun-python-sdk-kms==2.16.5
9
+ annotated-types==0.7.0
10
+ antlr4-python3-runtime==4.9.3
11
+ anyio==4.4.0
12
+ apturl==0.5.2
13
+ argon2-cffi==23.1.0
14
+ argon2-cffi-bindings==21.2.0
15
+ arrow==1.3.0
16
+ asttokens==2.4.1
17
+ astunparse==1.6.3
18
+ async-lru==2.0.4
19
+ async-timeout==5.0.1
20
+ attrs==24.2.0
21
+ babel==2.16.0
22
+ bcrypt==3.2.0
23
+ beautifulsoup4==4.12.3
24
+ bitsandbytes==0.45.5
25
+ black==24.2.0
26
+ bleach==6.1.0
27
+ blinker==1.9.0
28
+ boto3==1.36.25
29
+ botocore==1.36.25
30
+ Brlapi==0.8.3
31
+ cachetools==5.5.2
32
+ certifi==2020.6.20
33
+ cffi==1.17.1
34
+ chardet==4.0.0
35
+ charset-normalizer==3.4.1
36
+ chumpy==0.70
37
+ click==8.1.8
38
+ colorama==0.4.4
39
+ comm==0.2.0
40
+ command-not-found==0.3
41
+ contourpy==1.3.0
42
+ crcmod==1.7
43
+ cryptography==3.4.8
44
+ cupshelpers==1.0
45
+ cycler==0.12.1
46
+ Cython==3.0.11
47
+ datasets==3.4.1
48
+ dbus-python==1.2.18
49
+ debugpy==1.8.0
50
+ decorator==5.2.1
51
+ deepface==0.0.93
52
+ defer==1.0.6
53
+ defusedxml==0.7.1
54
+ depthai==2.23.0.0
55
+ dill==0.3.8
56
+ distlib==0.3.6
57
+ distro==1.7.0
58
+ distro-info==1.1+ubuntu0.2
59
+ docker-pycreds==0.4.0
60
+ duplicity==0.8.21
61
+ exceptiongroup==1.2.0
62
+ executing==2.0.1
63
+ fasteners==0.14.1
64
+ fastjsonschema==2.20.0
65
+ filelock==3.14.0
66
+ fire==0.7.0
67
+ Flask==3.1.0
68
+ flask-cors==5.0.1
69
+ flatbuffers==25.2.10
70
+ fonttools==4.53.1
71
+ fqdn==1.5.1
72
+ frozenlist==1.5.0
73
+ fsspec==2023.9.2
74
+ future==0.18.2
75
+ gast==0.4.0
76
+ gdown==5.2.0
77
+ gitdb==4.0.11
78
+ GitPython==3.1.43
79
+ google-auth==2.38.0
80
+ google-auth-oauthlib==0.4.6
81
+ google-pasta==0.2.0
82
+ greenlet==1.1.2
83
+ grpcio==1.70.0
84
+ gunicorn==23.0.0
85
+ gyp==0.1
86
+ h11==0.14.0
87
+ h5py==3.13.0
88
+ httpcore==1.0.5
89
+ httplib2==0.20.2
90
+ httpx==0.27.2
91
+ huggingface-hub==0.30.1
92
+ hydra-core==1.3.2
93
+ idna==3.3
94
+ importlib-metadata==4.6.4
95
+ install==1.3.5
96
+ iopath==0.1.10
97
+ ipykernel==6.27.1
98
+ ipython==8.18.1
99
+ ipywidgets==8.1.5
100
+ isoduration==20.11.0
101
+ itsdangerous==2.2.0
102
+ jedi==0.19.1
103
+ jeepney==0.7.1
104
+ Jinja2==3.1.2
105
+ jmespath==0.10.0
106
+ joblib==1.4.2
107
+ json-tricks==3.17.3
108
+ json5==0.9.25
109
+ jsonpointer==3.0.0
110
+ jsonschema==4.23.0
111
+ jsonschema-specifications==2023.12.1
112
+ jupyter==1.1.1
113
+ jupyter-console==6.6.3
114
+ jupyter-events==0.10.0
115
+ jupyter-lsp==2.2.5
116
+ jupyter_client==8.6.0
117
+ jupyter_core==5.5.1
118
+ jupyter_server==2.14.2
119
+ jupyter_server_terminals==0.5.3
120
+ jupyterlab==4.2.5
121
+ jupyterlab_pygments==0.3.0
122
+ jupyterlab_server==2.27.3
123
+ jupyterlab_widgets==3.0.13
124
+ kaggle==1.7.4.2
125
+ keras==2.11.0
126
+ Keras-Preprocessing==1.1.2
127
+ keyring==23.5.0
128
+ kiwisolver==1.4.7
129
+ language-selector==0.1
130
+ largestinteriorrectangle==0.2.1
131
+ launchpadlib==1.10.16
132
+ lazr.restfulclient==0.14.4
133
+ lazr.uri==1.0.6
134
+ libclang==18.1.1
135
+ libcst==1.4.0
136
+ llvmlite==0.44.0
137
+ lockfile==0.12.2
138
+ louis==3.20.0
139
+ lz4==4.4.4
140
+ macaroonbakery==1.3.1
141
+ Mako==1.1.3
142
+ Markdown==3.7
143
+ markdown-it-py==3.0.0
144
+ MarkupSafe==3.0.2
145
+ matplotlib==3.9.2
146
+ matplotlib-inline==0.1.6
147
+ mdurl==0.1.2
148
+ mistune==3.0.2
149
+ ml_dtypes==0.5.1
150
+ mmdet==3.0.0
151
+ mmengine==0.7.4
152
+ model-index==0.1.11
153
+ monotonic==1.6
154
+ more-itertools==8.10.0
155
+ moreorless==0.4.0
156
+ mpmath==1.3.0
157
+ msgpack==1.0.3
158
+ mtcnn==1.0.0
159
+ multidict==6.2.0
160
+ multiprocess==0.70.16
161
+ munkres==1.1.4
162
+ mypy-extensions==1.0.0
163
+ namex==0.0.8
164
+ natsort==8.4.0
165
+ nbclient==0.10.0
166
+ nbconvert==7.16.4
167
+ nbformat==5.10.4
168
+ nest-asyncio==1.5.8
169
+ netifaces==0.11.0
170
+ networkx==3.1
171
+ notebook==7.2.2
172
+ notebook_shim==0.2.4
173
+ numba==0.61.0
174
+ numpy==1.26.4
175
+ nvidia-cublas-cu12==12.4.5.8
176
+ nvidia-cuda-cupti-cu12==12.4.127
177
+ nvidia-cuda-nvrtc-cu12==12.4.127
178
+ nvidia-cuda-runtime-cu12==12.4.127
179
+ nvidia-cudnn-cu12==9.1.0.70
180
+ nvidia-cufft-cu12==11.2.1.3
181
+ nvidia-curand-cu12==10.3.5.147
182
+ nvidia-cusolver-cu12==11.6.1.9
183
+ nvidia-cusparse-cu12==12.3.1.170
184
+ nvidia-cusparselt-cu12==0.6.2
185
+ nvidia-nccl-cu12==2.21.5
186
+ nvidia-nvjitlink-cu12==12.4.127
187
+ nvidia-nvtx-cu12==12.4.127
188
+ oauthlib==3.2.0
189
+ olefile==0.46
190
+ omegaconf==2.3.0
191
+ opencv-python==4.10.0.84
192
+ opencv-python-headless==4.10.0.84
193
+ opendatalab==0.0.10
194
+ openmim==0.3.9
195
+ openxlab==0.1.2
196
+ opt_einsum==3.4.0
197
+ optree==0.14.0
198
+ ordered-set==4.1.0
199
+ oss2==2.17.0
200
+ overrides==7.7.0
201
+ packaging==24.2
202
+ pandas==2.2.2
203
+ pandocfilters==1.5.1
204
+ paramiko==2.9.3
205
+ parso==0.8.3
206
+ pathspec==0.12.1
207
+ peft==0.15.1
208
+ pexpect==4.8.0
209
+ pillow==10.4.0
210
+ pipreq==0.4
211
+ platformdirs==3.5.1
212
+ portalocker==2.10.1
213
+ prometheus_client==0.20.0
214
+ prompt-toolkit==3.0.43
215
+ propcache==0.3.0
216
+ protobuf==3.12.4
217
+ psutil==5.9.5
218
+ ptyprocess==0.7.0
219
+ pure-eval==0.2.2
220
+ py-cpuinfo==9.0.0
221
+ pyarrow==19.0.1
222
+ pyasn1==0.6.1
223
+ pyasn1_modules==0.4.1
224
+ pycairo==1.20.1
225
+ pycocotools==2.0.8
226
+ pycparser==2.22
227
+ pycryptodome==3.21.0
228
+ pycups==2.0.1
229
+ pydantic==2.11.1
230
+ pydantic_core==2.33.0
231
+ pydub==0.25.1
232
+ Pygments==2.17.2
233
+ PyGObject==3.42.1
234
+ PyJWT==2.3.0
235
+ pymacaroons==0.13.0
236
+ PyNaCl==1.5.0
237
+ pynvim==0.4.2
238
+ pyparsing==2.4.7
239
+ PyPDF2==3.0.1
240
+ pyRFC3339==1.1
241
+ pysmbc==1.0.23
242
+ PySocks==1.7.1
243
+ python-apt==2.4.0+ubuntu4
244
+ python-dateutil==2.8.2
245
+ python-debian==0.1.43+ubuntu1.1
246
+ python-json-logger==2.0.7
247
+ python-slugify==8.0.4
248
+ python-version==0.0.2
249
+ pytz==2023.4
250
+ pyxdg==0.27
251
+ PyYAML==6.0.2
252
+ pyzmq==25.1.2
253
+ referencing==0.35.1
254
+ regex==2024.11.6
255
+ reportlab==3.6.8
256
+ requests==2.32.3
257
+ requests-oauthlib==2.0.0
258
+ retina-face==0.0.17
259
+ rfc3339-validator==0.1.4
260
+ rfc3986-validator==0.1.1
261
+ rich==13.4.2
262
+ rpds-py==0.20.0
263
+ rsa==4.9
264
+ s3transfer==0.11.2
265
+ safetensors==0.5.3
266
+ scikit-learn==1.6.1
267
+ scipy==1.11.3
268
+ seaborn==0.13.0
269
+ SecretStorage==3.3.1
270
+ Send2Trash==1.8.3
271
+ sentry-sdk==2.25.1
272
+ setproctitle==1.3.5
273
+ shapely==2.0.7
274
+ six==1.16.0
275
+ smmap==5.0.1
276
+ sniffio==1.3.1
277
+ soupsieve==2.6
278
+ SpeechRecognition==3.14.2
279
+ ssh-import-id==5.11
280
+ stack-data==0.6.3
281
+ stdlibs==2024.5.15
282
+ stitching==0.6.1
283
+ supervision==0.23.0
284
+ sympy==1.13.1
285
+ systemd-python==234
286
+ tabulate==0.9.0
287
+ tensorboard==2.11.2
288
+ tensorboard-data-server==0.6.1
289
+ tensorboard-plugin-wit==1.8.1
290
+ tensorflow==2.11.0
291
+ tensorflow-estimator==2.11.0
292
+ tensorflow-io-gcs-filesystem==0.37.1
293
+ termcolor==2.5.0
294
+ terminado==0.18.1
295
+ terminaltables==3.1.10
296
+ text-unidecode==1.3
297
+ tf-keras==2.15.0
298
+ thop==0.1.1.post2209072238
299
+ threadpoolctl==3.6.0
300
+ tiktoken==0.9.0
301
+ timm==1.0.14
302
+ tinycss2==1.3.0
303
+ tokenizers==0.21.1
304
+ toml==0.10.2
305
+ tomli==2.0.1
306
+ tomlkit==0.13.2
307
+ torch==2.6.0
308
+ torchvision==0.21.0
309
+ tornado==6.4
310
+ tqdm==4.67.1
311
+ trailrunner==1.4.0
312
+ traitlets==5.14.0
313
+ transformers==4.51.0
314
+ triton==3.2.0
315
+ types-python-dateutil==2.9.0.20240906
316
+ typing-inspection==0.4.0
317
+ typing_extensions==4.13.1
318
+ tzdata==2023.3
319
+ ubuntu-drivers-common==0.0.0
320
+ ubuntu-pro-client==8001
321
+ ufmt==2.0.0b2
322
+ ufw==0.36.1
323
+ ultralytics==8.2.100
324
+ ultralytics-thop==2.0.8
325
+ unattended-upgrades==0.1
326
+ uri-template==1.3.0
327
+ urllib3==2.3.0
328
+ usb-creator==0.3.7
329
+ usort==1.0.2
330
+ virtualenv==20.23.0
331
+ wadllib==1.3.6
332
+ wandb==0.19.9
333
+ warmup-scheduler==0.3
334
+ wcwidth==0.2.12
335
+ webcolors==24.8.0
336
+ webencodings==0.5.1
337
+ websocket-client==1.8.0
338
+ Werkzeug==3.1.3
339
+ widgetsnbextension==4.0.13
340
+ wrapt==1.14.1
341
+ xdg==5
342
+ xkit==0.0.0
343
+ xtcocotools==1.14.3
344
+ xxhash==3.5.0
345
+ yacs==0.1.8
346
+ yapf==0.43.0
347
+ yarl==1.18.3
348
+ zipp==1.0.0