Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from transformers import pipeline | |
import torch # PyTorch is commonly used by transformers | |
# --- Set Page Config FIRST --- | |
st.set_page_config(layout="wide") # Use wider layout | |
# --- Configuration --- | |
MODEL_NAME = "AdaptLLM/finance-LLM" | |
# Attempt to get token from secrets, handle case where it might not be set yet | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# --- Model Loading (Cached by Streamlit for efficiency) --- | |
# Cache the pipeline object | |
def load_text_generation_pipeline(): | |
"""Loads the text generation pipeline.""" | |
if not HF_TOKEN: | |
st.warning("HF_TOKEN secret not found. Ensure the model is public or add the token to secrets.") | |
# Decide if you want to stop or proceed cautiously | |
# st.stop() # Uncomment this line to halt execution if token is strictly required | |
try: | |
# Determine device: Use GPU (cuda:0) if available, otherwise CPU (-1) | |
# Free Spaces typically only have CPU, so device will likely be -1 | |
device = 0 if torch.cuda.is_available() else -1 | |
st.info(f"Loading model {MODEL_NAME}... This might take a while on the first run.") | |
# Use pipeline for easier text generation | |
generator = pipeline( | |
"text-generation", | |
model=MODEL_NAME, | |
tokenizer=MODEL_NAME, | |
torch_dtype=torch.float16, | |
device=device, | |
trust_remote_code=True | |
) | |
st.success(f"Model {MODEL_NAME} loaded successfully!") | |
return generator | |
except Exception as e: | |
st.error(f"Error loading model pipeline: {e}", icon="π₯") | |
st.error("This could be due to memory limits on the free tier, missing token for a private model, or other issues.") | |
st.stop() # Stop the app if the model fails to load | |
# --- Load the Model Pipeline --- | |
generator = load_text_generation_pipeline() | |
# --- Streamlit App UI --- | |
st.title("π° FinBuddy Assistant") | |
st.caption("Your AI-powered financial planning assistant (Text Chat - v1)") | |
# Initialize chat history in session state if it doesn't exist | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display past chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) # Display content as markdown | |
# Get user input using chat_input | |
if prompt := st.chat_input("Ask a question about finance..."): | |
# Add user message to session state and display it | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Generate assistant response | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() # Create placeholder for streaming/final response | |
message_placeholder.markdown("Thinking...β³") # Initial thinking message | |
# --- Prepare prompt for the model --- | |
# Simple approach: just use the latest user prompt. | |
# TODO: Improve this later to include conversation history for better context. | |
prompt_for_model = prompt | |
try: | |
# Generate response using the pipeline | |
outputs = generator( | |
prompt_for_model, | |
max_new_tokens=512, # Limit the length of the response | |
num_return_sequences=1, | |
eos_token_id=generator.tokenizer.eos_token_id, | |
pad_token_id=generator.tokenizer.eos_token_id # Helps prevent warnings/issues | |
) | |
if outputs and len(outputs) > 0 and 'generated_text' in outputs[0]: | |
# Extract the generated text | |
full_response = outputs[0]['generated_text'] | |
# --- Attempt to clean the response --- | |
# The pipeline often returns the prompt + response. Try to remove the prompt part. | |
if full_response.startswith(prompt_for_model): | |
assistant_response = full_response[len(prompt_for_model):].strip() | |
# Sometimes models add their own role prefix | |
if assistant_response.lower().startswith("assistant:"): | |
assistant_response = assistant_response[len("assistant:"):].strip() | |
elif assistant_response.lower().startswith("response:"): | |
assistant_response = assistant_response[len("response:"):].strip() | |
else: | |
assistant_response = full_response # Fallback if prompt isn't found at start | |
# Handle cases where the response might be empty after cleaning | |
if not assistant_response: | |
assistant_response = "I received your message, but I don't have a further response right now." | |
else: | |
assistant_response = "Sorry, I couldn't generate a response." | |
# Display the final response | |
message_placeholder.markdown(assistant_response) | |
# Add the final assistant response to session state | |
st.session_state.messages.append({"role": "assistant", "content": assistant_response}) | |
except Exception as e: | |
error_message = f"Error during text generation: {e}" | |
st.error(error_message, icon="π₯") | |
message_placeholder.markdown("Sorry, an error occurred while generating the response.") | |
# Add error indication to history | |
st.session_state.messages.append({"role": "assistant", "content": f"[Error: {e}]"}) |