File size: 5,562 Bytes
88f5ec5
 
 
 
 
ebfb19f
 
 
88f5ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92cd1f8
 
 
 
88f5ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
import os
from transformers import pipeline
import torch # PyTorch is commonly used by transformers

# --- Set Page Config FIRST --- 
st.set_page_config(layout="wide") # Use wider layout

# --- Configuration --- 
MODEL_NAME = "AdaptLLM/finance-LLM"
# Attempt to get token from secrets, handle case where it might not be set yet
HF_TOKEN = os.environ.get("HF_TOKEN")

# --- Model Loading (Cached by Streamlit for efficiency) --- 
@st.cache_resource # Cache the pipeline object
def load_text_generation_pipeline():
    """Loads the text generation pipeline."""
    if not HF_TOKEN:
        st.warning("HF_TOKEN secret not found. Ensure the model is public or add the token to secrets.")
        # Decide if you want to stop or proceed cautiously
        # st.stop() # Uncomment this line to halt execution if token is strictly required

    try:
        # Determine device: Use GPU (cuda:0) if available, otherwise CPU (-1)
        # Free Spaces typically only have CPU, so device will likely be -1
        device = 0 if torch.cuda.is_available() else -1

        st.info(f"Loading model {MODEL_NAME}... This might take a while on the first run.")
        # Use pipeline for easier text generation
        generator = pipeline(
            "text-generation",
            model=MODEL_NAME,
            tokenizer=MODEL_NAME, 
            torch_dtype=torch.float16, 
            device=device, 
            trust_remote_code=True
        )
        st.success(f"Model {MODEL_NAME} loaded successfully!")
        return generator
    except Exception as e:
        st.error(f"Error loading model pipeline: {e}", icon="πŸ”₯")
        st.error("This could be due to memory limits on the free tier, missing token for a private model, or other issues.")
        st.stop() # Stop the app if the model fails to load

# --- Load the Model Pipeline --- 
generator = load_text_generation_pipeline()

# --- Streamlit App UI --- 
st.title("πŸ’° FinBuddy Assistant")
st.caption("Your AI-powered financial planning assistant (Text Chat - v1)")

# Initialize chat history in session state if it doesn't exist
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display past chat messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"]) # Display content as markdown

# Get user input using chat_input
if prompt := st.chat_input("Ask a question about finance..."):
    # Add user message to session state and display it
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    # Generate assistant response
    with st.chat_message("assistant"):
        message_placeholder = st.empty() # Create placeholder for streaming/final response
        message_placeholder.markdown("Thinking...⏳") # Initial thinking message

        # --- Prepare prompt for the model --- 
        # Simple approach: just use the latest user prompt.
        # TODO: Improve this later to include conversation history for better context.
        prompt_for_model = prompt

        try:
            # Generate response using the pipeline
            outputs = generator(
                prompt_for_model,
                max_new_tokens=512,  # Limit the length of the response
                num_return_sequences=1,
                eos_token_id=generator.tokenizer.eos_token_id,
                pad_token_id=generator.tokenizer.eos_token_id # Helps prevent warnings/issues
            )

            if outputs and len(outputs) > 0 and 'generated_text' in outputs[0]:
                # Extract the generated text
                full_response = outputs[0]['generated_text']

                # --- Attempt to clean the response --- 
                # The pipeline often returns the prompt + response. Try to remove the prompt part.
                if full_response.startswith(prompt_for_model):
                     assistant_response = full_response[len(prompt_for_model):].strip()
                     # Sometimes models add their own role prefix
                     if assistant_response.lower().startswith("assistant:"):
                         assistant_response = assistant_response[len("assistant:"):].strip()
                     elif assistant_response.lower().startswith("response:"):
                         assistant_response = assistant_response[len("response:"):].strip()
                else:
                     assistant_response = full_response # Fallback if prompt isn't found at start

                # Handle cases where the response might be empty after cleaning
                if not assistant_response:
                    assistant_response = "I received your message, but I don't have a further response right now."

            else:
                assistant_response = "Sorry, I couldn't generate a response."

            # Display the final response
            message_placeholder.markdown(assistant_response)
            # Add the final assistant response to session state
            st.session_state.messages.append({"role": "assistant", "content": assistant_response})

        except Exception as e:
            error_message = f"Error during text generation: {e}"
            st.error(error_message, icon="πŸ”₯")
            message_placeholder.markdown("Sorry, an error occurred while generating the response.")
            # Add error indication to history
            st.session_state.messages.append({"role": "assistant", "content": f"[Error: {e}]"})