Spaces:
Running
Running
# MentalQA – Arabic Mental Health Assistant (chat + classifier) | |
import torch | |
from transformers import ( | |
AutoTokenizer, AutoModelForCausalLM, | |
AutoModelForSequenceClassification, pipeline | |
) | |
CHAT_REPO = "yasser-alharbi/MentalQA" | |
CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification" | |
# Load chat model | |
chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False) | |
chat_model = AutoModelForCausalLM.from_pretrained( | |
CHAT_REPO, | |
torch_dtype="auto", | |
device_map="auto", | |
low_cpu_mem_usage=True, | |
) | |
# Load classifier | |
clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO) | |
clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO) | |
device_idx = 0 if torch.cuda.is_available() else -1 | |
clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx) | |
label_map = { | |
"LABEL_0": "A", # تشخيص | |
"LABEL_1": "B", # علاج | |
"LABEL_2": "C", # تشريح | |
"LABEL_3": "D", # وبائيات | |
"LABEL_4": "E", # نمط حياة | |
"LABEL_5": "F", # مقدم خدمة | |
"LABEL_6": "G", # أخرى | |
} | |
SYSTEM_MSG = ( | |
"أنت مساعد ذكي للصحة النفسية اسمه MentalQA. " | |
"لا تذكر اسمك أو منصة عملك إلا إذا سُئلت صراحةً عن هويتك." | |
) | |
def classify_question(text: str, threshold: float = 0.5) -> str: | |
pred = max(clf_pipe(text), key=lambda x: x["score"]) | |
return label_map.get(pred["label"], "G") if pred["score"] >= threshold else "G" | |
def build_prompt(question: str, final_qt: str) -> str: | |
return ( | |
f"{SYSTEM_MSG}\n\n" | |
f"final_QT: {final_qt}\n\n" | |
f"سؤال المستخدم:\n{question}\n\n" | |
"اكتب فقرة واحدة مفصّلة لا تقل عن ثلاث جمل مترابطة، بعد أن تفكّر خطوة بخطوة.\n" | |
"الإجابة النهائية:\n" | |
) | |
def generate_mentalqa_answer(question: str, threshold: float = 0.5) -> str: | |
final_qt = classify_question(question, threshold) | |
prompt = build_prompt(question, final_qt) | |
chat_input = chat_tok.apply_chat_template( | |
[{"role": "system", "content": SYSTEM_MSG}, | |
{"role": "user", "content": prompt}], | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(chat_model.device) | |
gen_output = chat_model.generate( | |
chat_input, | |
max_new_tokens=128, | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.95, | |
repetition_penalty=1.15, | |
no_repeat_ngram_size=2, | |
pad_token_id=chat_tok.eos_token_id, | |
eos_token_id=chat_tok.eos_token_id, | |
)[0] | |
answer = chat_tok.decode(gen_output[chat_input.shape[1]:], skip_special_tokens=True) | |
return answer.strip() | |