junaidbaber's picture
Update app.py
0d5774d verified
raw
history blame
6.63 kB
import streamlit as st
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import BitsAndBytesConfig
import os
def initialize_model():
"""Initialize the model and tokenizer with CPU support"""
# Log in to Hugging Face
token = os.environ.get("hf")
if token:
login(token)
# Use a smaller model that's more CPU-friendly
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Much smaller model
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Configure 4-bit quantization for CPU
try:
# First try with bitsandbytes 4-bit quantization
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=False,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
except:
# Fallback to CPU without quantization
print("Falling back to CPU without quantization")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Ensure padding token is defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def format_conversation(conversation_history):
"""Format the conversation history into a single string."""
formatted = ""
for turn in conversation_history:
formatted += f"Human: {turn['user']}\nAssistant: {turn['assistant']}\n"
return formatted.strip()
def generate_response(model, tokenizer, prompt, conversation_history):
"""Generate model response"""
# Format the entire conversation context
context = format_conversation(conversation_history[:-1])
if context:
full_prompt = f"{context}\nHuman: {prompt}"
else:
full_prompt = f"Human: {prompt}"
# Tokenize input
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
# Move inputs to the same device as the model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Calculate max new tokens
input_length = inputs["input_ids"].shape[1]
max_model_length = 1024 # Reduced context window for memory efficiency
max_new_tokens = min(150, max_model_length - input_length)
try:
# Generate response with lower temperature for faster generation
outputs = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
temperature=0.5, # Lower temperature for faster, more focused responses
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
min_length=10, # Reduced minimum length
no_repeat_ngram_size=3
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response_parts = response.split("Human: ")
model_response = response_parts[-1].split("Assistant: ")[-1].strip()
return model_response
except RuntimeError as e:
if "out of memory" in str(e):
torch.cuda.empty_cache()
return "I apologize, but I ran out of memory. Please try a shorter message or clear the chat history."
else:
return f"An error occurred: {str(e)}"
def main():
st.set_page_config(
page_title="LLM Chat Interface",
page_icon="πŸ€–",
layout="wide"
)
# Add CSS to make the chat interface more compact
st.markdown("""
<style>
.stChat {
padding-top: 0rem;
}
.stChatMessage {
padding: 0.5rem;
}
</style>
""", unsafe_allow_html=True)
st.title("Welcome to LowCode No Code Demo")
# Initialize session state for chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Initialize model (only once)
if "model" not in st.session_state:
with st.spinner("Loading the model... This might take a minute..."):
try:
model, tokenizer = initialize_model()
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# Display chat messages
for message in st.session_state.chat_history:
with st.chat_message("user"):
st.write(message["user"])
with st.chat_message("assistant"):
st.write(message["assistant"])
# Chat input
if prompt := st.chat_input("What would you like to know?"):
# Display user message
with st.chat_message("user"):
st.write(prompt)
# Generate and display assistant response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
current_turn = {"user": prompt, "assistant": ""}
st.session_state.chat_history.append(current_turn)
response = generate_response(
st.session_state.model,
st.session_state.tokenizer,
prompt,
st.session_state.chat_history
)
st.write(response)
st.session_state.chat_history[-1]["assistant"] = response
# Manage context window
if len(st.session_state.chat_history) > 5:
st.session_state.chat_history = st.session_state.chat_history[-5:]
# Sidebar controls
with st.sidebar:
st.title("Controls")
if st.button("Clear Chat"):
st.session_state.chat_history = []
st.rerun()
st.markdown("---")
st.markdown("""
### Model Info
- Using TinyLlama 1.1B Chat
- Optimized for CPU usage
- Context window: 1024 tokens
""")
if __name__ == "__main__":
main()