ShenghaoYummy commited on
Commit
d8c1d71
·
1 Parent(s): 8512311
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -16,17 +16,22 @@ model = AutoModelForCausalLM.from_pretrained(
16
  )
17
 
18
  # 2) define inference function
19
- def generate(messages):
20
  """
21
- messages: List of alternating [user, assistant, user, ...]
22
- returns: [user, assistant, user, assistant, ...] with model's reply appended
 
23
  """
24
- # rebuild a single prompt string
25
  prompt = ""
26
- for i in range(0, len(messages), 2):
27
- prompt += f"User: {messages[i]}\n"
28
- if i+1 < len(messages):
29
- prompt += f"Assistant: {messages[i+1]}\n"
 
 
 
 
30
  prompt += "Assistant:"
31
 
32
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -40,8 +45,7 @@ def generate(messages):
40
  # strip everything before the last "Assistant:"
41
  reply = text.split("Assistant:")[-1].strip()
42
 
43
- messages.append(reply)
44
- return messages
45
 
46
  # 3) build Gradio ChatInterface
47
  demo = gr.ChatInterface(
 
16
  )
17
 
18
  # 2) define inference function
19
+ def generate(message, history):
20
  """
21
+ message: Current user message (string)
22
+ history: List of [user_message, assistant_message] pairs
23
+ returns: assistant's reply (string)
24
  """
25
+ # rebuild a single prompt string from history + current message
26
  prompt = ""
27
+
28
+ # Add conversation history
29
+ for user_msg, assistant_msg in history:
30
+ prompt += f"User: {user_msg}\n"
31
+ prompt += f"Assistant: {assistant_msg}\n"
32
+
33
+ # Add current user message
34
+ prompt += f"User: {message}\n"
35
  prompt += "Assistant:"
36
 
37
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
45
  # strip everything before the last "Assistant:"
46
  reply = text.split("Assistant:")[-1].strip()
47
 
48
+ return reply
 
49
 
50
  # 3) build Gradio ChatInterface
51
  demo = gr.ChatInterface(