import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import gradio as gr import os # Update this to your Hugging Face model ID MODEL_ID = "ShenghaoYummy/TinyLlama-ECommerce-Chatbot" # Replace with your actual model ID BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" def load_model(): """Load the fine-tuned model with PEFT adapter""" print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) # Ensure pad token is set if tokenizer.pad_token is None: print("Tokenizer pad_token not set. Setting to eos_token.") tokenizer.pad_token = tokenizer.eos_token # It's also good to ensure the model's config reflects this if it's used during generation # model.config.pad_token_id = tokenizer.pad_token_id # (Do this after model is loaded if needed, but usually tokenizer.pad_token_id in generate is enough) print(f"Tokenizer pad_token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}") print(f"Tokenizer eos_token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}") print("Loading base model...") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, load_in_4bit=True, # comment out to use full precision torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) print("Loading PEFT adapter...") model = PeftModel.from_pretrained(base_model, MODEL_ID) # If you had to set tokenizer.pad_token, ensure the merged model's config is also aware # This is more relevant if not passing pad_token_id directly to generate, but good for consistency if model.config.pad_token_id is None and tokenizer.pad_token_id is not None: print(f"Setting model.config.pad_token_id to: {tokenizer.pad_token_id}") model.config.pad_token_id = tokenizer.pad_token_id print("Model loaded successfully!") return model, tokenizer # Load model and tokenizer model, tokenizer = load_model() def generate(message, history): """ Generate response using the fine-tuned e-commerce chatbot message: Current user message (string) history: List of [user_message, assistant_message] pairs returns: assistant's reply (string) """ DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance." conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n" if history: # Ensure history is not None or empty before iterating for user_msg, assistant_msg in history: # Ensure messages are strings user_msg_str = str(user_msg) if user_msg is not None else "" assistant_msg_str = str(assistant_msg) if assistant_msg is not None else "" conversation += f"<|user|>\n{user_msg_str}\n<|assistant|>\n{assistant_msg_str}\n" message_str = str(message) if message is not None else "" conversation += f"<|user|>\n{message_str}\n<|assistant|>\n" print(f"--- Constructed Prompt ---\n{conversation}\n--------------------------") inputs = tokenizer( conversation, return_tensors="pt", max_length=512, # Max length of context + new tokens for some models, but here it's input context length truncation=True, padding=True # Pad to max_length or longest in batch if dynamic ).to(model.device) input_length = inputs["input_ids"].shape[1] # Ensure eos_token_id is correctly set for generation # If your model was trained to use <|end|> as an EOS token, its ID should be tokenizer.eos_token_id eos_token_id_to_use = tokenizer.eos_token_id # Example: if <|end|> has a specific ID different from the default eos_token # end_custom_token_id = tokenizer.convert_tokens_to_ids("<|end|>") # if end_custom_token_id != tokenizer.unk_token_id: # Check if token exists # eos_token_id_to_use = end_custom_token_id # print(f"Using EOS token ID for generation: {eos_token_id_to_use}") with torch.no_grad(): outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=300, do_sample=True, temperature=0.8, top_p=0.9, top_k=50, repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_token_id_to_use, # Use the determined EOS token ID ) new_tokens = outputs[0][input_length:] generated_reply_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() print(f"--- Raw Generated Reply Part (after skip_special_tokens=True) ---\n{generated_reply_part}\n----------------------------------------------------------------") end_token_marker = "<|end|>" # The specific string marker you're looking for first_end_token_pos = generated_reply_part.find(end_token_marker) if first_end_token_pos != -1: reply = generated_reply_part[:first_end_token_pos].strip() else: reply = generated_reply_part # Use the whole string if <|end|> isn't found # Fallback if the reply is empty after processing if not reply: print("Warning: Reply became empty after processing. Using fallback.") reply = "I apologize, but I couldn't generate a proper response. Please try again." print(f"--- Final Reply ---\n{reply}\n-------------------") return reply # Build Gradio ChatInterface demo = ( gr.ChatInterface( fn=generate, title="E-commerce Customer Service Chatbot", description="Chat with our AI-powered e-commerce assistant. Ask about products, orders, shipping, returns, and more!", examples=[ "What's your return policy?", "How long does shipping take?", "Do you have any discounts available?", "I need help with my order", "What payment methods do you accept?" ], type="messages", # Ensures history is a list of lists/tuples ) .queue(api_open=True) ) # Launch the app if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False )