import streamlit as st from huggingface_hub import login from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import torch import os def initialize_model(): """Initialize the model and tokenizer""" # Log in to Hugging Face token = os.environ.get("hf") login(token) # Define the model ID and device model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Configure INT8 quantization bnb_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True ) # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto" ) # Ensure padding token is defined if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return model, tokenizer, device def format_conversation(conversation_history): """Format the conversation history into a single string.""" formatted = "" for turn in conversation_history: formatted += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" return formatted.strip() def generate_response(model, tokenizer, device, prompt, conversation_history): """Generate model response""" # Format the entire conversation context context = format_conversation(conversation_history[:-1]) if context: full_prompt = f"{context}\nUser: {prompt}" else: full_prompt = f"User: {prompt}" # Tokenize input inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(device) # Calculate max new tokens input_length = inputs["input_ids"].shape[1] max_model_length = 2048 max_new_tokens = min(200, max_model_length - input_length) # Generate response outputs = model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=max_new_tokens, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.pad_token_id, do_sample=True, min_length=20, no_repeat_ngram_size=3 ) # Decode response response = tokenizer.decode(outputs[0], skip_special_tokens=True) response_parts = response.split("User: ") model_response = response_parts[-1].split("Assistant: ")[-1].strip() return model_response def main(): st.set_page_config(page_title="LLM Chat Interface", page_icon="🤖") st.title("Chat with LLM 🤖") # Initialize session state for chat history if "chat_history" not in st.session_state: st.session_state.chat_history = [] # Initialize model (only once) if "model" not in st.session_state: with st.spinner("Loading the model... This might take a minute..."): model, tokenizer, device = initialize_model() st.session_state.model = model st.session_state.tokenizer = tokenizer st.session_state.device = device # Display chat messages for message in st.session_state.chat_history: with st.chat_message("user"): st.write(message["user"]) with st.chat_message("assistant"): st.write(message["assistant"]) # Chat input if prompt := st.chat_input("What would you like to know?"): # Display user message with st.chat_message("user"): st.write(prompt) # Generate and display assistant response with st.chat_message("assistant"): with st.spinner("Thinking..."): current_turn = {"user": prompt, "assistant": ""} st.session_state.chat_history.append(current_turn) response = generate_response( st.session_state.model, st.session_state.tokenizer, st.session_state.device, prompt, st.session_state.chat_history ) st.write(response) st.session_state.chat_history[-1]["assistant"] = response # Manage context window if len(st.session_state.chat_history) > 5: st.session_state.chat_history = st.session_state.chat_history[-5:] # Add a clear chat button if st.sidebar.button("Clear Chat"): st.session_state.chat_history = [] st.rerun() if __name__ == "__main__": main()