junaidbaber's picture
Update app.py
fccfdf4 verified
raw
history blame
4.57 kB
import streamlit as st
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import os
def initialize_model():
"""Initialize the model and tokenizer"""
# Log in to Hugging Face
token = os.environ.get("hf")
login(token)
# Define the model ID and device
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Configure INT8 quantization
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True
)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto"
)
# Ensure padding token is defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer, device
def format_conversation(conversation_history):
"""Format the conversation history into a single string."""
formatted = ""
for turn in conversation_history:
formatted += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
return formatted.strip()
def generate_response(model, tokenizer, device, prompt, conversation_history):
"""Generate model response"""
# Format the entire conversation context
context = format_conversation(conversation_history[:-1])
if context:
full_prompt = f"{context}\nUser: {prompt}"
else:
full_prompt = f"User: {prompt}"
# Tokenize input
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(device)
# Calculate max new tokens
input_length = inputs["input_ids"].shape[1]
max_model_length = 2048
max_new_tokens = min(200, max_model_length - input_length)
# Generate response
outputs = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
min_length=20,
no_repeat_ngram_size=3
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response_parts = response.split("User: ")
model_response = response_parts[-1].split("Assistant: ")[-1].strip()
return model_response
def main():
st.set_page_config(page_title="LLM Chat Interface", page_icon="πŸ€–")
st.title("Chat with LLM πŸ€–")
# Initialize session state for chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Initialize model (only once)
if "model" not in st.session_state:
with st.spinner("Loading the model... This might take a minute..."):
model, tokenizer, device = initialize_model()
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.session_state.device = device
# Display chat messages
for message in st.session_state.chat_history:
with st.chat_message("user"):
st.write(message["user"])
with st.chat_message("assistant"):
st.write(message["assistant"])
# Chat input
if prompt := st.chat_input("What would you like to know?"):
# Display user message
with st.chat_message("user"):
st.write(prompt)
# Generate and display assistant response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
current_turn = {"user": prompt, "assistant": ""}
st.session_state.chat_history.append(current_turn)
response = generate_response(
st.session_state.model,
st.session_state.tokenizer,
st.session_state.device,
prompt,
st.session_state.chat_history
)
st.write(response)
st.session_state.chat_history[-1]["assistant"] = response
# Manage context window
if len(st.session_state.chat_history) > 5:
st.session_state.chat_history = st.session_state.chat_history[-5:]
# Add a clear chat button
if st.sidebar.button("Clear Chat"):
st.session_state.chat_history = []
st.rerun()
if __name__ == "__main__":
main()