ShenghaoYummy commited on
Commit
35d9d45
·
verified ·
1 Parent(s): 37550b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -32
app.py CHANGED
@@ -1,56 +1,124 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  import gradio as gr
4
  import os
5
 
6
- MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
 
7
 
8
- # 1) load model & tokenizer
9
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
10
- model = AutoModelForCausalLM.from_pretrained(
11
- MODEL_ID,
12
- load_in_4bit=True, # comment out to use full precision
13
- torch_dtype=torch.float16,
14
- device_map="auto",
15
- trust_remote_code=True,
16
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # 2) define inference function
19
  def generate(message, history):
20
  """
 
21
  message: Current user message (string)
22
  history: List of [user_message, assistant_message] pairs
23
  returns: assistant's reply (string)
24
  """
25
- # rebuild a single prompt string from history + current message
26
- prompt = ""
 
 
 
 
 
27
  for user_msg, assistant_msg in history:
28
- prompt += f"User: {user_msg}\n"
29
- prompt += f"Assistant: {assistant_msg}\n"
30
- prompt += f"User: {message}\nAssistant:"
31
-
32
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
- outputs = model.generate(
34
- **inputs,
35
- max_new_tokens=128,
36
- do_sample=True,
37
- temperature=0.7,
38
- )
39
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
- reply = text.split("Assistant:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return reply
42
 
43
- # 3) build Gradio ChatInterface *with open_routes enabled*
44
  demo = (
45
  gr.ChatInterface(
46
  fn=generate,
47
- title="TinyLlama-1.1B Chat API",
48
- description="Chat with TinyLlama-1.1B and call via /api/predict",
 
 
 
 
 
 
 
49
  type="messages",
50
  )
51
- .queue(api_open=True) # allow direct HTTP POST to /api/predict
52
  )
53
 
54
- # 4) launch
55
  if __name__ == "__main__":
56
- demo.launch()
 
 
 
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
  import gradio as gr
5
  import os
6
 
7
+ # Update this to your Hugging Face model ID
8
+ MODEL_ID = "YourUsername/TinyLlama-ECommerce-Chatbot" # Replace with your actual model ID
9
+ BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
10
 
11
+ def load_model():
12
+ """Load the fine-tuned model with PEFT adapter"""
13
+ print("Loading tokenizer...")
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
15
+
16
+ # Ensure pad token is set
17
+ if tokenizer.pad_token is None:
18
+ tokenizer.pad_token = tokenizer.eos_token
19
+
20
+ print("Loading base model...")
21
+ base_model = AutoModelForCausalLM.from_pretrained(
22
+ BASE_MODEL_ID,
23
+ load_in_4bit=True, # comment out to use full precision
24
+ torch_dtype=torch.float16,
25
+ device_map="auto",
26
+ trust_remote_code=True,
27
+ )
28
+
29
+ print("Loading PEFT adapter...")
30
+ model = PeftModel.from_pretrained(base_model, MODEL_ID)
31
+
32
+ print("Model loaded successfully!")
33
+ return model, tokenizer
34
+
35
+ # Load model and tokenizer
36
+ model, tokenizer = load_model()
37
 
 
38
  def generate(message, history):
39
  """
40
+ Generate response using the fine-tuned e-commerce chatbot
41
  message: Current user message (string)
42
  history: List of [user_message, assistant_message] pairs
43
  returns: assistant's reply (string)
44
  """
45
+ # Use ChatML format that your model was trained on
46
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful e-commerce customer service assistant. Provide accurate, helpful, and friendly responses to customer inquiries about products, orders, shipping, returns, and general shopping assistance."
47
+
48
+ # Build conversation in ChatML format
49
+ conversation = f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}\n"
50
+
51
+ # Add history
52
  for user_msg, assistant_msg in history:
53
+ conversation += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
54
+
55
+ # Add current message
56
+ conversation += f"<|user|>\n{message}\n<|assistant|>\n"
57
+
58
+ # Tokenize
59
+ inputs = tokenizer(
60
+ conversation,
61
+ return_tensors="pt",
62
+ max_length=512,
63
+ truncation=True,
64
+ padding=True
65
+ ).to(model.device)
66
+
67
+ # Generate response
68
+ with torch.no_grad():
69
+ outputs = model.generate(
70
+ **inputs,
71
+ max_new_tokens=150,
72
+ do_sample=True,
73
+ temperature=0.8,
74
+ top_p=0.9,
75
+ top_k=50,
76
+ repetition_penalty=1.1,
77
+ pad_token_id=tokenizer.pad_token_id,
78
+ eos_token_id=tokenizer.eos_token_id,
79
+ )
80
+
81
+ # Decode and extract assistant response
82
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
+
84
+ # Extract only the new assistant response
85
+ if "<|assistant|>" in full_text:
86
+ # Get the last assistant response
87
+ assistant_parts = full_text.split("<|assistant|>")
88
+ if len(assistant_parts) > 1:
89
+ reply = assistant_parts[-1].strip()
90
+ # Remove any trailing tokens
91
+ if "<|user|>" in reply:
92
+ reply = reply.split("<|user|>")[0].strip()
93
+ else:
94
+ reply = "I apologize, but I couldn't generate a proper response. Please try again."
95
+ else:
96
+ reply = "I apologize, but I couldn't generate a proper response. Please try again."
97
+
98
  return reply
99
 
100
+ # Build Gradio ChatInterface
101
  demo = (
102
  gr.ChatInterface(
103
  fn=generate,
104
+ title="E-commerce Customer Service Chatbot",
105
+ description="Chat with our AI-powered e-commerce assistant. Ask about products, orders, shipping, returns, and more!",
106
+ examples=[
107
+ "What's your return policy?",
108
+ "How long does shipping take?",
109
+ "Do you have any discounts available?",
110
+ "I need help with my order",
111
+ "What payment methods do you accept?"
112
+ ],
113
  type="messages",
114
  )
115
+ .queue(api_open=True) # allow direct HTTP POST to /api/predict
116
  )
117
 
118
+ # Launch the app
119
  if __name__ == "__main__":
120
+ demo.launch(
121
+ server_name="0.0.0.0", # Allow external access
122
+ server_port=7860, # Default Gradio port
123
+ share=False # Set to True if you want a public link
124
+ )