MentalQA / models /space_b.py
NLPAlhuzali's picture
Update models/space_b.py
74ef6d5 verified
# MentalQA – Arabic Mental Health Assistant (chat + classifier)
import torch
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
AutoModelForSequenceClassification, pipeline
)
CHAT_REPO = "yasser-alharbi/MentalQA"
CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification"
# Load chat model
chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False)
chat_model = AutoModelForCausalLM.from_pretrained(
CHAT_REPO,
torch_dtype="auto",
device_map="auto",
low_cpu_mem_usage=True,
)
# Load classifier
clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO)
clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO)
device_idx = 0 if torch.cuda.is_available() else -1
clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx)
label_map = {
"LABEL_0": "A", # تشخيص
"LABEL_1": "B", # علاج
"LABEL_2": "C", # تشريح
"LABEL_3": "D", # وبائيات
"LABEL_4": "E", # نمط حياة
"LABEL_5": "F", # مقدم خدمة
"LABEL_6": "G", # أخرى
}
SYSTEM_MSG = (
"أنت مساعد ذكي للصحة النفسية اسمه MentalQA. "
"لا تذكر اسمك أو منصة عملك إلا إذا سُئلت صراحةً عن هويتك."
)
def classify_question(text: str, threshold: float = 0.5) -> str:
pred = max(clf_pipe(text), key=lambda x: x["score"])
return label_map.get(pred["label"], "G") if pred["score"] >= threshold else "G"
def build_prompt(question: str, final_qt: str) -> str:
return (
f"{SYSTEM_MSG}\n\n"
f"final_QT: {final_qt}\n\n"
f"سؤال المستخدم:\n{question}\n\n"
"اكتب فقرة واحدة مفصّلة لا تقل عن ثلاث جمل مترابطة، بعد أن تفكّر خطوة بخطوة.\n"
"الإجابة النهائية:\n"
)
def generate_mentalqa_answer(question: str, threshold: float = 0.5) -> str:
final_qt = classify_question(question, threshold)
prompt = build_prompt(question, final_qt)
chat_input = chat_tok.apply_chat_template(
[{"role": "system", "content": SYSTEM_MSG},
{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt"
).to(chat_model.device)
gen_output = chat_model.generate(
chat_input,
max_new_tokens=128,
do_sample=True,
temperature=0.6,
top_p=0.95,
repetition_penalty=1.15,
no_repeat_ngram_size=2,
pad_token_id=chat_tok.eos_token_id,
eos_token_id=chat_tok.eos_token_id,
)[0]
answer = chat_tok.decode(gen_output[chat_input.shape[1]:], skip_special_tokens=True)
return answer.strip()