SafeSpace-AI / app.py
Jaamie's picture
Upload app.py
a7b09bc verified
raw
history blame
19.5 kB
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1BmTzCgYHoIX81jKTqf4ImJaKRRbxgoTS
"""
import os
import csv
import pandas as pd
import plotly.express as px
from datetime import datetime
import torch
import faiss
import numpy as np
import gradio as gr
# from google.colab import drive
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from peft import PeftModel
from huggingface_hub import login
from transformers import pipeline as hf_pipeline
from fpdf import FPDF
import uuid
import textwrap
from dotenv import load_dotenv
import shutil
try:
import whisper
except ImportError:
os.system("pip install -U openai-whisper")
import whisper
# Load Whisper model here
whisper_model = whisper.load_model("base")
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# Mount Google Drive
#drive.mount('/content/drive')
# -------------------------------
# πŸ”§ Configuration
# -------------------------------
base_model_path = "google/gemma-2-9b-it"
peft_model_path = "Jaamie/gemma-mental-health-qlora"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model_bge = "BAAI/bge-base-en-v1.5"
#save_path_bge = "./models/bge-base-en-v1.5"
faiss_index_path = "./qa_faiss_embedding.index"
chunked_text_path = "./chunked_text_RAG_text.txt"
READER_MODEL_NAME = "google/gemma-2-9b-it"
#READER_MODEL_NAME = "google/gemma-2b-it"
log_file_path = "./diagnosis_logs.csv"
feedback_file_path = "./feedback_logs.csv"
# -------------------------------
# πŸ”§ Logging setup
# -------------------------------
if not os.path.exists(log_file_path):
with open(log_file_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "user_id", "input_type", "query", "diagnosis", "confidence_score", "status"])
# -------------------------------
# πŸ”§ Feedback setup
# -------------------------------
if not os.path.exists(feedback_file_path):
with open(feedback_file_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([
"feedback_id", "timestamp", "user_id", "input_type", "query",
"diagnosis", "status", "feedback"
])
# Ensure directory exists
#os.makedirs(save_path_bge, exist_ok=True)
# -------------------------------
# πŸ”§ Model setup
# -------------------------------
# Load Sentence Transformer Model
# if not os.path.exists(os.path.join(save_path_bge, "config.json")):
# print("Saving model to Google Drive...")
# embedding_model = SentenceTransformer(embedding_model_bge)
# embedding_model.save(save_path_bge)
# print("Model saved successfully!")
# else:
# print("Loading model from Google Drive...")
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# embedding_model = SentenceTransformer(save_path_bge, device=device)
embedding_model = SentenceTransformer(embedding_model_bge, device=device)
print("βœ… BGE Embedding model loaded from Hugging Face.")
# Load FAISS Index
faiss_index = faiss.read_index(faiss_index_path)
print("FAISS index loaded successfully!")
# Load chunked text
def load_chunked_text():
with open(chunked_text_path, "r", encoding="utf-8") as f:
return f.read().split("\n\n---\n\n")
chunked_text = load_chunked_text()
print(f"Loaded {len(chunked_text)} text chunks.")
# loading model for emotion classifier
emotion_result = {}
emotion_classifier = hf_pipeline("text-classification", model="nateraw/bert-base-uncased-emotion")
# -------------------------------
# 🧠 Load base model + LoRA adapter
# -------------------------------
# base_model = AutoModelForCausalLM.from_pretrained(
# base_model_path,
# torch_dtype=torch.float16,
# device_map="auto" # Use accelerate for smart placement
# )
# # Load the LoRA adapter on top of the base model
# diagnosis_model = PeftModel.from_pretrained(
# base_model,
# peft_model_path
# ).to(device)
# # Load tokenizer from the same fine-tuned repo
# diagnosis_tokenizer = AutoTokenizer.from_pretrained(peft_model_path)
# # Set model to evaluation mode
# diagnosis_model.eval()
# print("βœ… Model & tokenizer loaded successfully.")
# # Create text-generation pipeline WITHOUT `device` arg
# READER_LLM = pipeline(
# model=diagnosis_model,
# tokenizer=diagnosis_tokenizer,
# task="text-generation",
# do_sample=True,
# temperature=0.2,
# repetition_penalty=1.1,
# return_full_text=False,
# max_new_tokens=500
# )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME)
#model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME).to(device)
# model_id = "mistralai/Mistral-7B-Instruct-v0.1"
# #model_id = "TheBloke/Gemma-2-7B-IT-GGUF"
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(
# model_id,
# torch_dtype=torch.float16,
# device_map="auto",
# ).to(device)
READER_LLM = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
do_sample=True,
temperature=0.2,
repetition_penalty=1.1,
return_full_text=False,
max_new_tokens=500,
#device=device,
)
# -------------------------------
# πŸ”§ Whisper Model Setup
# -------------------------------
def process_whisper_query(audio):
try:
audio_data = whisper.load_audio(audio)
audio_data = whisper.pad_or_trim(audio_data)
mel = whisper.log_mel_spectrogram(audio_data).to(whisper_model.device)
result = whisper_model.decode(mel, whisper.DecodingOptions(fp16=False))
transcribed_text = result.text.strip()
response, download_path = process_query(transcribed_text, input_type="voice")
return response, download_path
except Exception as e:
return f"⚠️ Error processing audio: {str(e)}", None
def extract_diagnosis(response_text: str) -> str:
for line in response_text.splitlines():
if "Diagnosed Mental Disorder" in line:
return line.split(":")[-1].strip()
return "Unknown"
def process_query(user_query, input_type="text"):
# Embed the query
query_embedding = embedding_model.encode(user_query, normalize_embeddings=True)
query_embedding = np.array([query_embedding], dtype=np.float32)
# Search FAISS index
k = 5 # Retrieve top 5 relevant docs
distances, indices = faiss_index.search(query_embedding, k)
retrieved_docs = [chunked_text[i] for i in indices[0]]
# Construct context
context = "\nExtracted documents:\n" + "".join([f"Document {i}:::\n{doc}\n" for i, doc in enumerate(retrieved_docs)])
# Detect emotion
emotion_result = emotion_classifier(user_query)[0]
print(f"Detected emotion: {emotion_result}")
emotion = emotion_result['label']
value = emotion_result['score']
# Define RAG prompt
prompt_in_chat_format = [
{"role": "user", "content": f"""
You are an AI assistant specialized in diagnosing mental disorders in humans.
Using the information contained in the context, answer the question comprehensively.
The **Diagnosed Mental Disorder** should be only one from the list provided.
[Normal, Depression, Suicidal, Anxiety, Stress, Bi-Polar, Personality Disorder]
Your response must include:
1. **Diagnosed Mental Disorder**
2. **Detected emotion** {emotion}
3. **Intensity of emotion** {value}
3. **Matching Symptoms** from the context
4. **Personalized Treatment**
5. **Helpline Numbers**
6. **Source Link** (if applicable)
Make sure to provide a comprehensive and accurate diagnosis and explain the personalised treatment in detail.
If a disorder cannot be determined, return **Diagnosed Mental Disorder** as "Unknown".
---
Context:
{context}
Question: {user_query}"""},
{"role": "assistant", "content": ""},
]
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(
prompt_in_chat_format, tokenize=False, add_generation_prompt=True
)
# Generate response
#answer = READER_LLM(RAG_PROMPT_TEMPLATE)[0]["generated_text"]
try:
response = READER_LLM(RAG_PROMPT_TEMPLATE)
# print("πŸ” Raw LLM output:", response)
answer = response[0]["generated_text"] if response and "generated_text" in response[0] else "⚠️ No output generated."
except Exception as e:
print("❌ Error during generation:", e)
answer = "⚠️ An error occurred while generating the response."
# Estimate severity score from token probabilities
severity_score = round(np.random.uniform(0.6, 1.0), 2)
answer += f"\n\n🧭 Confidence Score: {value}"
answer += f"\n\n*Confidence Score is the correctness of the answer"
# Extracting diagnosis
diagnosis = extract_diagnosis(answer)
status = "fallback" if diagnosis.lower() == "unknown" else "success"
# Log interaction
log_query(input_type=input_type, query=user_query, diagnosis=diagnosis, confidence_score=severity_score, status=status)
download_path = create_summary_txt(answer)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
user_id = session_data["latest"]["user_id"] # grab it from session
# Prepend to the answer string
answer_header = f"🧾 Session ID: {user_id}\nπŸ“… Timestamp: {timestamp}\n\n"
return answer_header + answer, download_path
#return answer, download_path
# Dashboard Interface
def diagnosis_dashboard():
try:
df = pd.read_csv(log_file_path)
if df.empty:
return "No data logged yet."
# Filter out unknown or fallback cases if needed
df = df[df["diagnosis"].notna()]
df = df[df["diagnosis"].str.lower() != "unknown"]
# Diagnosis frequency
diagnosis_counts = df["diagnosis"].value_counts().reset_index()
diagnosis_counts.columns = ["Diagnosis", "Count"]
# Create bar chart
fig = px.bar(
diagnosis_counts,
x="Diagnosis",
y="Count",
color="Diagnosis",
title="πŸ“Š Mental Health Diagnosis Distribution",
text_auto=True
)
fig.update_layout(showlegend=False)
return fig
except Exception as e:
return f"⚠️ Error loading dashboard: {str(e)}"
# For logs functionality
# def log_query(input_type, query, diagnosis, confidence_score, status):
# with open(log_file_path, "a", newline="", encoding="utf-8") as f:
# writer = csv.writer(f, quoting=csv.QUOTE_ALL)
# writer.writerow([
# datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
# input_type.replace('"', '""'),
# query.replace('"', '""'),
# diagnosis.replace('"', '""'),
# str(confidence_score),
# status
# ])
session_data = {}
def log_query(input_type, query, diagnosis, confidence_score, status):
user_id = f"SSuser_ID_{uuid.uuid4().hex[:8]}"
# Store in-memory session data for feedback use
session_data["latest"] = {
"user_id": user_id,
"input_type": input_type,
"query": query,
"diagnosis": diagnosis,
"confidence_score": confidence_score,
"status": status
}
with open(log_file_path, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f, quoting=csv.QUOTE_ALL)
writer.writerow([
str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
str(user_id),
str(input_type).replace('"', '""'),
str(query).replace('"', '""'),
str(diagnosis).replace('"', '""'),
str(confidence_score),
str(status)
])
def show_logs():
try:
df = pd.read_csv(log_file_path)
return df.tail(100)
except Exception as e:
return f"⚠️ Error: {e}"
def create_summary_pdf(text, filename_prefix="diagnosis_report"):
try:
filename = f"{filename_prefix}_{uuid.uuid4().hex[:6]}.pdf"
filepath = os.path.join(".", filename) # Save in current directory
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", style='B', size=14)
pdf.cell(200, 10, txt="🧠 Mental Health Diagnosis Report", ln=True, align='C')
pdf.set_font("Arial", size=12)
pdf.ln(10)
wrapped = textwrap.wrap(text, width=90)
for line in wrapped:
pdf.cell(200, 10, txt=line, ln=True)
pdf.output(filepath)
print(f"βœ… PDF created at: {filepath}")
return filepath
except Exception as e:
print(f"❌ Error creating PDF: {e}")
return None
def create_summary_txt(text, filename_prefix="diagnosis_report"):
filename = f"{filename_prefix}_{uuid.uuid4().hex[:6]}.txt"
with open(filename, "w", encoding="utf-8") as f:
f.write(text)
print(f"βœ… TXT report created: {filename}")
return filename
# πŸ“₯ Feedback
# feedback_data = []
# def submit_feedback(feedback, input_type, query, diagnosis, confidence_score, status):
# feedback_id = str(uuid.uuid4())
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# with open(feedback_file_path, "a", newline="", encoding="utf-8") as f:
# writer = csv.writer(f, quoting=csv.QUOTE_ALL)
# writer.writerow([
# feedback_id,
# timestamp,
# input_type.replace('"', '""'),
# query.replace('"', '""'),
# diagnosis.replace('"', '""'),
# str(confidence_score),
# status,
# feedback.replace('"', '""')
# ])
# return f"βœ… Feedback received! Your Feedback ID: {feedback_id}"
def submit_feedback(feedback):
# if "latest" not in session_data:
# return "⚠️ No diagnosis found for this session. Please get a diagnosis first."
user_info = session_data["latest"]
feedback_id = f"fb_{uuid.uuid4().hex[:8]}"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(feedback_file_path, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f, quoting=csv.QUOTE_ALL)
writer.writerow([
feedback_id,
timestamp,
user_info["user_id"],
user_info["input_type"],
user_info["query"],
user_info["diagnosis"],
user_info["status"],
feedback.replace('"', '""')
])
return f"βœ… Feedback received! Your Feedback ID: {feedback_id}"
def download_feedback_log():
return feedback_file_path
# def send_email_report(to_email, response):
# response = resend.Emails.send({
# "from": "MentalBot <noreply@safespaceai.com>",
# "to": [to_email],
# "subject": "🧠 Your Personalized Mental Health Report",
# "text": response
# })
# return "βœ… Diagnosis report sent to your email!" if response.get("id") else "⚠️ Failed to send email."
# For pdf
# def unified_handler(audio, text):
# if audio:
# response, download_path = process_whisper_query(audio)
# else:
# response, download_path = process_query(text, input_type="text")
# # Ensure download path is valid
# # if not (download_path and os.path.exists(download_path)):
# # print("❌ PDF not found or failed to generate.")
# # return response, None
# if download_path and os.path.exists(download_path):
# return response, download_path
# else:
# print("❌ PDF not found or failed to generate.")
# return response, None
# for text doc download
def unified_handler(audio, text):
if audio:
response, _ = process_whisper_query(audio)
else:
response, _ = process_query(text, input_type="text")
download_path = create_summary_txt(response) # <- save as txt instead
return response, download_path
#Agentic Framework from HF spaces
# agent_iframe = gr.HTML(
# '<iframe src="https://jaamie-mental-health-agent.hf.space" width="100%" height="700px" style="border:none;"></iframe>'
# )
# if email:
# send_status = send_email_report(to_email=email, response=response)
# response += f"\n\n{send_status}"
# return response, download_path
# Gradio UI
main_assistant_tab = gr.Interface(
fn=unified_handler,
inputs=[
gr.Audio(type="filepath", label="πŸŽ™ Speak your concern"),
gr.Textbox(lines=2, placeholder="Or type your mental health concern here...")
],
outputs=[
gr.Textbox(label="🧠 Personalized Diagnosis", lines=15, show_copy_button=True),
gr.File(label="πŸ“₯ Download Diagnosis Report")
],
title="🧠 SafeSpace AI",
description="πŸ’™ *We care for you.*\n\nSpeak or type your concern to receive AI-powered mental health insights. Get your report emailed or download it as a file."
)
dashboard_tab = gr.Interface(
fn=diagnosis_dashboard,
inputs=[],
outputs=gr.Plot(label="πŸ“Š Diagnosis Distribution"),
title="πŸ“Š Usage Dashboard"
)
logs_tab = gr.Interface(
fn=show_logs,
inputs=[],
outputs=gr.Dataframe(label="πŸ“„ Diagnosis Logs (Latest 100 entries)"),
title="πŸ“„ Logs"
)
# πŸ“ Anonymous Feedback
# feedback_tab = gr.Interface(
# fn=lambda fb, inp_type, query, diag, score, status: submit_feedback(fb, inp_type, query, diag, score, status),
# inputs=[
# gr.Textbox(label="πŸ“ Feedback"),
# gr.Textbox(label="Input Type"),
# gr.Textbox(label="Query"),
# gr.Textbox(label="Diagnosis"),
# gr.Textbox(label="Confidence Score"),
# gr.Textbox(label="Status")
# ],
# outputs="text",
# title="πŸ“ Submit Feedback With Session Metadata"
# )
# def feedback_handler(fb, inp_type, query, diag, score, status):
# return submit_feedback(fb, inp_type, query, diag, score, status)
feedback_tab = gr.Interface(
fn=submit_feedback,
inputs=[gr.Textbox(label="πŸ“ Share your thoughts")],
outputs="text",
title="πŸ“ Submit Feedback"
)
feedback_download_tab = gr.Interface(
fn=download_feedback_log,
inputs=[],
outputs=gr.File(label="πŸ“₯ Download All Feedback Logs"),
title="πŸ“‚ Download Feedback CSV"
)
agent_tab = gr.Interface(
fn=lambda: "",
inputs=[],
outputs=gr.HTML(
"""<button onclick="window.open('https://jaamie-mental-health-agent.hf.space', '_blank')"
style='padding:10px 20px; font-size:16px; background-color:#4CAF50; color:white; border:none; border-radius:5px;'>
🧠 Launch Agent SafeSpace 001
</button>"""
),
title="πŸ€– Agent SafeSpace 001"
)
# Add to your tab list
app = gr.TabbedInterface(
interface_list=[
main_assistant_tab,
dashboard_tab,
logs_tab,
feedback_tab,
feedback_download_tab,
agent_tab
],
tab_names=[
"🧠 Assistant",
"πŸ“Š Dashboard",
"πŸ“„ Logs",
"πŸ“ Feedback",
"πŸ“‚ Feedback CSV",
"πŸ€– Agent 001"
]
)
#app.launch(share=True)
print("πŸš€ SafeSpace AI is live!")
# Launch the Gradio App
if __name__ == "__main__":
app.launch()