junaidbaber commited on
Commit
0d5774d
·
verified ·
1 Parent(s): fccfdf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -58
app.py CHANGED
@@ -1,86 +1,134 @@
1
  import streamlit as st
2
  from huggingface_hub import login
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  import torch
 
5
  import os
6
 
7
  def initialize_model():
8
- """Initialize the model and tokenizer"""
9
  # Log in to Hugging Face
10
  token = os.environ.get("hf")
11
- login(token)
 
12
 
13
- # Define the model ID and device
14
- model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
-
17
- # Configure INT8 quantization
18
- bnb_config = BitsAndBytesConfig(
19
- load_in_8bit=True,
20
- llm_int8_enable_fp32_cpu_offload=True
21
- )
22
-
23
- # Load tokenizer and model
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_id,
27
- quantization_config=bnb_config,
28
- device_map="auto"
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Ensure padding token is defined
32
  if tokenizer.pad_token is None:
33
  tokenizer.pad_token = tokenizer.eos_token
34
 
35
- return model, tokenizer, device
36
 
37
  def format_conversation(conversation_history):
38
  """Format the conversation history into a single string."""
39
  formatted = ""
40
  for turn in conversation_history:
41
- formatted += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
42
  return formatted.strip()
43
 
44
- def generate_response(model, tokenizer, device, prompt, conversation_history):
45
  """Generate model response"""
46
  # Format the entire conversation context
47
  context = format_conversation(conversation_history[:-1])
48
  if context:
49
- full_prompt = f"{context}\nUser: {prompt}"
50
  else:
51
- full_prompt = f"User: {prompt}"
52
 
53
  # Tokenize input
54
- inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(device)
 
 
 
 
55
 
56
  # Calculate max new tokens
57
  input_length = inputs["input_ids"].shape[1]
58
- max_model_length = 2048
59
- max_new_tokens = min(200, max_model_length - input_length)
60
-
61
- # Generate response
62
- outputs = model.generate(
63
- inputs["input_ids"],
64
- attention_mask=inputs["attention_mask"],
65
- max_new_tokens=max_new_tokens,
66
- temperature=0.7,
67
- top_p=0.9,
68
- pad_token_id=tokenizer.pad_token_id,
69
- do_sample=True,
70
- min_length=20,
71
- no_repeat_ngram_size=3
72
- )
73
-
74
- # Decode response
75
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
- response_parts = response.split("User: ")
77
- model_response = response_parts[-1].split("Assistant: ")[-1].strip()
78
-
79
- return model_response
 
 
 
 
 
 
 
80
 
81
  def main():
82
- st.set_page_config(page_title="LLM Chat Interface", page_icon="🤖")
83
- st.title("Chat with LLM 🤖")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # Initialize session state for chat history
86
  if "chat_history" not in st.session_state:
@@ -89,10 +137,14 @@ def main():
89
  # Initialize model (only once)
90
  if "model" not in st.session_state:
91
  with st.spinner("Loading the model... This might take a minute..."):
92
- model, tokenizer, device = initialize_model()
93
- st.session_state.model = model
94
- st.session_state.tokenizer = tokenizer
95
- st.session_state.device = device
 
 
 
 
96
 
97
  # Display chat messages
98
  for message in st.session_state.chat_history:
@@ -116,7 +168,6 @@ def main():
116
  response = generate_response(
117
  st.session_state.model,
118
  st.session_state.tokenizer,
119
- st.session_state.device,
120
  prompt,
121
  st.session_state.chat_history
122
  )
@@ -128,10 +179,20 @@ def main():
128
  if len(st.session_state.chat_history) > 5:
129
  st.session_state.chat_history = st.session_state.chat_history[-5:]
130
 
131
- # Add a clear chat button
132
- if st.sidebar.button("Clear Chat"):
133
- st.session_state.chat_history = []
134
- st.rerun()
 
 
 
 
 
 
 
 
 
 
135
 
136
  if __name__ == "__main__":
137
  main()
 
1
  import streamlit as st
2
  from huggingface_hub import login
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
+ from transformers import BitsAndBytesConfig
6
  import os
7
 
8
  def initialize_model():
9
+ """Initialize the model and tokenizer with CPU support"""
10
  # Log in to Hugging Face
11
  token = os.environ.get("hf")
12
+ if token:
13
+ login(token)
14
 
15
+ # Use a smaller model that's more CPU-friendly
16
+ model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Much smaller model
17
+
18
+ # Load tokenizer
 
 
 
 
 
 
 
19
  tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+
21
+ # Configure 4-bit quantization for CPU
22
+ try:
23
+ # First try with bitsandbytes 4-bit quantization
24
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
25
+
26
+ compute_dtype = getattr(torch, "float16")
27
+
28
+ bnb_config = BitsAndBytesConfig(
29
+ load_in_4bit=True,
30
+ bnb_4bit_quant_type="nf4",
31
+ bnb_4bit_compute_dtype=compute_dtype,
32
+ bnb_4bit_use_double_quant=False,
33
+ )
34
+
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_id,
37
+ quantization_config=bnb_config,
38
+ device_map="auto",
39
+ trust_remote_code=True
40
+ )
41
+ except:
42
+ # Fallback to CPU without quantization
43
+ print("Falling back to CPU without quantization")
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ model_id,
46
+ device_map="cpu",
47
+ trust_remote_code=True,
48
+ low_cpu_mem_usage=True
49
+ )
50
 
51
  # Ensure padding token is defined
52
  if tokenizer.pad_token is None:
53
  tokenizer.pad_token = tokenizer.eos_token
54
 
55
+ return model, tokenizer
56
 
57
  def format_conversation(conversation_history):
58
  """Format the conversation history into a single string."""
59
  formatted = ""
60
  for turn in conversation_history:
61
+ formatted += f"Human: {turn['user']}\nAssistant: {turn['assistant']}\n"
62
  return formatted.strip()
63
 
64
+ def generate_response(model, tokenizer, prompt, conversation_history):
65
  """Generate model response"""
66
  # Format the entire conversation context
67
  context = format_conversation(conversation_history[:-1])
68
  if context:
69
+ full_prompt = f"{context}\nHuman: {prompt}"
70
  else:
71
+ full_prompt = f"Human: {prompt}"
72
 
73
  # Tokenize input
74
+ inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
75
+
76
+ # Move inputs to the same device as the model
77
+ device = next(model.parameters()).device
78
+ inputs = {k: v.to(device) for k, v in inputs.items()}
79
 
80
  # Calculate max new tokens
81
  input_length = inputs["input_ids"].shape[1]
82
+ max_model_length = 1024 # Reduced context window for memory efficiency
83
+ max_new_tokens = min(150, max_model_length - input_length)
84
+
85
+ try:
86
+ # Generate response with lower temperature for faster generation
87
+ outputs = model.generate(
88
+ inputs["input_ids"],
89
+ attention_mask=inputs["attention_mask"],
90
+ max_new_tokens=max_new_tokens,
91
+ temperature=0.5, # Lower temperature for faster, more focused responses
92
+ top_p=0.9,
93
+ pad_token_id=tokenizer.pad_token_id,
94
+ do_sample=True,
95
+ min_length=10, # Reduced minimum length
96
+ no_repeat_ngram_size=3
97
+ )
98
+
99
+ # Decode response
100
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
101
+ response_parts = response.split("Human: ")
102
+ model_response = response_parts[-1].split("Assistant: ")[-1].strip()
103
+
104
+ return model_response
105
+ except RuntimeError as e:
106
+ if "out of memory" in str(e):
107
+ torch.cuda.empty_cache()
108
+ return "I apologize, but I ran out of memory. Please try a shorter message or clear the chat history."
109
+ else:
110
+ return f"An error occurred: {str(e)}"
111
 
112
  def main():
113
+ st.set_page_config(
114
+ page_title="LLM Chat Interface",
115
+ page_icon="🤖",
116
+ layout="wide"
117
+ )
118
+
119
+ # Add CSS to make the chat interface more compact
120
+ st.markdown("""
121
+ <style>
122
+ .stChat {
123
+ padding-top: 0rem;
124
+ }
125
+ .stChatMessage {
126
+ padding: 0.5rem;
127
+ }
128
+ </style>
129
+ """, unsafe_allow_html=True)
130
+
131
+ st.title("Welcome to LowCode No Code Demo")
132
 
133
  # Initialize session state for chat history
134
  if "chat_history" not in st.session_state:
 
137
  # Initialize model (only once)
138
  if "model" not in st.session_state:
139
  with st.spinner("Loading the model... This might take a minute..."):
140
+ try:
141
+ model, tokenizer = initialize_model()
142
+ st.session_state.model = model
143
+ st.session_state.tokenizer = tokenizer
144
+ st.success("Model loaded successfully!")
145
+ except Exception as e:
146
+ st.error(f"Error loading model: {str(e)}")
147
+ return
148
 
149
  # Display chat messages
150
  for message in st.session_state.chat_history:
 
168
  response = generate_response(
169
  st.session_state.model,
170
  st.session_state.tokenizer,
 
171
  prompt,
172
  st.session_state.chat_history
173
  )
 
179
  if len(st.session_state.chat_history) > 5:
180
  st.session_state.chat_history = st.session_state.chat_history[-5:]
181
 
182
+ # Sidebar controls
183
+ with st.sidebar:
184
+ st.title("Controls")
185
+ if st.button("Clear Chat"):
186
+ st.session_state.chat_history = []
187
+ st.rerun()
188
+
189
+ st.markdown("---")
190
+ st.markdown("""
191
+ ### Model Info
192
+ - Using TinyLlama 1.1B Chat
193
+ - Optimized for CPU usage
194
+ - Context window: 1024 tokens
195
+ """)
196
 
197
  if __name__ == "__main__":
198
  main()