Spaces:
Running
Running
import gradio as gr | |
import spaces | |
import os | |
import logging | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
from huggingface_hub import InferenceClient, get_token | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set HF_HOME for caching Hugging Face assets in persistent storage | |
os.environ["HF_HOME"] = "/data/.huggingface" | |
os.makedirs(os.environ["HF_HOME"], exist_ok=True) | |
# Define persistent storage directories | |
DATA_DIR = "/data" # Root persistent storage directory | |
DOCS_DIR = os.path.join(DATA_DIR, "documents") # Subdirectory for uploaded PDFs | |
CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db") # Subdirectory for Chroma vector store | |
# Create directories if they don't exist | |
os.makedirs(DOCS_DIR, exist_ok=True) | |
os.makedirs(CHROMA_DIR, exist_ok=True) | |
# Initialize Cerebras InferenceClient | |
try: | |
client = InferenceClient( | |
model="meta-llama/Llama-4-Scout-17B-16E-Instruct", | |
provider="cerebras", | |
token=get_token() # PRO account token with $2 monthly credits | |
) | |
except Exception as e: | |
logger.error(f"Failed to initialize InferenceClient: {str(e)}") | |
client = None | |
# Global variables for vector store | |
vectorstore = None | |
retriever = None | |
# Use ZeroGPU (H200) for embedding generation, 120s timeout | |
def initialize_rag(file): | |
global vectorstore, retriever | |
try: | |
# Save uploaded file to persistent storage | |
file_name = os.path.basename(file.name) | |
file_path = os.path.join(DOCS_DIR, file_name) | |
if os.path.exists(file_path): | |
logger.info(f"File {file_name} already exists in {DOCS_DIR}, skipping save.") | |
else: | |
with open(file_path, "wb") as f: | |
f.write(file.read()) | |
logger.info(f"Saved {file_name} to {file_path}") | |
# Load and split document | |
loader = PyPDFLoader(file_path) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
texts = text_splitter.split_documents(documents) | |
# Create or update embeddings and vector store | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vectorstore = Chroma.from_documents( | |
texts, embeddings, persist_directory=CHROMA_DIR | |
) | |
vectorstore.persist() # Save to persistent storage | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
return f"Document '{file_name}' processed and saved to {DOCS_DIR}!" | |
except Exception as e: | |
logger.error(f"Error processing document: {str(e)}") | |
return f"Error processing document: {str(e)}" | |
def query_documents(query, history, system_prompt, max_tokens, temperature): | |
global retriever, client | |
try: | |
if client is None: | |
return history, "Error: InferenceClient not initialized." | |
if retriever is None: | |
return history, "Error: No documents loaded. Please upload a document first." | |
# Retrieve relevant documents | |
docs = retriever.get_relevant_documents(query) | |
context = "\n".join([doc.page_content for doc in docs]) | |
# Call Cerebras inference | |
response = client.chat_completion( | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": f"Context: {context}\n\nQuery: {query}"} | |
], | |
max_tokens=int(max_tokens), | |
temperature=float(temperature), | |
stream=False | |
) | |
answer = response.choices[0].message.content | |
# Update chat history | |
history.append((query, answer)) | |
return history, history | |
except Exception as e: | |
logger.error(f"Error querying documents: {str(e)}") | |
return history, f"Error querying documents: {str(e)}" | |
# Load existing vector store on startup | |
try: | |
if os.path.exists(CHROMA_DIR): | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vectorstore = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
logger.info(f"Loaded existing vector store from {CHROMA_DIR}") | |
except Exception as e: | |
logger.error(f"Error loading vector store: {str(e)}") | |
with gr.Blocks() as demo: | |
gr.Markdown("# RAG Chatbot with Persistent Storage and Cerebras Inference") | |
# File upload | |
file_input = gr.File(label="Upload Document (PDF)", file_types=[".pdf"]) | |
file_output = gr.Textbox(label="Upload Status") | |
file_input.upload(initialize_rag, file_input, file_output) | |
# Chat interface | |
chatbot = gr.Chatbot(label="Conversation") | |
# Query and parameters | |
with gr.Row(): | |
query_input = gr.Textbox(label="Query", placeholder="Ask about the document...") | |
system_prompt = gr.Textbox( | |
label="System Prompt", | |
value="You are a helpful assistant answering questions based on the provided document context." | |
) | |
max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=2000, value=500, step=50) | |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1) | |
# Submit button | |
submit_btn = gr.Button("Send") | |
submit_btn.click( | |
query_documents, | |
inputs=[query_input, chatbot, system_prompt, max_tokens, temperature], | |
outputs=[chatbot, chatbot] | |
) | |
if __name__ == "__main__": | |
demo.launch() |