junaidbaber's picture
Update app.py
03b1321 verified
raw
history blame
6.55 kB
import streamlit as st
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import BitsAndBytesConfig
import os
def initialize_model():
"""Initialize the model and tokenizer with CPU support"""
# Log in to Hugging Face
token = os.environ.get("hf")
if token:
login(token)
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
try:
# Try with regular CPU mode first (simpler and more reliable)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
trust_remote_code=True,
low_cpu_mem_usage=True
)
except Exception as e:
print(f"Error loading model: {str(e)}")
raise e
# Ensure padding token is defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def format_prompt(user_input, conversation_history=[]):
"""Format the prompt according to TinyLlama's expected chat format"""
messages = []
# Add conversation history
for turn in conversation_history:
messages.append({"role": "user", "content": turn["user"]})
messages.append({"role": "assistant", "content": turn["assistant"]})
# Add current user input
messages.append({"role": "user", "content": user_input})
# Format into TinyLlama chat format
formatted_prompt = "<|system|>You are a helpful AI assistant.</s>"
for message in messages:
if message["role"] == "user":
formatted_prompt += f"<|user|>{message['content']}</s>"
else:
formatted_prompt += f"<|assistant|>{message['content']}</s>"
formatted_prompt += "<|assistant|>"
return formatted_prompt
def generate_response(model, tokenizer, prompt, conversation_history):
"""Generate model response"""
try:
# Format prompt using TinyLlama's chat template
formatted_prompt = format_prompt(prompt, conversation_history[:-1])
# Tokenize input
inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True, truncation=True)
# Move inputs to the same device as the model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Calculate max new tokens
input_length = inputs["input_ids"].shape[1]
max_model_length = 1024
max_new_tokens = min(150, max_model_length - input_length)
# Generate response
outputs = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
min_length=10,
no_repeat_ngram_size=3,
eos_token_id=tokenizer.encode("</s>")[0] # Set end token
)
# Decode response and extract only the assistant's message
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
# Extract only the last assistant response
assistant_response = full_response.split("<|assistant|>")[-1].split("</s>")[0].strip()
return assistant_response if assistant_response else "I apologize, but I couldn't generate a proper response."
except RuntimeError as e:
if "out of memory" in str(e):
torch.cuda.empty_cache()
return "I apologize, but I ran out of memory. Please try a shorter message or clear the chat history."
else:
return f"An error occurred: {str(e)}"
def main():
st.set_page_config(
page_title="LLM Chat Interface",
page_icon="πŸ€–",
layout="wide"
)
# Add CSS to make the chat interface more compact
st.markdown("""
<style>
.stChat {
padding-top: 0rem;
}
.stChatMessage {
padding: 0.5rem;
}
</style>
""", unsafe_allow_html=True)
st.title("Chat with TinyLlama πŸ€–")
# Initialize session state for chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Initialize model (only once)
if "model" not in st.session_state:
with st.spinner("Loading the model... This might take a minute..."):
try:
model, tokenizer = initialize_model()
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# Display chat messages
for message in st.session_state.chat_history:
with st.chat_message("user"):
st.write(message["user"])
with st.chat_message("assistant"):
st.write(message["assistant"])
# Chat input
if prompt := st.chat_input("What would you like to know?"):
# Display user message
with st.chat_message("user"):
st.write(prompt)
# Generate and display assistant response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
current_turn = {"user": prompt, "assistant": ""}
st.session_state.chat_history.append(current_turn)
response = generate_response(
st.session_state.model,
st.session_state.tokenizer,
prompt,
st.session_state.chat_history
)
st.write(response)
st.session_state.chat_history[-1]["assistant"] = response
# Manage context window
if len(st.session_state.chat_history) > 5:
st.session_state.chat_history = st.session_state.chat_history[-5:]
# Sidebar controls
with st.sidebar:
st.title("Controls")
if st.button("Clear Chat"):
st.session_state.chat_history = []
st.rerun()
st.markdown("---")
st.markdown("""
### Model Info
- Using TinyLlama 1.1B Chat
- CPU optimized
- Context window: 1024 tokens
""")
if __name__ == "__main__":
main()