thanglekdi commited on
Commit
e4b71d4
·
1 Parent(s): 2935264

sau khi da tham khao huong dan tren github

Browse files
Files changed (1) hide show
  1. app.py +57 -21
app.py CHANGED
@@ -1,40 +1,76 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
3
 
4
- # Load PhoGPT-4B-Chat model and tokenizer
5
- tokenizer = AutoTokenizer.from_pretrained("vinai/PhoGPT-4B-Chat", trust_remote_code=True)
6
- model = AutoModelForCausalLM.from_pretrained("vinai/PhoGPT-4B-Chat", trust_remote_code=True)
7
 
8
- def respond(message, history, system_message, max_tokens, temperature, top_p):
9
- messages = f"{system_message}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  for user_msg, bot_msg in history:
11
  if user_msg:
12
- messages += f"User: {user_msg}\n"
13
  if bot_msg:
14
- messages += f"Bot: {bot_msg}\n"
15
- messages += f"User: {message}\nBot:"
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- inputs = tokenizer(messages, return_tensors="pt")
18
  outputs = model.generate(
19
- **inputs,
 
20
  max_new_tokens=max_tokens,
21
  temperature=temperature,
22
  top_p=top_p,
 
23
  do_sample=True,
24
- pad_token_id=tokenizer.eos_token_id
 
25
  )
26
- full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
- response = full_output.replace(messages, "").strip()
28
- yield response
29
 
 
 
 
 
30
 
 
31
  demo = gr.ChatInterface(
32
- respond,
33
  additional_inputs=[
34
- gr.Textbox(value="Bạn là một chatbot người Việt thân thiện.", label="System message"),
35
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
36
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
37
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
38
  ],
39
  )
40
 
 
1
+ # app.py
2
+ import torch # type: ignore
3
+ import gradio as gr # type: ignore
4
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer # type: ignore
5
 
6
+ # 1️⃣ Chuẩn bị model tokenizer
7
+ model_path = "vinai/PhoGPT-4B-Chat"
 
8
 
9
+ # Load config, ép sang GPU và bfloat16
10
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
11
+ config.init_device = "cuda"
12
+ # Nếu bạn đã cài FlashAttention, có thể bật:
13
+ # config.attn_config["attn_impl"] = "flash"
14
+
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_path,
17
+ config=config,
18
+ torch_dtype=torch.bfloat16,
19
+ trust_remote_code=True,
20
+ )
21
+ model.eval() # Chuyển model sang chế độ inference
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
24
+
25
+ # 2️⃣ Hàm xử lý chat
26
+ def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
27
+ # 2.1 — Gom prompt theo chuẩn chat
28
+ messages = [{"role": "system", "content": system_message}]
29
  for user_msg, bot_msg in history:
30
  if user_msg:
31
+ messages.append({"role": "user", "content": user_msg})
32
  if bot_msg:
33
+ messages.append({"role": "assistant", "content": bot_msg})
34
+ messages.append({"role": "user", "content": message})
35
+
36
+ # 2.2 — Dùng apply_chat_template để có prompt chính xác
37
+ prompt = tokenizer.apply_chat_template(
38
+ messages,
39
+ tokenize=False,
40
+ add_generation_prompt=True,
41
+ )
42
+
43
+ # 2.3 — Tokenize và đưa lên GPU
44
+ inputs = tokenizer(prompt, return_tensors="pt")
45
+ input_ids = inputs["input_ids"].to(model.device)
46
+ attention_mask = inputs["attention_mask"].to(model.device)
47
 
48
+ # 2.4 Sinh văn bản
49
  outputs = model.generate(
50
+ input_ids=input_ids,
51
+ attention_mask=attention_mask,
52
  max_new_tokens=max_tokens,
53
  temperature=temperature,
54
  top_p=top_p,
55
+ top_k=50, # theo hướng dẫn thêm top_k
56
  do_sample=True,
57
+ eos_token_id=tokenizer.eos_token_id,
58
+ pad_token_id=tokenizer.pad_token_id,
59
  )
 
 
 
60
 
61
+ # 2.5 — Decode và tách phần chatbot trả lời
62
+ full = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+ answer = full.replace(prompt, "").strip()
64
+ yield answer
65
 
66
+ # 3️⃣ Tạo giao diện Gradio
67
  demo = gr.ChatInterface(
68
+ fn=respond,
69
  additional_inputs=[
70
+ gr.Textbox(value="Bạn là một chatbot tiếng Việt thân thiện.", label="System message"),
71
+ gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
72
+ gr.Slider(0.1, 4.0, value=1.0, step=0.1, label="Temperature"),
73
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
74
  ],
75
  )
76