File size: 6,438 Bytes
086ed8e
 
35d9d45
b1848f9
086ed8e
 
35d9d45
9c33257
35d9d45
086ed8e
35d9d45
 
 
 
 
 
 
426176a
35d9d45
426176a
 
 
35d9d45
426176a
 
 
35d9d45
 
 
 
 
 
 
 
 
 
 
 
426176a
 
 
 
 
 
35d9d45
 
 
 
 
b1848f9
d8c1d71
086ed8e
35d9d45
d8c1d71
 
 
086ed8e
35d9d45
 
 
426176a
 
 
 
 
 
35d9d45
426176a
 
35d9d45
426176a
 
35d9d45
 
 
426176a
35d9d45
426176a
35d9d45
 
426176a
 
 
 
 
 
 
 
 
 
 
 
35d9d45
 
426176a
 
f54d28c
35d9d45
 
 
 
 
 
426176a
35d9d45
 
426176a
 
35d9d45
426176a
 
 
 
 
 
 
35d9d45
426176a
 
 
 
 
35d9d45
 
426176a
d8c1d71
086ed8e
35d9d45
4287d7f
 
 
35d9d45
 
 
 
 
 
 
 
 
426176a
4287d7f
426176a
b1848f9
 
35d9d45
b1848f9
35d9d45
426176a
 
 
35d9d45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gradio as gr
import os

# Update this to your Hugging Face model ID
MODEL_ID = "ShenghaoYummy/TinyLlama-ECommerce-Chatbot"  # Replace with your actual model ID
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

def load_model():
    """Load the fine-tuned model with PEFT adapter"""
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    
    # Ensure pad token is set
    if tokenizer.pad_token is None:
        print("Tokenizer pad_token not set. Setting to eos_token.")
        tokenizer.pad_token = tokenizer.eos_token
        # It's also good to ensure the model's config reflects this if it's used during generation
        # model.config.pad_token_id = tokenizer.pad_token_id 
        # (Do this after model is loaded if needed, but usually tokenizer.pad_token_id in generate is enough)
    
    print(f"Tokenizer pad_token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
    print(f"Tokenizer eos_token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}")

    print("Loading base model...")
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        load_in_4bit=True,            # comment out to use full precision
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )
    
    print("Loading PEFT adapter...")
    model = PeftModel.from_pretrained(base_model, MODEL_ID)
    
    # If you had to set tokenizer.pad_token, ensure the merged model's config is also aware
    # This is more relevant if not passing pad_token_id directly to generate, but good for consistency
    if model.config.pad_token_id is None and tokenizer.pad_token_id is not None:
        print(f"Setting model.config.pad_token_id to: {tokenizer.pad_token_id}")
        model.config.pad_token_id = tokenizer.pad_token_id

    print("Model loaded successfully!")
    return model, tokenizer

# Load model and tokenizer
model, tokenizer = load_model()

def generate(message, history):
    """
    Generate response using the fine-tuned e-commerce chatbot
    message: Current user message (string)
    history: List of [user_message, assistant_message] pairs
    returns: assistant's reply (string)
    """
    DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance."
    
    conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n"
    if history: # Ensure history is not None or empty before iterating
        for user_msg, assistant_msg in history:
            # Ensure messages are strings
            user_msg_str = str(user_msg) if user_msg is not None else ""
            assistant_msg_str = str(assistant_msg) if assistant_msg is not None else ""
            conversation += f"<|user|>\n{user_msg_str}\n<|assistant|>\n{assistant_msg_str}\n"
    
    message_str = str(message) if message is not None else ""
    conversation += f"<|user|>\n{message_str}\n<|assistant|>\n"
    
    print(f"--- Constructed Prompt ---\n{conversation}\n--------------------------")

    inputs = tokenizer(
        conversation, 
        return_tensors="pt", 
        max_length=512, # Max length of context + new tokens for some models, but here it's input context length
        truncation=True,
        padding=True # Pad to max_length or longest in batch if dynamic
    ).to(model.device)
    
    input_length = inputs["input_ids"].shape[1]

    # Ensure eos_token_id is correctly set for generation
    # If your model was trained to use <|end|> as an EOS token, its ID should be tokenizer.eos_token_id
    eos_token_id_to_use = tokenizer.eos_token_id
    # Example: if <|end|> has a specific ID different from the default eos_token
    # end_custom_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
    # if end_custom_token_id != tokenizer.unk_token_id: # Check if token exists
    #     eos_token_id_to_use = end_custom_token_id
    # print(f"Using EOS token ID for generation: {eos_token_id_to_use}")


    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=300,
            do_sample=True,
            temperature=0.8,
            top_p=0.9,
            top_k=50,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=eos_token_id_to_use, # Use the determined EOS token ID
        )
    
    new_tokens = outputs[0][input_length:]
    generated_reply_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    
    print(f"--- Raw Generated Reply Part (after skip_special_tokens=True) ---\n{generated_reply_part}\n----------------------------------------------------------------")

    end_token_marker = "<|end|>" # The specific string marker you're looking for
    first_end_token_pos = generated_reply_part.find(end_token_marker)

    if first_end_token_pos != -1:
        reply = generated_reply_part[:first_end_token_pos].strip()
    else:
        reply = generated_reply_part # Use the whole string if <|end|> isn't found

    # Fallback if the reply is empty after processing
    if not reply:
        print("Warning: Reply became empty after processing. Using fallback.")
        reply = "I apologize, but I couldn't generate a proper response. Please try again."
    
    print(f"--- Final Reply ---\n{reply}\n-------------------")
    return reply

# Build Gradio ChatInterface
demo = (
    gr.ChatInterface(
        fn=generate,
        title="E-commerce Customer Service Chatbot",
        description="Chat with our AI-powered e-commerce assistant. Ask about products, orders, shipping, returns, and more!",
        examples=[
            "What's your return policy?",
            "How long does shipping take?",
            "Do you have any discounts available?",
            "I need help with my order",
            "What payment methods do you accept?"
        ],
        type="messages", # Ensures history is a list of lists/tuples
    )
    .queue(api_open=True)
)

# Launch the app
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False 
    )