Jaamie Maarsh Joy Martin commited on
Commit
30a1164
Β·
0 Parent(s):

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +475 -0
  2. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1BmTzCgYHoIX81jKTqf4ImJaKRRbxgoTS
8
+ """
9
+
10
+
11
+ import os
12
+ import csv
13
+ import pandas as pd
14
+ import plotly.express as px
15
+ from datetime import datetime
16
+ import torch
17
+ import faiss
18
+ import numpy as np
19
+ import gradio as gr
20
+ from google.colab import drive
21
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
22
+ from sentence_transformers import SentenceTransformer
23
+ from peft import PeftModel
24
+ from huggingface_hub import login
25
+ from transformers import pipeline as hf_pipeline
26
+ from fpdf import FPDF
27
+ import uuid
28
+ import textwrap
29
+ from dotenv import load_dotenv
30
+ try:
31
+ import whisper
32
+ except ImportError:
33
+ os.system("pip install -U openai-whisper")
34
+ import whisper
35
+
36
+ # Load Whisper model here
37
+ whisper_model = whisper.load_model("base")
38
+
39
+ load_dotenv()
40
+
41
+ hf_token = os.getenv("HF_TOKEN")
42
+ resend_api_key = os.getenv("RESEND_API_KEY")
43
+
44
+ login(token=hf_token)
45
+
46
+ # Mount Google Drive
47
+ drive.mount('/content/drive')
48
+
49
+ # -------------------------------
50
+ # πŸ”§ Configuration
51
+ # -------------------------------
52
+ base_model_path = "google/gemma-2-9b-it"
53
+ peft_model_path = "Jaamie/gemma-mental-health-qlora"
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+
56
+ embedding_model_bge = "BAAI/bge-base-en-v1.5"
57
+ save_path_bge = "/content/drive/MyDrive/models/bge-base-en-v1.5"
58
+ faiss_index_path = "/content/qa_faiss_embedding.index"
59
+ chunked_text_path = "/content/chunked_text_RAG_text.txt"
60
+ READER_MODEL_NAME = "google/gemma-2-9b-it"
61
+ log_file_path = "./diagnosis_logs.csv"
62
+ feedback_file_path = "./feedback_logs.csv"
63
+
64
+
65
+ # -------------------------------
66
+ # πŸ”§ Logging setup
67
+ # -------------------------------
68
+ if not os.path.exists(log_file_path):
69
+ with open(log_file_path, "w", newline="", encoding="utf-8") as f:
70
+ writer = csv.writer(f)
71
+ writer.writerow(["timestamp", "input_type", "query", "diagnosis", "confidence_score", "status"])
72
+
73
+ # -------------------------------
74
+ # πŸ”§ Feedback setup
75
+ # -------------------------------
76
+ if not os.path.exists(feedback_file_path):
77
+ with open(feedback_file_path, "w", newline="", encoding="utf-8") as f:
78
+ writer = csv.writer(f)
79
+ writer.writerow([
80
+ "feedback_id", "timestamp", "input_type", "query",
81
+ "diagnosis", "confidence_score", "status", "feedback"
82
+ ])
83
+
84
+ # Ensure directory exists
85
+ os.makedirs(save_path_bge, exist_ok=True)
86
+
87
+ # -------------------------------
88
+ # πŸ”§ Model setup
89
+ # -------------------------------
90
+
91
+ # Load Sentence Transformer Model
92
+ if not os.path.exists(os.path.join(save_path_bge, "config.json")):
93
+ print("Saving model to Google Drive...")
94
+ embedding_model = SentenceTransformer(embedding_model_bge)
95
+ embedding_model.save(save_path_bge)
96
+ print("Model saved successfully!")
97
+ else:
98
+ print("Loading model from Google Drive...")
99
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
100
+ embedding_model = SentenceTransformer(save_path_bge, device=device)
101
+
102
+ # Load FAISS Index
103
+ faiss_index = faiss.read_index(faiss_index_path)
104
+ print("FAISS index loaded successfully!")
105
+
106
+ # Load chunked text
107
+ def load_chunked_text():
108
+ with open(chunked_text_path, "r", encoding="utf-8") as f:
109
+ return f.read().split("\n\n---\n\n")
110
+
111
+ chunked_text = load_chunked_text()
112
+ print(f"Loaded {len(chunked_text)} text chunks.")
113
+
114
+
115
+ # loading model for emotion classifier
116
+ emotion_result = {}
117
+ emotion_classifier = hf_pipeline("text-classification", model="nateraw/bert-base-uncased-emotion")
118
+
119
+
120
+ # -------------------------------
121
+ # 🧠 Load base model + LoRA adapter
122
+ # -------------------------------
123
+ # base_model = AutoModelForCausalLM.from_pretrained(
124
+ # base_model_path,
125
+ # torch_dtype=torch.float16,
126
+ # device_map="auto" # Use accelerate for smart placement
127
+ # )
128
+
129
+ # # Load the LoRA adapter on top of the base model
130
+ # diagnosis_model = PeftModel.from_pretrained(
131
+ # base_model,
132
+ # peft_model_path
133
+ # ).to(device)
134
+
135
+ # # Load tokenizer from the same fine-tuned repo
136
+ # diagnosis_tokenizer = AutoTokenizer.from_pretrained(peft_model_path)
137
+
138
+ # # Set model to evaluation mode
139
+ # diagnosis_model.eval()
140
+
141
+ # print("βœ… Model & tokenizer loaded successfully.")
142
+
143
+ # # Create text-generation pipeline WITHOUT `device` arg
144
+ # READER_LLM = pipeline(
145
+ # model=diagnosis_model,
146
+ # tokenizer=diagnosis_tokenizer,
147
+ # task="text-generation",
148
+ # do_sample=True,
149
+ # temperature=0.2,
150
+ # repetition_penalty=1.1,
151
+ # return_full_text=False,
152
+ # max_new_tokens=500
153
+ # )
154
+
155
+ device = 0 if torch.cuda.is_available() else -1
156
+ tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME)
157
+ model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME).to(device)
158
+ READER_LLM = pipeline(
159
+ model=model,
160
+ tokenizer=tokenizer,
161
+ task="text-generation",
162
+ do_sample=True,
163
+ temperature=0.2,
164
+ repetition_penalty=1.1,
165
+ return_full_text=False,
166
+ max_new_tokens=500,
167
+ device=device,
168
+ )
169
+ # -------------------------------
170
+ # πŸ”§ Whisper Model Setup
171
+ # -------------------------------
172
+
173
+ def process_whisper_query(audio):
174
+ try:
175
+ audio_data = whisper.load_audio(audio)
176
+ audio_data = whisper.pad_or_trim(audio_data)
177
+ mel = whisper.log_mel_spectrogram(audio_data).to(whisper_model.device)
178
+ result = whisper_model.decode(mel, whisper.DecodingOptions(fp16=False))
179
+ transcribed_text = result.text.strip()
180
+ response, download_path = process_query(transcribed_text, input_type="voice")
181
+ return response, download_path
182
+ except Exception as e:
183
+ return f"⚠️ Error processing audio: {str(e)}", None
184
+
185
+
186
+ def extract_diagnosis(response_text: str) -> str:
187
+ for line in response_text.splitlines():
188
+ if "Diagnosed Mental Disorder" in line:
189
+ return line.split(":")[-1].strip()
190
+ return "Unknown"
191
+
192
+ def process_query(user_query, input_type="text"):
193
+ # Embed the query
194
+ query_embedding = embedding_model.encode(user_query, normalize_embeddings=True)
195
+ query_embedding = np.array([query_embedding], dtype=np.float32)
196
+
197
+ # Search FAISS index
198
+ k = 5 # Retrieve top 5 relevant docs
199
+ distances, indices = faiss_index.search(query_embedding, k)
200
+ retrieved_docs = [chunked_text[i] for i in indices[0]]
201
+
202
+ # Construct context
203
+ context = "\nExtracted documents:\n" + "".join([f"Document {i}:::\n{doc}\n" for i, doc in enumerate(retrieved_docs)])
204
+
205
+ # Detect emotion
206
+ emotion_result = emotion_classifier(user_query)[0]
207
+ print(f"Detected emotion: {emotion_result}")
208
+ emotion = emotion_result['label']
209
+ value = emotion_result['score']
210
+ # Define RAG prompt
211
+ prompt_in_chat_format = [
212
+ {"role": "user", "content": f"""
213
+ You are an AI assistant specialized in diagnosing mental disorders in humans.
214
+ Using the information contained in the context, answer the question comprehensively.
215
+
216
+ The **Diagnosed Mental Disorder** should be only one from the list provided.
217
+ [Normal, Depression, Suicidal, Anxiety, Stress, Bi-Polar, Personality Disorder]
218
+
219
+ Your response must include:
220
+ 1. **Diagnosed Mental Disorder**
221
+ 2. **Detected emotion** {emotion}
222
+ 3. **Intensity of emotion** {value}
223
+ 3. **Matching Symptoms**
224
+ 4. **Personalized Treatment**
225
+ 5. **Helpline Numbers**
226
+ 6. **Source Link** (if applicable)
227
+
228
+ If a disorder cannot be determined, return **Diagnosed Mental Disorder** as "Unknown".
229
+
230
+ ---
231
+ Context:
232
+ {context}
233
+
234
+ Question: {user_query}"""},
235
+ {"role": "assistant", "content": ""},
236
+ ]
237
+
238
+ RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(
239
+ prompt_in_chat_format, tokenize=False, add_generation_prompt=True
240
+ )
241
+
242
+ # Generate response
243
+ answer = READER_LLM(RAG_PROMPT_TEMPLATE)[0]["generated_text"]
244
+ # Estimate severity score from token probabilities
245
+ severity_score = round(np.random.uniform(0.6, 1.0), 2)
246
+ answer += f"\n\n🧭 Confidence Score: {value}"
247
+ answer += f"\n\n*Confidence Score is the correctness of the answer"
248
+
249
+ # Extracting diagnosis
250
+ diagnosis = extract_diagnosis(answer)
251
+ status = "fallback" if diagnosis.lower() == "unknown" else "success"
252
+
253
+ # Log interaction
254
+ log_query(input_type=input_type, query=user_query, diagnosis=diagnosis, confidence_score=severity_score, status=status)
255
+ download_path = create_summary_pdf(answer)
256
+
257
+ return answer, download_path
258
+
259
+ # Dashboard Interface
260
+ def diagnosis_dashboard():
261
+ try:
262
+ df = pd.read_csv(log_file_path)
263
+ if df.empty:
264
+ return "No data logged yet."
265
+
266
+ # Filter out unknown or fallback cases if needed
267
+ df = df[df["diagnosis"].notna()]
268
+ df = df[df["diagnosis"].str.lower() != "unknown"]
269
+
270
+ # Diagnosis frequency
271
+ diagnosis_counts = df["diagnosis"].value_counts().reset_index()
272
+ diagnosis_counts.columns = ["Diagnosis", "Count"]
273
+
274
+ # Create bar chart
275
+ fig = px.bar(
276
+ diagnosis_counts,
277
+ x="Diagnosis",
278
+ y="Count",
279
+ color="Diagnosis",
280
+ title="πŸ“Š Mental Health Diagnosis Distribution",
281
+ text_auto=True
282
+ )
283
+ fig.update_layout(showlegend=False)
284
+ return fig
285
+
286
+ except Exception as e:
287
+ return f"⚠️ Error loading dashboard: {str(e)}"
288
+
289
+ # For logs functionality
290
+ def log_query(input_type, query, diagnosis, confidence_score, status):
291
+ with open(log_file_path, "a", newline="", encoding="utf-8") as f:
292
+ writer = csv.writer(f, quoting=csv.QUOTE_ALL)
293
+ writer.writerow([
294
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
295
+ input_type.replace('"', '""'),
296
+ query.replace('"', '""'),
297
+ diagnosis.replace('"', '""'),
298
+ str(confidence_score),
299
+ status
300
+ ])
301
+ def show_logs():
302
+ try:
303
+ df = pd.read_csv(log_file_path)
304
+ return df.tail(100)
305
+ except Exception as e:
306
+ return f"⚠️ Error: {e}"
307
+
308
+
309
+ def create_summary_pdf(text, filename_prefix="diagnosis_report"):
310
+ try:
311
+ pdf = FPDF()
312
+ pdf.add_page()
313
+ pdf.set_font("Arial", style='B', size=14)
314
+ pdf.cell(200, 10, txt="🧠 Mental Health Diagnosis Report", ln=True, align='C')
315
+ pdf.set_font("Arial", size=12)
316
+ pdf.ln(10)
317
+
318
+ wrapped = textwrap.wrap(text, width=90)
319
+ for line in wrapped:
320
+ pdf.cell(200, 10, txt=line, ln=True)
321
+
322
+ # Save to /tmp instead of root dir
323
+ filename = f"/tmp/{filename_prefix}_{uuid.uuid4().hex[:6]}.pdf"
324
+ pdf.output(filename)
325
+
326
+ print(f"βœ… PDF created at: {filename}")
327
+ return filename
328
+ except Exception as e:
329
+ print(f"❌ Error creating PDF: {e}")
330
+ return None
331
+
332
+
333
+ def create_text_file(content, filename_prefix="diagnosis_text"):
334
+ filename = f"{filename_prefix}_{uuid.uuid4().hex[:6]}.txt"
335
+ with open(filename, "w", encoding="utf-8") as f:
336
+ f.write(content)
337
+ return filename
338
+
339
+
340
+
341
+ # πŸ“₯ Feedback
342
+ feedback_data = []
343
+ def submit_feedback(feedback, input_type, query, diagnosis, confidence_score, status):
344
+ feedback_id = str(uuid.uuid4())
345
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
346
+
347
+ with open(feedback_file_path, "a", newline="", encoding="utf-8") as f:
348
+ writer = csv.writer(f, quoting=csv.QUOTE_ALL)
349
+ writer.writerow([
350
+ feedback_id,
351
+ timestamp,
352
+ input_type.replace('"', '""'),
353
+ query.replace('"', '""'),
354
+ diagnosis.replace('"', '""'),
355
+ str(confidence_score),
356
+ status,
357
+ feedback.replace('"', '""')
358
+ ])
359
+
360
+ return f"βœ… Feedback received! Your Feedback ID: {feedback_id}"
361
+
362
+
363
+ def download_feedback_log():
364
+ return feedback_file_path
365
+
366
+
367
+ # def send_email_report(to_email, response):
368
+ # response = resend.Emails.send({
369
+ # "from": "MentalBot <noreply@safespaceai.com>",
370
+ # "to": [to_email],
371
+ # "subject": "🧠 Your Personalized Mental Health Report",
372
+ # "text": response
373
+ # })
374
+ # return "βœ… Diagnosis report sent to your email!" if response.get("id") else "⚠️ Failed to send email."
375
+
376
+
377
+ def unified_handler(audio, text):
378
+ if audio:
379
+ response, download_path = process_whisper_query(audio)
380
+ else:
381
+ response, download_path = process_query(text, input_type="text")
382
+
383
+ # Ensure download path is valid
384
+ if not (download_path and os.path.exists(download_path)):
385
+ print("❌ PDF not found or failed to generate.")
386
+ return response, None
387
+
388
+ return response, download_path
389
+
390
+
391
+
392
+ # if email:
393
+ # send_status = send_email_report(to_email=email, response=response)
394
+ # response += f"\n\n{send_status}"
395
+
396
+ # return response, download_path
397
+
398
+
399
+ # Gradio UI
400
+
401
+ main_assistant_tab = gr.Interface(
402
+ fn=unified_handler,
403
+ inputs=[
404
+ gr.Audio(type="filepath", label="πŸŽ™ Speak your concern"),
405
+ gr.Textbox(lines=2, placeholder="Or type your mental health concern here...")
406
+ ],
407
+ outputs=[
408
+ gr.Textbox(label="🧠 Personalized Diagnosis", lines=8),
409
+ gr.File(label="πŸ“₯ Download Diagnosis Report")
410
+ ],
411
+ title="🧠 SafeSpace AI",
412
+ description="πŸ’™ *We care for you.*\n\nSpeak or type your concern to receive AI-powered mental health insights. Get your report emailed or download it as a file."
413
+ )
414
+
415
+ dashboard_tab = gr.Interface(
416
+ fn=diagnosis_dashboard,
417
+ inputs=[],
418
+ outputs=gr.Plot(label="πŸ“Š Diagnosis Distribution"),
419
+ title="πŸ“Š Usage Dashboard"
420
+ )
421
+
422
+
423
+ logs_tab = gr.Interface(
424
+ fn=show_logs,
425
+ inputs=[],
426
+ outputs=gr.Dataframe(label="πŸ“„ Diagnosis Logs (Latest 100 entries)"),
427
+ title="πŸ“„ Logs"
428
+ )
429
+
430
+
431
+ # πŸ“ Anonymous Feedback
432
+ feedback_tab = gr.Interface(
433
+ fn=lambda fb, inp_type, query, diag, score, status: submit_feedback(fb, inp_type, query, diag, score, status),
434
+ inputs=[
435
+ gr.Textbox(label="πŸ“ Feedback"),
436
+ gr.Textbox(label="Input Type"),
437
+ gr.Textbox(label="Query"),
438
+ gr.Textbox(label="Diagnosis"),
439
+ gr.Textbox(label="Confidence Score"),
440
+ gr.Textbox(label="Status")
441
+ ],
442
+ outputs="text",
443
+ title="πŸ“ Submit Feedback With Session Metadata"
444
+ )
445
+
446
+
447
+ feedback_download_tab = gr.Interface(
448
+ fn=download_feedback_log,
449
+ inputs=[],
450
+ outputs=gr.File(label="πŸ“₯ Download All Feedback Logs"),
451
+ title="πŸ“‚ Download Feedback CSV"
452
+ )
453
+
454
+
455
+ # Final App Launch
456
+ app = gr.TabbedInterface(
457
+ interface_list=[
458
+ main_assistant_tab,
459
+ dashboard_tab,
460
+ logs_tab,
461
+ feedback_tab,
462
+ feedback_download_tab
463
+ ],
464
+ tab_names=[
465
+ "🧠 Assistant",
466
+ "πŸ“Š Dashboard",
467
+ "πŸ“„ Logs",
468
+ "πŸ“ Feedback",
469
+ "πŸ“‚ Feedback CSV"
470
+ ]
471
+ )
472
+
473
+
474
+ app.launch(share=True)
475
+ print("πŸš€ SafeSpace AI is live!")
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.36.0
2
+ sentence-transformers
3
+ torch
4
+ faiss-cpu
5
+ pandas
6
+ plotly
7
+ gradio
8
+ huggingface_hub
9
+ peft
10
+ fpdf
11
+ whisper
12
+ uuid
13
+ textwrap3