Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
import gradio as gr | |
import os | |
# Update this to your Hugging Face model ID | |
MODEL_ID = "ShenghaoYummy/TinyLlama-ECommerce-Chatbot" # Replace with your actual model ID | |
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
def load_model(): | |
"""Load the fine-tuned model with PEFT adapter""" | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
# Ensure pad token is set | |
if tokenizer.pad_token is None: | |
print("Tokenizer pad_token not set. Setting to eos_token.") | |
tokenizer.pad_token = tokenizer.eos_token | |
# It's also good to ensure the model's config reflects this if it's used during generation | |
# model.config.pad_token_id = tokenizer.pad_token_id | |
# (Do this after model is loaded if needed, but usually tokenizer.pad_token_id in generate is enough) | |
print(f"Tokenizer pad_token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}") | |
print(f"Tokenizer eos_token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}") | |
print("Loading base model...") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL_ID, | |
load_in_4bit=True, # comment out to use full precision | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
print("Loading PEFT adapter...") | |
model = PeftModel.from_pretrained(base_model, MODEL_ID) | |
# If you had to set tokenizer.pad_token, ensure the merged model's config is also aware | |
# This is more relevant if not passing pad_token_id directly to generate, but good for consistency | |
if model.config.pad_token_id is None and tokenizer.pad_token_id is not None: | |
print(f"Setting model.config.pad_token_id to: {tokenizer.pad_token_id}") | |
model.config.pad_token_id = tokenizer.pad_token_id | |
print("Model loaded successfully!") | |
return model, tokenizer | |
# Load model and tokenizer | |
model, tokenizer = load_model() | |
def generate(message, history): | |
""" | |
Generate response using the fine-tuned e-commerce chatbot | |
message: Current user message (string) | |
history: List of [user_message, assistant_message] pairs | |
returns: assistant's reply (string) | |
""" | |
DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance." | |
conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n" | |
if history: # Ensure history is not None or empty before iterating | |
for user_msg, assistant_msg in history: | |
# Ensure messages are strings | |
user_msg_str = str(user_msg) if user_msg is not None else "" | |
assistant_msg_str = str(assistant_msg) if assistant_msg is not None else "" | |
conversation += f"<|user|>\n{user_msg_str}\n<|assistant|>\n{assistant_msg_str}\n" | |
message_str = str(message) if message is not None else "" | |
conversation += f"<|user|>\n{message_str}\n<|assistant|>\n" | |
print(f"--- Constructed Prompt ---\n{conversation}\n--------------------------") | |
inputs = tokenizer( | |
conversation, | |
return_tensors="pt", | |
max_length=512, # Max length of context + new tokens for some models, but here it's input context length | |
truncation=True, | |
padding=True # Pad to max_length or longest in batch if dynamic | |
).to(model.device) | |
input_length = inputs["input_ids"].shape[1] | |
# Ensure eos_token_id is correctly set for generation | |
# If your model was trained to use <|end|> as an EOS token, its ID should be tokenizer.eos_token_id | |
eos_token_id_to_use = tokenizer.eos_token_id | |
# Example: if <|end|> has a specific ID different from the default eos_token | |
# end_custom_token_id = tokenizer.convert_tokens_to_ids("<|end|>") | |
# if end_custom_token_id != tokenizer.unk_token_id: # Check if token exists | |
# eos_token_id_to_use = end_custom_token_id | |
# print(f"Using EOS token ID for generation: {eos_token_id_to_use}") | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_new_tokens=300, | |
do_sample=True, | |
temperature=0.8, | |
top_p=0.9, | |
top_k=50, | |
repetition_penalty=1.1, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=eos_token_id_to_use, # Use the determined EOS token ID | |
) | |
new_tokens = outputs[0][input_length:] | |
generated_reply_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
print(f"--- Raw Generated Reply Part (after skip_special_tokens=True) ---\n{generated_reply_part}\n----------------------------------------------------------------") | |
end_token_marker = "<|end|>" # The specific string marker you're looking for | |
first_end_token_pos = generated_reply_part.find(end_token_marker) | |
if first_end_token_pos != -1: | |
reply = generated_reply_part[:first_end_token_pos].strip() | |
else: | |
reply = generated_reply_part # Use the whole string if <|end|> isn't found | |
# Fallback if the reply is empty after processing | |
if not reply: | |
print("Warning: Reply became empty after processing. Using fallback.") | |
reply = "I apologize, but I couldn't generate a proper response. Please try again." | |
print(f"--- Final Reply ---\n{reply}\n-------------------") | |
return reply | |
# Build Gradio ChatInterface | |
demo = ( | |
gr.ChatInterface( | |
fn=generate, | |
title="E-commerce Customer Service Chatbot", | |
description="Chat with our AI-powered e-commerce assistant. Ask about products, orders, shipping, returns, and more!", | |
examples=[ | |
"What's your return policy?", | |
"How long does shipping take?", | |
"Do you have any discounts available?", | |
"I need help with my order", | |
"What payment methods do you accept?" | |
], | |
type="messages", # Ensures history is a list of lists/tuples | |
) | |
.queue(api_open=True) | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |