Spaces:
Running
Running
import streamlit as st | |
from huggingface_hub import login | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from transformers import BitsAndBytesConfig | |
import os | |
def initialize_model(): | |
"""Initialize the model and tokenizer with CPU support""" | |
# Log in to Hugging Face | |
token = os.environ.get("hf") | |
if token: | |
login(token) | |
# Use a smaller model that's more CPU-friendly | |
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Much smaller model | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# Configure 4-bit quantization for CPU | |
try: | |
# First try with bitsandbytes 4-bit quantization | |
from transformers import AutoModelForCausalLM, BitsAndBytesConfig | |
compute_dtype = getattr(torch, "float16") | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=compute_dtype, | |
bnb_4bit_use_double_quant=False, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
except: | |
# Fallback to CPU without quantization | |
print("Falling back to CPU without quantization") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="cpu", | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
# Ensure padding token is defined | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
return model, tokenizer | |
def format_conversation(conversation_history): | |
"""Format the conversation history into a single string.""" | |
formatted = "" | |
for turn in conversation_history: | |
formatted += f"Human: {turn['user']}\nAssistant: {turn['assistant']}\n" | |
return formatted.strip() | |
def generate_response(model, tokenizer, prompt, conversation_history): | |
"""Generate model response""" | |
# Format the entire conversation context | |
context = format_conversation(conversation_history[:-1]) | |
if context: | |
full_prompt = f"{context}\nHuman: {prompt}" | |
else: | |
full_prompt = f"Human: {prompt}" | |
# Tokenize input | |
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) | |
# Move inputs to the same device as the model | |
device = next(model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Calculate max new tokens | |
input_length = inputs["input_ids"].shape[1] | |
max_model_length = 1024 # Reduced context window for memory efficiency | |
max_new_tokens = min(150, max_model_length - input_length) | |
try: | |
# Generate response with lower temperature for faster generation | |
outputs = model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_new_tokens=max_new_tokens, | |
temperature=0.5, # Lower temperature for faster, more focused responses | |
top_p=0.9, | |
pad_token_id=tokenizer.pad_token_id, | |
do_sample=True, | |
min_length=10, # Reduced minimum length | |
no_repeat_ngram_size=3 | |
) | |
# Decode response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response_parts = response.split("Human: ") | |
model_response = response_parts[-1].split("Assistant: ")[-1].strip() | |
return model_response | |
except RuntimeError as e: | |
if "out of memory" in str(e): | |
torch.cuda.empty_cache() | |
return "I apologize, but I ran out of memory. Please try a shorter message or clear the chat history." | |
else: | |
return f"An error occurred: {str(e)}" | |
def main(): | |
st.set_page_config( | |
page_title="LLM Chat Interface", | |
page_icon="π€", | |
layout="wide" | |
) | |
# Add CSS to make the chat interface more compact | |
st.markdown(""" | |
<style> | |
.stChat { | |
padding-top: 0rem; | |
} | |
.stChatMessage { | |
padding: 0.5rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title("Welcome to LowCode No Code Demo") | |
# Initialize session state for chat history | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# Initialize model (only once) | |
if "model" not in st.session_state: | |
with st.spinner("Loading the model... This might take a minute..."): | |
try: | |
model, tokenizer = initialize_model() | |
st.session_state.model = model | |
st.session_state.tokenizer = tokenizer | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return | |
# Display chat messages | |
for message in st.session_state.chat_history: | |
with st.chat_message("user"): | |
st.write(message["user"]) | |
with st.chat_message("assistant"): | |
st.write(message["assistant"]) | |
# Chat input | |
if prompt := st.chat_input("What would you like to know?"): | |
# Display user message | |
with st.chat_message("user"): | |
st.write(prompt) | |
# Generate and display assistant response | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
current_turn = {"user": prompt, "assistant": ""} | |
st.session_state.chat_history.append(current_turn) | |
response = generate_response( | |
st.session_state.model, | |
st.session_state.tokenizer, | |
prompt, | |
st.session_state.chat_history | |
) | |
st.write(response) | |
st.session_state.chat_history[-1]["assistant"] = response | |
# Manage context window | |
if len(st.session_state.chat_history) > 5: | |
st.session_state.chat_history = st.session_state.chat_history[-5:] | |
# Sidebar controls | |
with st.sidebar: | |
st.title("Controls") | |
if st.button("Clear Chat"): | |
st.session_state.chat_history = [] | |
st.rerun() | |
st.markdown("---") | |
st.markdown(""" | |
### Model Info | |
- Using TinyLlama 1.1B Chat | |
- Optimized for CPU usage | |
- Context window: 1024 tokens | |
""") | |
if __name__ == "__main__": | |
main() |