Spaces:
Sleeping
Sleeping
File size: 6,438 Bytes
086ed8e 35d9d45 b1848f9 086ed8e 35d9d45 9c33257 35d9d45 086ed8e 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a 35d9d45 b1848f9 d8c1d71 086ed8e 35d9d45 d8c1d71 086ed8e 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a f54d28c 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a 35d9d45 426176a d8c1d71 086ed8e 35d9d45 4287d7f 35d9d45 426176a 4287d7f 426176a b1848f9 35d9d45 b1848f9 35d9d45 426176a 35d9d45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gradio as gr
import os
# Update this to your Hugging Face model ID
MODEL_ID = "ShenghaoYummy/TinyLlama-ECommerce-Chatbot" # Replace with your actual model ID
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
def load_model():
"""Load the fine-tuned model with PEFT adapter"""
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
# Ensure pad token is set
if tokenizer.pad_token is None:
print("Tokenizer pad_token not set. Setting to eos_token.")
tokenizer.pad_token = tokenizer.eos_token
# It's also good to ensure the model's config reflects this if it's used during generation
# model.config.pad_token_id = tokenizer.pad_token_id
# (Do this after model is loaded if needed, but usually tokenizer.pad_token_id in generate is enough)
print(f"Tokenizer pad_token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
print(f"Tokenizer eos_token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}")
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
load_in_4bit=True, # comment out to use full precision
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
print("Loading PEFT adapter...")
model = PeftModel.from_pretrained(base_model, MODEL_ID)
# If you had to set tokenizer.pad_token, ensure the merged model's config is also aware
# This is more relevant if not passing pad_token_id directly to generate, but good for consistency
if model.config.pad_token_id is None and tokenizer.pad_token_id is not None:
print(f"Setting model.config.pad_token_id to: {tokenizer.pad_token_id}")
model.config.pad_token_id = tokenizer.pad_token_id
print("Model loaded successfully!")
return model, tokenizer
# Load model and tokenizer
model, tokenizer = load_model()
def generate(message, history):
"""
Generate response using the fine-tuned e-commerce chatbot
message: Current user message (string)
history: List of [user_message, assistant_message] pairs
returns: assistant's reply (string)
"""
DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance."
conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n"
if history: # Ensure history is not None or empty before iterating
for user_msg, assistant_msg in history:
# Ensure messages are strings
user_msg_str = str(user_msg) if user_msg is not None else ""
assistant_msg_str = str(assistant_msg) if assistant_msg is not None else ""
conversation += f"<|user|>\n{user_msg_str}\n<|assistant|>\n{assistant_msg_str}\n"
message_str = str(message) if message is not None else ""
conversation += f"<|user|>\n{message_str}\n<|assistant|>\n"
print(f"--- Constructed Prompt ---\n{conversation}\n--------------------------")
inputs = tokenizer(
conversation,
return_tensors="pt",
max_length=512, # Max length of context + new tokens for some models, but here it's input context length
truncation=True,
padding=True # Pad to max_length or longest in batch if dynamic
).to(model.device)
input_length = inputs["input_ids"].shape[1]
# Ensure eos_token_id is correctly set for generation
# If your model was trained to use <|end|> as an EOS token, its ID should be tokenizer.eos_token_id
eos_token_id_to_use = tokenizer.eos_token_id
# Example: if <|end|> has a specific ID different from the default eos_token
# end_custom_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
# if end_custom_token_id != tokenizer.unk_token_id: # Check if token exists
# eos_token_id_to_use = end_custom_token_id
# print(f"Using EOS token ID for generation: {eos_token_id_to_use}")
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=300,
do_sample=True,
temperature=0.8,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=eos_token_id_to_use, # Use the determined EOS token ID
)
new_tokens = outputs[0][input_length:]
generated_reply_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
print(f"--- Raw Generated Reply Part (after skip_special_tokens=True) ---\n{generated_reply_part}\n----------------------------------------------------------------")
end_token_marker = "<|end|>" # The specific string marker you're looking for
first_end_token_pos = generated_reply_part.find(end_token_marker)
if first_end_token_pos != -1:
reply = generated_reply_part[:first_end_token_pos].strip()
else:
reply = generated_reply_part # Use the whole string if <|end|> isn't found
# Fallback if the reply is empty after processing
if not reply:
print("Warning: Reply became empty after processing. Using fallback.")
reply = "I apologize, but I couldn't generate a proper response. Please try again."
print(f"--- Final Reply ---\n{reply}\n-------------------")
return reply
# Build Gradio ChatInterface
demo = (
gr.ChatInterface(
fn=generate,
title="E-commerce Customer Service Chatbot",
description="Chat with our AI-powered e-commerce assistant. Ask about products, orders, shipping, returns, and more!",
examples=[
"What's your return policy?",
"How long does shipping take?",
"Do you have any discounts available?",
"I need help with my order",
"What payment methods do you accept?"
],
type="messages", # Ensures history is a list of lists/tuples
)
.queue(api_open=True)
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
) |