NLPAlhuzali commited on
Commit
74ef6d5
·
verified ·
1 Parent(s): 6e92cbd

Update models/space_b.py

Browse files
Files changed (1) hide show
  1. models/space_b.py +31 -17
models/space_b.py CHANGED
@@ -1,10 +1,15 @@
 
 
1
  import torch
2
- from transformers import (AutoTokenizer, AutoModelForCausalLM,
3
- AutoModelForSequenceClassification, pipeline)
 
 
4
 
5
  CHAT_REPO = "yasser-alharbi/MentalQA"
6
  CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification"
7
 
 
8
  chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False)
9
  chat_model = AutoModelForCausalLM.from_pretrained(
10
  CHAT_REPO,
@@ -13,6 +18,7 @@ chat_model = AutoModelForCausalLM.from_pretrained(
13
  low_cpu_mem_usage=True,
14
  )
15
 
 
16
  clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO)
17
  clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO)
18
 
@@ -20,8 +26,13 @@ device_idx = 0 if torch.cuda.is_available() else -1
20
  clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx)
21
 
22
  label_map = {
23
- "LABEL_0": "A", "LABEL_1": "B", "LABEL_2": "C",
24
- "LABEL_3": "D", "LABEL_4": "E", "LABEL_5": "F", "LABEL_6": "G"
 
 
 
 
 
25
  }
26
 
27
  SYSTEM_MSG = (
@@ -29,29 +40,32 @@ SYSTEM_MSG = (
29
  "لا تذكر اسمك أو منصة عملك إلا إذا سُئلت صراحةً عن هويتك."
30
  )
31
 
32
- def classify_question(text: str, thr: float = 0.5) -> str:
33
  pred = max(clf_pipe(text), key=lambda x: x["score"])
34
- return label_map.get(pred["label"], pred["label"]) if pred["score"] >= thr else "G"
35
 
36
- def build_prompt(question: str, tag: str) -> str:
37
  return (
38
- f"{SYSTEM_MSG}\n\nfinal_QT: {tag}\n\n"
 
39
  f"سؤال المستخدم:\n{question}\n\n"
40
  "اكتب فقرة واحدة مفصّلة لا تقل عن ثلاث جمل مترابطة، بعد أن تفكّر خطوة بخطوة.\n"
41
  "الإجابة النهائية:\n"
42
  )
43
 
44
- def generate_mentalqa_answer(question: str) -> str:
45
- tag = classify_question(question)
46
- prompt = build_prompt(question, tag)
47
- chat_ids = chat_tok.apply_chat_template(
48
- [{"role": "system", "content": SYSTEM_MSG}, {"role": "user", "content": prompt}],
 
 
49
  add_generation_prompt=True,
50
  return_tensors="pt"
51
  ).to(chat_model.device)
52
 
53
- gen_ids = chat_model.generate(
54
- chat_ids,
55
  max_new_tokens=128,
56
  do_sample=True,
57
  temperature=0.6,
@@ -62,5 +76,5 @@ def generate_mentalqa_answer(question: str) -> str:
62
  eos_token_id=chat_tok.eos_token_id,
63
  )[0]
64
 
65
- answer_ids = gen_ids[chat_ids.shape[1]:]
66
- return chat_tok.decode(answer_ids, skip_special_tokens=True).strip()
 
1
+ # MentalQA – Arabic Mental Health Assistant (chat + classifier)
2
+
3
  import torch
4
+ from transformers import (
5
+ AutoTokenizer, AutoModelForCausalLM,
6
+ AutoModelForSequenceClassification, pipeline
7
+ )
8
 
9
  CHAT_REPO = "yasser-alharbi/MentalQA"
10
  CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification"
11
 
12
+ # Load chat model
13
  chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False)
14
  chat_model = AutoModelForCausalLM.from_pretrained(
15
  CHAT_REPO,
 
18
  low_cpu_mem_usage=True,
19
  )
20
 
21
+ # Load classifier
22
  clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO)
23
  clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO)
24
 
 
26
  clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx)
27
 
28
  label_map = {
29
+ "LABEL_0": "A", # تشخيص
30
+ "LABEL_1": "B", # علاج
31
+ "LABEL_2": "C", # تشريح
32
+ "LABEL_3": "D", # وبائيات
33
+ "LABEL_4": "E", # نمط حياة
34
+ "LABEL_5": "F", # مقدم خدمة
35
+ "LABEL_6": "G", # أخرى
36
  }
37
 
38
  SYSTEM_MSG = (
 
40
  "لا تذكر اسمك أو منصة عملك إلا إذا سُئلت صراحةً عن هويتك."
41
  )
42
 
43
+ def classify_question(text: str, threshold: float = 0.5) -> str:
44
  pred = max(clf_pipe(text), key=lambda x: x["score"])
45
+ return label_map.get(pred["label"], "G") if pred["score"] >= threshold else "G"
46
 
47
+ def build_prompt(question: str, final_qt: str) -> str:
48
  return (
49
+ f"{SYSTEM_MSG}\n\n"
50
+ f"final_QT: {final_qt}\n\n"
51
  f"سؤال المستخدم:\n{question}\n\n"
52
  "اكتب فقرة واحدة مفصّلة لا تقل عن ثلاث جمل مترابطة، بعد أن تفكّر خطوة بخطوة.\n"
53
  "الإجابة النهائية:\n"
54
  )
55
 
56
+ def generate_mentalqa_answer(question: str, threshold: float = 0.5) -> str:
57
+ final_qt = classify_question(question, threshold)
58
+ prompt = build_prompt(question, final_qt)
59
+
60
+ chat_input = chat_tok.apply_chat_template(
61
+ [{"role": "system", "content": SYSTEM_MSG},
62
+ {"role": "user", "content": prompt}],
63
  add_generation_prompt=True,
64
  return_tensors="pt"
65
  ).to(chat_model.device)
66
 
67
+ gen_output = chat_model.generate(
68
+ chat_input,
69
  max_new_tokens=128,
70
  do_sample=True,
71
  temperature=0.6,
 
76
  eos_token_id=chat_tok.eos_token_id,
77
  )[0]
78
 
79
+ answer = chat_tok.decode(gen_output[chat_input.shape[1]:], skip_special_tokens=True)
80
+ return answer.strip()