rag-llama-4 / app.py
SlouchyBuffalo's picture
Update app.py
20fa628 verified
import gradio as gr
import spaces
import os
import logging
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from huggingface_hub import InferenceClient, get_token
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set HF_HOME for caching Hugging Face assets in persistent storage
os.environ["HF_HOME"] = "/data/.huggingface"
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
# Define persistent storage directories
DATA_DIR = "/data" # Root persistent storage directory
DOCS_DIR = os.path.join(DATA_DIR, "documents") # Subdirectory for uploaded PDFs
CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db") # Subdirectory for Chroma vector store
# Create directories if they don't exist
os.makedirs(DOCS_DIR, exist_ok=True)
os.makedirs(CHROMA_DIR, exist_ok=True)
# Initialize Cerebras InferenceClient
try:
token = get_token()
if not token:
logger.error("HF_TOKEN is not set in Space secrets")
client = None
else:
client = InferenceClient(
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
provider="cerebras",
token=token
)
logger.info("InferenceClient initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize InferenceClient: {str(e)}")
client = None
# Global variables for vector store
vectorstore = None
retriever = None
@spaces.GPU(duration=180) # Use ZeroGPU (H200) for embedding generation, 180s timeout
def initialize_rag(file):
global vectorstore, retriever
try:
# Debug file object properties
logger.info(f"File object: {type(file)}, Attributes: {dir(file)}")
logger.info(f"File name: {file.name}")
# Validate file
if not file or not file.name:
logger.error("No file provided or invalid file name")
return "Error: No file provided or invalid file name"
# Verify temporary file exists and is accessible
if not os.path.exists(file.name):
logger.error(f"Temporary file {file.name} does not exist")
return f"Error: Temporary file {file.name} does not exist"
# Check temporary file size
file_size = os.path.getsize(file.name)
logger.info(f"Temporary file size: {file_size} bytes")
if file_size == 0:
logger.error("Uploaded file is empty")
return "Error: Uploaded file is empty"
# Save uploaded file to persistent storage
file_name = os.path.basename(file.name)
file_path = os.path.join(DOCS_DIR, file_name)
# Check if file exists and its size
should_save = True
if os.path.exists(file_path):
existing_size = os.path.getsize(file_path)
logger.info(f"Existing file {file_name} size: {existing_size} bytes")
if existing_size == 0:
logger.warning(f"Existing file {file_name} is empty, will overwrite")
else:
logger.info(f"File {file_name} already exists and is not empty, skipping save")
should_save = False
if should_save:
try:
with open(file.name, "rb") as src_file:
file_content = src_file.read()
logger.info(f"Read {len(file_content)} bytes from temporary file")
if not file_content:
logger.error("File content is empty after reading")
return "Error: File content is empty after reading"
with open(file_path, "wb") as dst_file:
dst_file.write(file_content)
dst_file.flush() # Ensure write completes
# Verify written file
written_size = os.path.getsize(file_path)
logger.info(f"Saved {file_name} to {file_path}, size: {written_size} bytes")
if written_size == 0:
logger.error(f"Failed to write {file_name}, file is empty")
return f"Error: Failed to write {file_name}, file is empty"
except PermissionError as e:
logger.error(f"Permission error writing to {file_path}: {str(e)}")
return f"Error: Permission denied writing to {file_path}"
except Exception as e:
logger.error(f"Error writing file to {file_path}: {str(e)}")
return f"Error writing file: {str(e)}"
# Load and split document
try:
loader = PyPDFLoader(file_path)
documents = loader.load()
if not documents:
logger.error("No content loaded from PDF")
return "Error: No content loaded from PDF"
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
except Exception as e:
logger.error(f"Error loading PDF: {str(e)}")
return f"Error loading PDF: {str(e)}"
# Create or update embeddings and vector store
try:
logger.info("Initializing HuggingFaceEmbeddings")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
logger.info("Creating Chroma vector store")
vectorstore = Chroma.from_documents(
texts, embeddings, persist_directory=CHROMA_DIR
)
vectorstore.persist() # Save to persistent storage
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
logger.info(f"Vector store created and persisted to {CHROMA_DIR}")
return f"Document '{file_name}' processed and saved to {DOCS_DIR}!"
except Exception as e:
logger.error(f"Error in embeddings or Chroma: {str(e)}")
return f"Error processing embeddings: {str(e)}"
except Exception as e:
logger.error(f"Error processing document: {str(e)}")
return f"Error processing document: {str(e)}"
def query_documents(query, history, system_prompt, max_tokens, temperature):
global retriever, client
try:
if client is None:
logger.error("InferenceClient not initialized")
return history, "Error: InferenceClient not initialized. Check HF_TOKEN."
if retriever is None:
logger.error("No documents loaded")
return history, "Error: No documents loaded. Please upload a document first."
# Ensure history is a list of [user, assistant] lists
logger.info(f"History before processing: {history}")
if not isinstance(history, list):
logger.warning("History is not a list, resetting")
history = []
history = [[str(item[0]), str(item[1])] for item in history if isinstance(item, (list, tuple)) and len(item) == 2]
# Retrieve relevant documents
docs = retriever.get_relevant_documents(query)
context = "\n".join([doc.page_content for doc in docs])
# Call Cerebras inference
logger.info("Calling Cerebras inference")
response = client.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Context: {context}\n\nQuery: {query}"}
],
max_tokens=int(max_tokens),
temperature=float(temperature),
stream=False
)
answer = response.choices[0].message.content
logger.info("Inference successful")
# Update chat history with list format
history.append([query, answer])
logger.info(f"History after append: {history}")
return history, history
except Exception as e:
logger.error(f"Error querying documents: {str(e)}")
return history, f"Error querying documents: {str(e)}"
# Load existing vector store on startup
try:
if os.path.exists(CHROMA_DIR):
logger.info("Loading existing vector store")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
logger.info(f"Loaded vector store from {CHROMA_DIR}")
except Exception as e:
logger.error(f"Error loading vector store: {str(e)}")
with gr.Blocks() as demo:
gr.Markdown("# RAG chatbot w/persistent storage (works best with CPU Upgrade)")
# File upload
file_input = gr.File(label="Upload Document (PDF)", file_types=[".pdf"])
file_output = gr.Textbox(label="Upload Status")
file_input.upload(initialize_rag, file_input, file_output)
# Chat interface
chatbot = gr.Chatbot(label="Conversation")
# Query and parameters
with gr.Row():
query_input = gr.Textbox(label="Query", placeholder="Ask about the document...")
system_prompt = gr.Textbox(
label="System Prompt",
value="You are a helpful assistant answering questions based on the provided document context."
)
max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=2000, value=500, step=50)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
# Submit button
submit_btn = gr.Button("Send")
submit_btn.click(
query_documents,
inputs=[query_input, chatbot, system_prompt, max_tokens, temperature],
outputs=[gr.Chatbot(), gr.Textbox()]
)
if __name__ == "__main__":
demo.launch()