import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os from dotenv import load_dotenv # Load environment variables load_dotenv() api_key = os.getenv("api_key") # App title and description st.title("I am Your GrowBuddy 🌱") st.write("Let me help you start gardening. Let's grow together!") # Function to load model def load_model(): try: tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2-2b") model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-2-2b") return tokenizer, model except Exception as e: st.error(f"Failed to load model: {e}") return None, None # Load model and tokenizer tokenizer, model = load_model() if not tokenizer or not model: st.stop() # Default to CPU, or use GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Initialize session state messages if "messages" not in st.session_state: st.session_state.messages = [ {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"} ] # Display conversation history for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) # Function to generate response def generate_response(prompt): try: # Tokenize input prompt with dynamic padding and truncation inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) # Generate output from model outputs = model.generate(inputs["input_ids"], max_new_tokens=100, temperature=0.7, do_sample=True) # Decode and return response response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: st.error(f"Error during text generation: {e}") return "Sorry, I couldn't process your request." # User input field for gardening questions user_input = st.chat_input("Type your gardening question here:") if user_input: with st.chat_message("user"): st.write(user_input) with st.chat_message("assistant"): with st.spinner("Generating your answer..."): response = generate_response(user_input) st.write(response) # Update session state st.session_state.messages.append({"role": "user", "content": user_input}) st.session_state.messages.append({"role": "assistant", "content": response})