File size: 5,703 Bytes
91cb78a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
import spaces
import os
import logging
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from huggingface_hub import InferenceClient, get_token

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set HF_HOME for caching Hugging Face assets in persistent storage
os.environ["HF_HOME"] = "/data/.huggingface"
os.makedirs(os.environ["HF_HOME"], exist_ok=True)

# Define persistent storage directories
DATA_DIR = "/data"  # Root persistent storage directory
DOCS_DIR = os.path.join(DATA_DIR, "documents")  # Subdirectory for uploaded PDFs
CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db")  # Subdirectory for Chroma vector store

# Create directories if they don't exist
os.makedirs(DOCS_DIR, exist_ok=True)
os.makedirs(CHROMA_DIR, exist_ok=True)

# Initialize Cerebras InferenceClient
try:
    client = InferenceClient(
        model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
        provider="cerebras",
        token=get_token()  # PRO account token with $2 monthly credits
    )
except Exception as e:
    logger.error(f"Failed to initialize InferenceClient: {str(e)}")
    client = None

# Global variables for vector store
vectorstore = None
retriever = None

@spaces.GPU(duration=120)  # Use ZeroGPU (H200) for embedding generation, 120s timeout
def initialize_rag(file):
    global vectorstore, retriever
    try:
        # Save uploaded file to persistent storage
        file_name = os.path.basename(file.name)
        file_path = os.path.join(DOCS_DIR, file_name)
        if os.path.exists(file_path):
            logger.info(f"File {file_name} already exists in {DOCS_DIR}, skipping save.")
        else:
            with open(file_path, "wb") as f:
                f.write(file.read())
            logger.info(f"Saved {file_name} to {file_path}")

        # Load and split document
        loader = PyPDFLoader(file_path)
        documents = loader.load()
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        texts = text_splitter.split_documents(documents)

        # Create or update embeddings and vector store
        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        vectorstore = Chroma.from_documents(
            texts, embeddings, persist_directory=CHROMA_DIR
        )
        vectorstore.persist()  # Save to persistent storage
        retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
        return f"Document '{file_name}' processed and saved to {DOCS_DIR}!"
    except Exception as e:
        logger.error(f"Error processing document: {str(e)}")
        return f"Error processing document: {str(e)}"

def query_documents(query, history, system_prompt, max_tokens, temperature):
    global retriever, client
    try:
        if client is None:
            return history, "Error: InferenceClient not initialized."
        if retriever is None:
            return history, "Error: No documents loaded. Please upload a document first."

        # Retrieve relevant documents
        docs = retriever.get_relevant_documents(query)
        context = "\n".join([doc.page_content for doc in docs])

        # Call Cerebras inference
        response = client.chat_completion(
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"Context: {context}\n\nQuery: {query}"}
            ],
            max_tokens=int(max_tokens),
            temperature=float(temperature),
            stream=False
        )
        answer = response.choices[0].message.content

        # Update chat history
        history.append((query, answer))
        return history, history
    except Exception as e:
        logger.error(f"Error querying documents: {str(e)}")
        return history, f"Error querying documents: {str(e)}"

# Load existing vector store on startup
try:
    if os.path.exists(CHROMA_DIR):
        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        vectorstore = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
        retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
        logger.info(f"Loaded existing vector store from {CHROMA_DIR}")
except Exception as e:
    logger.error(f"Error loading vector store: {str(e)}")

with gr.Blocks() as demo:
    gr.Markdown("# RAG Chatbot with Persistent Storage and Cerebras Inference")

    # File upload
    file_input = gr.File(label="Upload Document (PDF)", file_types=[".pdf"])
    file_output = gr.Textbox(label="Upload Status")
    file_input.upload(initialize_rag, file_input, file_output)

    # Chat interface
    chatbot = gr.Chatbot(label="Conversation")

    # Query and parameters
    with gr.Row():
        query_input = gr.Textbox(label="Query", placeholder="Ask about the document...")
        system_prompt = gr.Textbox(
            label="System Prompt",
            value="You are a helpful assistant answering questions based on the provided document context."
        )
        max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=2000, value=500, step=50)
        temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)

    # Submit button
    submit_btn = gr.Button("Send")
    submit_btn.click(
        query_documents,
        inputs=[query_input, chatbot, system_prompt, max_tokens, temperature],
        outputs=[chatbot, chatbot]
    )

if __name__ == "__main__":
    demo.launch()