File size: 2,746 Bytes
74ef6d5
 
de736d7
74ef6d5
 
 
 
dc00d03
de736d7
dc00d03
 
74ef6d5
dc00d03
 
 
 
de736d7
dc00d03
 
 
74ef6d5
dc00d03
 
 
 
de736d7
dc00d03
 
74ef6d5
 
 
 
 
 
 
dc00d03
 
 
de736d7
dc00d03
 
 
74ef6d5
de736d7
74ef6d5
dc00d03
74ef6d5
de736d7
74ef6d5
 
de736d7
 
dc00d03
 
 
74ef6d5
 
 
 
 
 
 
dc00d03
 
 
 
74ef6d5
 
de736d7
dc00d03
 
 
 
 
 
 
 
 
74ef6d5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# MentalQA – Arabic Mental Health Assistant (chat + classifier)

import torch
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    AutoModelForSequenceClassification, pipeline
)

CHAT_REPO = "yasser-alharbi/MentalQA"
CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification"

# Load chat model
chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False)
chat_model = AutoModelForCausalLM.from_pretrained(
    CHAT_REPO,
    torch_dtype="auto",
    device_map="auto",
    low_cpu_mem_usage=True,
)

# Load classifier
clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO)
clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO)

device_idx = 0 if torch.cuda.is_available() else -1
clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx)

label_map = {
    "LABEL_0": "A",  # تشخيص
    "LABEL_1": "B",  # علاج
    "LABEL_2": "C",  # تشريح
    "LABEL_3": "D",  # وبائيات
    "LABEL_4": "E",  # نمط حياة
    "LABEL_5": "F",  # مقدم خدمة
    "LABEL_6": "G",  # أخرى
}

SYSTEM_MSG = (
    "أنت مساعد ذكي للصحة النفسية اسمه MentalQA. "
    "لا تذكر اسمك أو منصة عملك إلا إذا سُئلت صراحةً عن هويتك."
)

def classify_question(text: str, threshold: float = 0.5) -> str:
    pred = max(clf_pipe(text), key=lambda x: x["score"])
    return label_map.get(pred["label"], "G") if pred["score"] >= threshold else "G"

def build_prompt(question: str, final_qt: str) -> str:
    return (
        f"{SYSTEM_MSG}\n\n"
        f"final_QT: {final_qt}\n\n"
        f"سؤال المستخدم:\n{question}\n\n"
        "اكتب فقرة واحدة مفصّلة لا تقل عن ثلاث جمل مترابطة، بعد أن تفكّر خطوة بخطوة.\n"
        "الإجابة النهائية:\n"
    )

def generate_mentalqa_answer(question: str, threshold: float = 0.5) -> str:
    final_qt = classify_question(question, threshold)
    prompt = build_prompt(question, final_qt)

    chat_input = chat_tok.apply_chat_template(
        [{"role": "system", "content": SYSTEM_MSG},
         {"role": "user", "content": prompt}],
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(chat_model.device)

    gen_output = chat_model.generate(
        chat_input,
        max_new_tokens=128,
        do_sample=True,
        temperature=0.6,
        top_p=0.95,
        repetition_penalty=1.15,
        no_repeat_ngram_size=2,
        pad_token_id=chat_tok.eos_token_id,
        eos_token_id=chat_tok.eos_token_id,
    )[0]

    answer = chat_tok.decode(gen_output[chat_input.shape[1]:], skip_special_tokens=True)
    return answer.strip()