File size: 5,310 Bytes
93407fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd
import os
from groq import Groq
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

app = FastAPI()

# βœ… Set up Groq API key
os.environ["GROQ_API_KEY"] = "your-groq-api-key"  # Replace with your actual API key
client = Groq(api_key=os.environ["GROQ_API_KEY"])

# βœ… Load AI Models
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")

# βœ… Load datasets
recommendations_df = pd.read_csv("treatment_recommendations.csv")
questions_df = pd.read_csv("symptom_questions.csv")

# βœ… FAISS Index for disorder detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# βœ… FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)

# βœ… Request Model
class ChatRequest(BaseModel):
    message: str

class SummaryRequest(BaseModel):
    chat_history: list  # List of messages

# βœ… Retrieve the most relevant question
def retrieve_questions(user_input):
    """Retrieve the most relevant individual diagnostic question using FAISS."""
    input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
    _, indices = question_index.search(input_embedding, 1)  # βœ… Retrieve only 1 question

    # βœ… Extract only the first meaningful question
    question_block = questions_df["Questions"].iloc[indices[0][0]]
    split_questions = question_block.split(", ")  
    best_question = split_questions[0]  # βœ… Select the first clear question

    return best_question  # βœ… Return a single question as a string

# βœ… Groq API for rephrasing
def generate_empathetic_response(user_input, retrieved_question):
    """Use Groq API (LLaMA-3) to generate one empathetic response."""
    
    # βœ… Improved Prompt: Only One Question
    prompt = f"""
    The user said: "{user_input}"
    
    Relevant Question:
    - {retrieved_question}

    You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
    Acknowledge the user's emotions before asking the question.

    Example format:
    - "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"
    
    Generate only one empathetic response.
    """

    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
            {"role": "user", "content": prompt}
        ],
        model="llama3-8b",  # βœ… Use Groq's LLaMA-3 Model
        temperature=0.8,  # Adjust for natural variation
        top_p=0.9
    )

    return chat_completion.choices[0].message.content  # βœ… Return only one response

# βœ… API Endpoint: Get Empathetic Questions (Hybrid RAG)
@app.post("/get_questions")
def get_recommended_questions(request: ChatRequest):
    """Retrieve the most relevant diagnostic question and make it more empathetic using Groq API."""
    retrieved_question = retrieve_questions(request.message)
    empathetic_response = generate_empathetic_response(request.message, retrieved_question)
    
    return {"question": empathetic_response}

# βœ… API Endpoint: Summarize Chat
@app.post("/summarize_chat")
def summarize_chat(request: SummaryRequest):
    """Summarize full chat session at the end."""
    chat_text = " ".join(request.chat_history)
    inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
    summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
    summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return {"summary": summary}

# βœ… API Endpoint: Detect Disorders
@app.post("/detect_disorders")
def detect_disorders(request: SummaryRequest):
    """Detect psychiatric disorders from full chat history at the end."""
    full_chat_text = " ".join(request.chat_history)
    text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
    distances, indices = index.search(text_embedding, 3)
    disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
    return {"disorders": disorders}

# βœ… API Endpoint: Get Treatment Recommendations
@app.post("/get_treatment")
def get_treatment(request: SummaryRequest):
    """Retrieve treatment recommendations based on detected disorders."""
    detected_disorders = detect_disorders(request)["disorders"]
    treatments = {
        disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
        for disorder in detected_disorders
    }
    return {"treatments": treatments}