ShenghaoYummy commited on
Commit
426176a
·
verified ·
1 Parent(s): f54d28c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -32
app.py CHANGED
@@ -15,8 +15,15 @@ def load_model():
15
 
16
  # Ensure pad token is set
17
  if tokenizer.pad_token is None:
 
18
  tokenizer.pad_token = tokenizer.eos_token
 
 
 
19
 
 
 
 
20
  print("Loading base model...")
21
  base_model = AutoModelForCausalLM.from_pretrained(
22
  BASE_MODEL_ID,
@@ -29,6 +36,12 @@ def load_model():
29
  print("Loading PEFT adapter...")
30
  model = PeftModel.from_pretrained(base_model, MODEL_ID)
31
 
 
 
 
 
 
 
32
  print("Model loaded successfully!")
33
  return model, tokenizer
34
 
@@ -42,32 +55,45 @@ def generate(message, history):
42
  history: List of [user_message, assistant_message] pairs
43
  returns: assistant's reply (string)
44
  """
45
- # Use ChatML format that your model was trained on
46
  DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance."
47
 
48
- # Build conversation in ChatML format
49
  conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n"
 
 
 
 
 
 
50
 
51
- # Add history
52
- for user_msg, assistant_msg in history:
53
- conversation += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
54
-
55
- # Add current message
56
- conversation += f"<|user|>\n{message}\n<|assistant|>\n"
57
 
58
- # Tokenize
 
59
  inputs = tokenizer(
60
  conversation,
61
  return_tensors="pt",
62
- max_length=512,
63
  truncation=True,
64
- padding=True
65
  ).to(model.device)
66
 
67
- # Generate response
 
 
 
 
 
 
 
 
 
 
 
68
  with torch.no_grad():
69
  outputs = model.generate(
70
- **inputs,
 
71
  max_new_tokens=300,
72
  do_sample=True,
73
  temperature=0.8,
@@ -75,26 +101,28 @@ def generate(message, history):
75
  top_k=50,
76
  repetition_penalty=1.1,
77
  pad_token_id=tokenizer.pad_token_id,
78
- eos_token_id=tokenizer.eos_token_id,
79
  )
80
 
81
- # Decode and extract assistant response
82
- full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
 
84
- # Extract only the new assistant response
85
- if "<|assistant|>" in full_text:
86
- # Get the last assistant response
87
- assistant_parts = full_text.split("<|assistant|>")
88
- if len(assistant_parts) > 1:
89
- reply = assistant_parts[-1].strip()
90
- # Remove any trailing tokens
91
- if "<|user|>" in reply:
92
- reply = reply.split("<|user|>")[0].strip()
93
- else:
94
- reply = "I apologize, but I couldn't generate a proper response. Please try again."
95
  else:
 
 
 
 
 
96
  reply = "I apologize, but I couldn't generate a proper response. Please try again."
97
 
 
98
  return reply
99
 
100
  # Build Gradio ChatInterface
@@ -110,15 +138,15 @@ demo = (
110
  "I need help with my order",
111
  "What payment methods do you accept?"
112
  ],
113
- type="messages",
114
  )
115
- .queue(api_open=True) # allow direct HTTP POST to /api/predict
116
  )
117
 
118
  # Launch the app
119
  if __name__ == "__main__":
120
  demo.launch(
121
- server_name="0.0.0.0", # Allow external access
122
- server_port=7860, # Default Gradio port
123
- share=False # Set to True if you want a public link
124
  )
 
15
 
16
  # Ensure pad token is set
17
  if tokenizer.pad_token is None:
18
+ print("Tokenizer pad_token not set. Setting to eos_token.")
19
  tokenizer.pad_token = tokenizer.eos_token
20
+ # It's also good to ensure the model's config reflects this if it's used during generation
21
+ # model.config.pad_token_id = tokenizer.pad_token_id
22
+ # (Do this after model is loaded if needed, but usually tokenizer.pad_token_id in generate is enough)
23
 
24
+ print(f"Tokenizer pad_token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
25
+ print(f"Tokenizer eos_token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}")
26
+
27
  print("Loading base model...")
28
  base_model = AutoModelForCausalLM.from_pretrained(
29
  BASE_MODEL_ID,
 
36
  print("Loading PEFT adapter...")
37
  model = PeftModel.from_pretrained(base_model, MODEL_ID)
38
 
39
+ # If you had to set tokenizer.pad_token, ensure the merged model's config is also aware
40
+ # This is more relevant if not passing pad_token_id directly to generate, but good for consistency
41
+ if model.config.pad_token_id is None and tokenizer.pad_token_id is not None:
42
+ print(f"Setting model.config.pad_token_id to: {tokenizer.pad_token_id}")
43
+ model.config.pad_token_id = tokenizer.pad_token_id
44
+
45
  print("Model loaded successfully!")
46
  return model, tokenizer
47
 
 
55
  history: List of [user_message, assistant_message] pairs
56
  returns: assistant's reply (string)
57
  """
 
58
  DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance."
59
 
 
60
  conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n"
61
+ if history: # Ensure history is not None or empty before iterating
62
+ for user_msg, assistant_msg in history:
63
+ # Ensure messages are strings
64
+ user_msg_str = str(user_msg) if user_msg is not None else ""
65
+ assistant_msg_str = str(assistant_msg) if assistant_msg is not None else ""
66
+ conversation += f"<|user|>\n{user_msg_str}\n<|assistant|>\n{assistant_msg_str}\n"
67
 
68
+ message_str = str(message) if message is not None else ""
69
+ conversation += f"<|user|>\n{message_str}\n<|assistant|>\n"
 
 
 
 
70
 
71
+ print(f"--- Constructed Prompt ---\n{conversation}\n--------------------------")
72
+
73
  inputs = tokenizer(
74
  conversation,
75
  return_tensors="pt",
76
+ max_length=512, # Max length of context + new tokens for some models, but here it's input context length
77
  truncation=True,
78
+ padding=True # Pad to max_length or longest in batch if dynamic
79
  ).to(model.device)
80
 
81
+ input_length = inputs["input_ids"].shape[1]
82
+
83
+ # Ensure eos_token_id is correctly set for generation
84
+ # If your model was trained to use <|end|> as an EOS token, its ID should be tokenizer.eos_token_id
85
+ eos_token_id_to_use = tokenizer.eos_token_id
86
+ # Example: if <|end|> has a specific ID different from the default eos_token
87
+ # end_custom_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
88
+ # if end_custom_token_id != tokenizer.unk_token_id: # Check if token exists
89
+ # eos_token_id_to_use = end_custom_token_id
90
+ # print(f"Using EOS token ID for generation: {eos_token_id_to_use}")
91
+
92
+
93
  with torch.no_grad():
94
  outputs = model.generate(
95
+ input_ids=inputs["input_ids"],
96
+ attention_mask=inputs["attention_mask"],
97
  max_new_tokens=300,
98
  do_sample=True,
99
  temperature=0.8,
 
101
  top_k=50,
102
  repetition_penalty=1.1,
103
  pad_token_id=tokenizer.pad_token_id,
104
+ eos_token_id=eos_token_id_to_use, # Use the determined EOS token ID
105
  )
106
 
107
+ new_tokens = outputs[0][input_length:]
108
+ generated_reply_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
109
 
110
+ print(f"--- Raw Generated Reply Part (after skip_special_tokens=True) ---\n{generated_reply_part}\n----------------------------------------------------------------")
111
+
112
+ end_token_marker = "<|end|>" # The specific string marker you're looking for
113
+ first_end_token_pos = generated_reply_part.find(end_token_marker)
114
+
115
+ if first_end_token_pos != -1:
116
+ reply = generated_reply_part[:first_end_token_pos].strip()
 
 
 
 
117
  else:
118
+ reply = generated_reply_part # Use the whole string if <|end|> isn't found
119
+
120
+ # Fallback if the reply is empty after processing
121
+ if not reply:
122
+ print("Warning: Reply became empty after processing. Using fallback.")
123
  reply = "I apologize, but I couldn't generate a proper response. Please try again."
124
 
125
+ print(f"--- Final Reply ---\n{reply}\n-------------------")
126
  return reply
127
 
128
  # Build Gradio ChatInterface
 
138
  "I need help with my order",
139
  "What payment methods do you accept?"
140
  ],
141
+ type="messages", # Ensures history is a list of lists/tuples
142
  )
143
+ .queue(api_open=True)
144
  )
145
 
146
  # Launch the app
147
  if __name__ == "__main__":
148
  demo.launch(
149
+ server_name="0.0.0.0",
150
+ server_port=7860,
151
+ share=False
152
  )