AI-chatbot / app.py
ShenghaoYummy's picture
Update app.py
426176a verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gradio as gr
import os
# Update this to your Hugging Face model ID
MODEL_ID = "ShenghaoYummy/TinyLlama-ECommerce-Chatbot" # Replace with your actual model ID
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
def load_model():
"""Load the fine-tuned model with PEFT adapter"""
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
# Ensure pad token is set
if tokenizer.pad_token is None:
print("Tokenizer pad_token not set. Setting to eos_token.")
tokenizer.pad_token = tokenizer.eos_token
# It's also good to ensure the model's config reflects this if it's used during generation
# model.config.pad_token_id = tokenizer.pad_token_id
# (Do this after model is loaded if needed, but usually tokenizer.pad_token_id in generate is enough)
print(f"Tokenizer pad_token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
print(f"Tokenizer eos_token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}")
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
load_in_4bit=True, # comment out to use full precision
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
print("Loading PEFT adapter...")
model = PeftModel.from_pretrained(base_model, MODEL_ID)
# If you had to set tokenizer.pad_token, ensure the merged model's config is also aware
# This is more relevant if not passing pad_token_id directly to generate, but good for consistency
if model.config.pad_token_id is None and tokenizer.pad_token_id is not None:
print(f"Setting model.config.pad_token_id to: {tokenizer.pad_token_id}")
model.config.pad_token_id = tokenizer.pad_token_id
print("Model loaded successfully!")
return model, tokenizer
# Load model and tokenizer
model, tokenizer = load_model()
def generate(message, history):
"""
Generate response using the fine-tuned e-commerce chatbot
message: Current user message (string)
history: List of [user_message, assistant_message] pairs
returns: assistant's reply (string)
"""
DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance."
conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n"
if history: # Ensure history is not None or empty before iterating
for user_msg, assistant_msg in history:
# Ensure messages are strings
user_msg_str = str(user_msg) if user_msg is not None else ""
assistant_msg_str = str(assistant_msg) if assistant_msg is not None else ""
conversation += f"<|user|>\n{user_msg_str}\n<|assistant|>\n{assistant_msg_str}\n"
message_str = str(message) if message is not None else ""
conversation += f"<|user|>\n{message_str}\n<|assistant|>\n"
print(f"--- Constructed Prompt ---\n{conversation}\n--------------------------")
inputs = tokenizer(
conversation,
return_tensors="pt",
max_length=512, # Max length of context + new tokens for some models, but here it's input context length
truncation=True,
padding=True # Pad to max_length or longest in batch if dynamic
).to(model.device)
input_length = inputs["input_ids"].shape[1]
# Ensure eos_token_id is correctly set for generation
# If your model was trained to use <|end|> as an EOS token, its ID should be tokenizer.eos_token_id
eos_token_id_to_use = tokenizer.eos_token_id
# Example: if <|end|> has a specific ID different from the default eos_token
# end_custom_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
# if end_custom_token_id != tokenizer.unk_token_id: # Check if token exists
# eos_token_id_to_use = end_custom_token_id
# print(f"Using EOS token ID for generation: {eos_token_id_to_use}")
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=300,
do_sample=True,
temperature=0.8,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=eos_token_id_to_use, # Use the determined EOS token ID
)
new_tokens = outputs[0][input_length:]
generated_reply_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
print(f"--- Raw Generated Reply Part (after skip_special_tokens=True) ---\n{generated_reply_part}\n----------------------------------------------------------------")
end_token_marker = "<|end|>" # The specific string marker you're looking for
first_end_token_pos = generated_reply_part.find(end_token_marker)
if first_end_token_pos != -1:
reply = generated_reply_part[:first_end_token_pos].strip()
else:
reply = generated_reply_part # Use the whole string if <|end|> isn't found
# Fallback if the reply is empty after processing
if not reply:
print("Warning: Reply became empty after processing. Using fallback.")
reply = "I apologize, but I couldn't generate a proper response. Please try again."
print(f"--- Final Reply ---\n{reply}\n-------------------")
return reply
# Build Gradio ChatInterface
demo = (
gr.ChatInterface(
fn=generate,
title="E-commerce Customer Service Chatbot",
description="Chat with our AI-powered e-commerce assistant. Ask about products, orders, shipping, returns, and more!",
examples=[
"What's your return policy?",
"How long does shipping take?",
"Do you have any discounts available?",
"I need help with my order",
"What payment methods do you accept?"
],
type="messages", # Ensures history is a list of lists/tuples
)
.queue(api_open=True)
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)