Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -15,8 +15,15 @@ def load_model():
|
|
15 |
|
16 |
# Ensure pad token is set
|
17 |
if tokenizer.pad_token is None:
|
|
|
18 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
19 |
|
|
|
|
|
|
|
20 |
print("Loading base model...")
|
21 |
base_model = AutoModelForCausalLM.from_pretrained(
|
22 |
BASE_MODEL_ID,
|
@@ -29,6 +36,12 @@ def load_model():
|
|
29 |
print("Loading PEFT adapter...")
|
30 |
model = PeftModel.from_pretrained(base_model, MODEL_ID)
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
print("Model loaded successfully!")
|
33 |
return model, tokenizer
|
34 |
|
@@ -42,32 +55,45 @@ def generate(message, history):
|
|
42 |
history: List of [user_message, assistant_message] pairs
|
43 |
returns: assistant's reply (string)
|
44 |
"""
|
45 |
-
# Use ChatML format that your model was trained on
|
46 |
DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance."
|
47 |
|
48 |
-
# Build conversation in ChatML format
|
49 |
conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
conversation += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
|
54 |
-
|
55 |
-
# Add current message
|
56 |
-
conversation += f"<|user|>\n{message}\n<|assistant|>\n"
|
57 |
|
58 |
-
|
|
|
59 |
inputs = tokenizer(
|
60 |
conversation,
|
61 |
return_tensors="pt",
|
62 |
-
max_length=512,
|
63 |
truncation=True,
|
64 |
-
padding=True
|
65 |
).to(model.device)
|
66 |
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
with torch.no_grad():
|
69 |
outputs = model.generate(
|
70 |
-
|
|
|
71 |
max_new_tokens=300,
|
72 |
do_sample=True,
|
73 |
temperature=0.8,
|
@@ -75,26 +101,28 @@ def generate(message, history):
|
|
75 |
top_k=50,
|
76 |
repetition_penalty=1.1,
|
77 |
pad_token_id=tokenizer.pad_token_id,
|
78 |
-
eos_token_id=
|
79 |
)
|
80 |
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
if "<|user|>" in reply:
|
92 |
-
reply = reply.split("<|user|>")[0].strip()
|
93 |
-
else:
|
94 |
-
reply = "I apologize, but I couldn't generate a proper response. Please try again."
|
95 |
else:
|
|
|
|
|
|
|
|
|
|
|
96 |
reply = "I apologize, but I couldn't generate a proper response. Please try again."
|
97 |
|
|
|
98 |
return reply
|
99 |
|
100 |
# Build Gradio ChatInterface
|
@@ -110,15 +138,15 @@ demo = (
|
|
110 |
"I need help with my order",
|
111 |
"What payment methods do you accept?"
|
112 |
],
|
113 |
-
type="messages",
|
114 |
)
|
115 |
-
.queue(api_open=True)
|
116 |
)
|
117 |
|
118 |
# Launch the app
|
119 |
if __name__ == "__main__":
|
120 |
demo.launch(
|
121 |
-
server_name="0.0.0.0",
|
122 |
-
server_port=7860,
|
123 |
-
share=False
|
124 |
)
|
|
|
15 |
|
16 |
# Ensure pad token is set
|
17 |
if tokenizer.pad_token is None:
|
18 |
+
print("Tokenizer pad_token not set. Setting to eos_token.")
|
19 |
tokenizer.pad_token = tokenizer.eos_token
|
20 |
+
# It's also good to ensure the model's config reflects this if it's used during generation
|
21 |
+
# model.config.pad_token_id = tokenizer.pad_token_id
|
22 |
+
# (Do this after model is loaded if needed, but usually tokenizer.pad_token_id in generate is enough)
|
23 |
|
24 |
+
print(f"Tokenizer pad_token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
|
25 |
+
print(f"Tokenizer eos_token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}")
|
26 |
+
|
27 |
print("Loading base model...")
|
28 |
base_model = AutoModelForCausalLM.from_pretrained(
|
29 |
BASE_MODEL_ID,
|
|
|
36 |
print("Loading PEFT adapter...")
|
37 |
model = PeftModel.from_pretrained(base_model, MODEL_ID)
|
38 |
|
39 |
+
# If you had to set tokenizer.pad_token, ensure the merged model's config is also aware
|
40 |
+
# This is more relevant if not passing pad_token_id directly to generate, but good for consistency
|
41 |
+
if model.config.pad_token_id is None and tokenizer.pad_token_id is not None:
|
42 |
+
print(f"Setting model.config.pad_token_id to: {tokenizer.pad_token_id}")
|
43 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
44 |
+
|
45 |
print("Model loaded successfully!")
|
46 |
return model, tokenizer
|
47 |
|
|
|
55 |
history: List of [user_message, assistant_message] pairs
|
56 |
returns: assistant's reply (string)
|
57 |
"""
|
|
|
58 |
DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance."
|
59 |
|
|
|
60 |
conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n"
|
61 |
+
if history: # Ensure history is not None or empty before iterating
|
62 |
+
for user_msg, assistant_msg in history:
|
63 |
+
# Ensure messages are strings
|
64 |
+
user_msg_str = str(user_msg) if user_msg is not None else ""
|
65 |
+
assistant_msg_str = str(assistant_msg) if assistant_msg is not None else ""
|
66 |
+
conversation += f"<|user|>\n{user_msg_str}\n<|assistant|>\n{assistant_msg_str}\n"
|
67 |
|
68 |
+
message_str = str(message) if message is not None else ""
|
69 |
+
conversation += f"<|user|>\n{message_str}\n<|assistant|>\n"
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
print(f"--- Constructed Prompt ---\n{conversation}\n--------------------------")
|
72 |
+
|
73 |
inputs = tokenizer(
|
74 |
conversation,
|
75 |
return_tensors="pt",
|
76 |
+
max_length=512, # Max length of context + new tokens for some models, but here it's input context length
|
77 |
truncation=True,
|
78 |
+
padding=True # Pad to max_length or longest in batch if dynamic
|
79 |
).to(model.device)
|
80 |
|
81 |
+
input_length = inputs["input_ids"].shape[1]
|
82 |
+
|
83 |
+
# Ensure eos_token_id is correctly set for generation
|
84 |
+
# If your model was trained to use <|end|> as an EOS token, its ID should be tokenizer.eos_token_id
|
85 |
+
eos_token_id_to_use = tokenizer.eos_token_id
|
86 |
+
# Example: if <|end|> has a specific ID different from the default eos_token
|
87 |
+
# end_custom_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
|
88 |
+
# if end_custom_token_id != tokenizer.unk_token_id: # Check if token exists
|
89 |
+
# eos_token_id_to_use = end_custom_token_id
|
90 |
+
# print(f"Using EOS token ID for generation: {eos_token_id_to_use}")
|
91 |
+
|
92 |
+
|
93 |
with torch.no_grad():
|
94 |
outputs = model.generate(
|
95 |
+
input_ids=inputs["input_ids"],
|
96 |
+
attention_mask=inputs["attention_mask"],
|
97 |
max_new_tokens=300,
|
98 |
do_sample=True,
|
99 |
temperature=0.8,
|
|
|
101 |
top_k=50,
|
102 |
repetition_penalty=1.1,
|
103 |
pad_token_id=tokenizer.pad_token_id,
|
104 |
+
eos_token_id=eos_token_id_to_use, # Use the determined EOS token ID
|
105 |
)
|
106 |
|
107 |
+
new_tokens = outputs[0][input_length:]
|
108 |
+
generated_reply_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
109 |
|
110 |
+
print(f"--- Raw Generated Reply Part (after skip_special_tokens=True) ---\n{generated_reply_part}\n----------------------------------------------------------------")
|
111 |
+
|
112 |
+
end_token_marker = "<|end|>" # The specific string marker you're looking for
|
113 |
+
first_end_token_pos = generated_reply_part.find(end_token_marker)
|
114 |
+
|
115 |
+
if first_end_token_pos != -1:
|
116 |
+
reply = generated_reply_part[:first_end_token_pos].strip()
|
|
|
|
|
|
|
|
|
117 |
else:
|
118 |
+
reply = generated_reply_part # Use the whole string if <|end|> isn't found
|
119 |
+
|
120 |
+
# Fallback if the reply is empty after processing
|
121 |
+
if not reply:
|
122 |
+
print("Warning: Reply became empty after processing. Using fallback.")
|
123 |
reply = "I apologize, but I couldn't generate a proper response. Please try again."
|
124 |
|
125 |
+
print(f"--- Final Reply ---\n{reply}\n-------------------")
|
126 |
return reply
|
127 |
|
128 |
# Build Gradio ChatInterface
|
|
|
138 |
"I need help with my order",
|
139 |
"What payment methods do you accept?"
|
140 |
],
|
141 |
+
type="messages", # Ensures history is a list of lists/tuples
|
142 |
)
|
143 |
+
.queue(api_open=True)
|
144 |
)
|
145 |
|
146 |
# Launch the app
|
147 |
if __name__ == "__main__":
|
148 |
demo.launch(
|
149 |
+
server_name="0.0.0.0",
|
150 |
+
server_port=7860,
|
151 |
+
share=False
|
152 |
)
|