Spaces:
Running
Running
File size: 2,746 Bytes
74ef6d5 de736d7 74ef6d5 dc00d03 de736d7 dc00d03 74ef6d5 dc00d03 de736d7 dc00d03 74ef6d5 dc00d03 de736d7 dc00d03 74ef6d5 dc00d03 de736d7 dc00d03 74ef6d5 de736d7 74ef6d5 dc00d03 74ef6d5 de736d7 74ef6d5 de736d7 dc00d03 74ef6d5 dc00d03 74ef6d5 de736d7 dc00d03 74ef6d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
# MentalQA – Arabic Mental Health Assistant (chat + classifier)
import torch
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
AutoModelForSequenceClassification, pipeline
)
CHAT_REPO = "yasser-alharbi/MentalQA"
CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification"
# Load chat model
chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False)
chat_model = AutoModelForCausalLM.from_pretrained(
CHAT_REPO,
torch_dtype="auto",
device_map="auto",
low_cpu_mem_usage=True,
)
# Load classifier
clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO)
clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO)
device_idx = 0 if torch.cuda.is_available() else -1
clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx)
label_map = {
"LABEL_0": "A", # تشخيص
"LABEL_1": "B", # علاج
"LABEL_2": "C", # تشريح
"LABEL_3": "D", # وبائيات
"LABEL_4": "E", # نمط حياة
"LABEL_5": "F", # مقدم خدمة
"LABEL_6": "G", # أخرى
}
SYSTEM_MSG = (
"أنت مساعد ذكي للصحة النفسية اسمه MentalQA. "
"لا تذكر اسمك أو منصة عملك إلا إذا سُئلت صراحةً عن هويتك."
)
def classify_question(text: str, threshold: float = 0.5) -> str:
pred = max(clf_pipe(text), key=lambda x: x["score"])
return label_map.get(pred["label"], "G") if pred["score"] >= threshold else "G"
def build_prompt(question: str, final_qt: str) -> str:
return (
f"{SYSTEM_MSG}\n\n"
f"final_QT: {final_qt}\n\n"
f"سؤال المستخدم:\n{question}\n\n"
"اكتب فقرة واحدة مفصّلة لا تقل عن ثلاث جمل مترابطة، بعد أن تفكّر خطوة بخطوة.\n"
"الإجابة النهائية:\n"
)
def generate_mentalqa_answer(question: str, threshold: float = 0.5) -> str:
final_qt = classify_question(question, threshold)
prompt = build_prompt(question, final_qt)
chat_input = chat_tok.apply_chat_template(
[{"role": "system", "content": SYSTEM_MSG},
{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt"
).to(chat_model.device)
gen_output = chat_model.generate(
chat_input,
max_new_tokens=128,
do_sample=True,
temperature=0.6,
top_p=0.95,
repetition_penalty=1.15,
no_repeat_ngram_size=2,
pad_token_id=chat_tok.eos_token_id,
eos_token_id=chat_tok.eos_token_id,
)[0]
answer = chat_tok.decode(gen_output[chat_input.shape[1]:], skip_special_tokens=True)
return answer.strip()
|