import streamlit as st from huggingface_hub import login from transformers import AutoTokenizer, AutoModelForCausalLM import torch from transformers import BitsAndBytesConfig import os def initialize_model(): """Initialize the model and tokenizer with CPU support""" # Log in to Hugging Face token = os.environ.get("hf") if token: login(token) model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) try: # Try with regular CPU mode first (simpler and more reliable) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True ) except Exception as e: print(f"Error loading model: {str(e)}") raise e # Ensure padding token is defined if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return model, tokenizer def format_prompt(user_input, conversation_history=[]): """Format the prompt according to TinyLlama's expected chat format""" messages = [] # Add conversation history for turn in conversation_history: messages.append({"role": "user", "content": turn["user"]}) messages.append({"role": "assistant", "content": turn["assistant"]}) # Add current user input messages.append({"role": "user", "content": user_input}) # Format into TinyLlama chat format formatted_prompt = "<|system|>You are a helpful AI assistant." for message in messages: if message["role"] == "user": formatted_prompt += f"<|user|>{message['content']}" else: formatted_prompt += f"<|assistant|>{message['content']}" formatted_prompt += "<|assistant|>" return formatted_prompt def generate_response(model, tokenizer, prompt, conversation_history): """Generate model response""" try: # Format prompt using TinyLlama's chat template formatted_prompt = format_prompt(prompt, conversation_history[:-1]) # Tokenize input inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True, truncation=True) # Move inputs to the same device as the model device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} # Calculate max new tokens input_length = inputs["input_ids"].shape[1] max_model_length = 1024 max_new_tokens = min(150, max_model_length - input_length) # Generate response outputs = model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=max_new_tokens, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.pad_token_id, do_sample=True, min_length=10, no_repeat_ngram_size=3, eos_token_id=tokenizer.encode("")[0] # Set end token ) # Decode response and extract only the assistant's message full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) # Extract only the last assistant response assistant_response = full_response.split("<|assistant|>")[-1].split("")[0].strip() return assistant_response if assistant_response else "I apologize, but I couldn't generate a proper response." except RuntimeError as e: if "out of memory" in str(e): torch.cuda.empty_cache() return "I apologize, but I ran out of memory. Please try a shorter message or clear the chat history." else: return f"An error occurred: {str(e)}" def main(): st.set_page_config( page_title="LLM Chat Interface", page_icon="🤖", layout="wide" ) # Add CSS to make the chat interface more compact st.markdown(""" """, unsafe_allow_html=True) st.title("Chat with TinyLlama 🤖") # Initialize session state for chat history if "chat_history" not in st.session_state: st.session_state.chat_history = [] # Initialize model (only once) if "model" not in st.session_state: with st.spinner("Loading the model... This might take a minute..."): try: model, tokenizer = initialize_model() st.session_state.model = model st.session_state.tokenizer = tokenizer st.success("Model loaded successfully!") except Exception as e: st.error(f"Error loading model: {str(e)}") return # Display chat messages for message in st.session_state.chat_history: with st.chat_message("user"): st.write(message["user"]) with st.chat_message("assistant"): st.write(message["assistant"]) # Chat input if prompt := st.chat_input("What would you like to know?"): # Display user message with st.chat_message("user"): st.write(prompt) # Generate and display assistant response with st.chat_message("assistant"): with st.spinner("Thinking..."): current_turn = {"user": prompt, "assistant": ""} st.session_state.chat_history.append(current_turn) response = generate_response( st.session_state.model, st.session_state.tokenizer, prompt, st.session_state.chat_history ) st.write(response) st.session_state.chat_history[-1]["assistant"] = response # Manage context window if len(st.session_state.chat_history) > 5: st.session_state.chat_history = st.session_state.chat_history[-5:] # Sidebar controls with st.sidebar: st.title("Controls") if st.button("Clear Chat"): st.session_state.chat_history = [] st.rerun() st.markdown("---") st.markdown(""" ### Model Info - Using TinyLlama 1.1B Chat - CPU optimized - Context window: 1024 tokens """) if __name__ == "__main__": main()