SlouchyBuffalo commited on
Commit
91cb78a
·
verified ·
1 Parent(s): 51798a6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import os
4
+ import logging
5
+ from langchain.document_loaders import PyPDFLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.vectorstores import Chroma
9
+ from huggingface_hub import InferenceClient, get_token
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Set HF_HOME for caching Hugging Face assets in persistent storage
16
+ os.environ["HF_HOME"] = "/data/.huggingface"
17
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
18
+
19
+ # Define persistent storage directories
20
+ DATA_DIR = "/data" # Root persistent storage directory
21
+ DOCS_DIR = os.path.join(DATA_DIR, "documents") # Subdirectory for uploaded PDFs
22
+ CHROMA_DIR = os.path.join(DATA_DIR, "chroma_db") # Subdirectory for Chroma vector store
23
+
24
+ # Create directories if they don't exist
25
+ os.makedirs(DOCS_DIR, exist_ok=True)
26
+ os.makedirs(CHROMA_DIR, exist_ok=True)
27
+
28
+ # Initialize Cerebras InferenceClient
29
+ try:
30
+ client = InferenceClient(
31
+ model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
32
+ provider="cerebras",
33
+ token=get_token() # PRO account token with $2 monthly credits
34
+ )
35
+ except Exception as e:
36
+ logger.error(f"Failed to initialize InferenceClient: {str(e)}")
37
+ client = None
38
+
39
+ # Global variables for vector store
40
+ vectorstore = None
41
+ retriever = None
42
+
43
+ @spaces.GPU(duration=120) # Use ZeroGPU (H200) for embedding generation, 120s timeout
44
+ def initialize_rag(file):
45
+ global vectorstore, retriever
46
+ try:
47
+ # Save uploaded file to persistent storage
48
+ file_name = os.path.basename(file.name)
49
+ file_path = os.path.join(DOCS_DIR, file_name)
50
+ if os.path.exists(file_path):
51
+ logger.info(f"File {file_name} already exists in {DOCS_DIR}, skipping save.")
52
+ else:
53
+ with open(file_path, "wb") as f:
54
+ f.write(file.read())
55
+ logger.info(f"Saved {file_name} to {file_path}")
56
+
57
+ # Load and split document
58
+ loader = PyPDFLoader(file_path)
59
+ documents = loader.load()
60
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
61
+ texts = text_splitter.split_documents(documents)
62
+
63
+ # Create or update embeddings and vector store
64
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
65
+ vectorstore = Chroma.from_documents(
66
+ texts, embeddings, persist_directory=CHROMA_DIR
67
+ )
68
+ vectorstore.persist() # Save to persistent storage
69
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
70
+ return f"Document '{file_name}' processed and saved to {DOCS_DIR}!"
71
+ except Exception as e:
72
+ logger.error(f"Error processing document: {str(e)}")
73
+ return f"Error processing document: {str(e)}"
74
+
75
+ def query_documents(query, history, system_prompt, max_tokens, temperature):
76
+ global retriever, client
77
+ try:
78
+ if client is None:
79
+ return history, "Error: InferenceClient not initialized."
80
+ if retriever is None:
81
+ return history, "Error: No documents loaded. Please upload a document first."
82
+
83
+ # Retrieve relevant documents
84
+ docs = retriever.get_relevant_documents(query)
85
+ context = "\n".join([doc.page_content for doc in docs])
86
+
87
+ # Call Cerebras inference
88
+ response = client.chat_completion(
89
+ messages=[
90
+ {"role": "system", "content": system_prompt},
91
+ {"role": "user", "content": f"Context: {context}\n\nQuery: {query}"}
92
+ ],
93
+ max_tokens=int(max_tokens),
94
+ temperature=float(temperature),
95
+ stream=False
96
+ )
97
+ answer = response.choices[0].message.content
98
+
99
+ # Update chat history
100
+ history.append((query, answer))
101
+ return history, history
102
+ except Exception as e:
103
+ logger.error(f"Error querying documents: {str(e)}")
104
+ return history, f"Error querying documents: {str(e)}"
105
+
106
+ # Load existing vector store on startup
107
+ try:
108
+ if os.path.exists(CHROMA_DIR):
109
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
110
+ vectorstore = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
111
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
112
+ logger.info(f"Loaded existing vector store from {CHROMA_DIR}")
113
+ except Exception as e:
114
+ logger.error(f"Error loading vector store: {str(e)}")
115
+
116
+ with gr.Blocks() as demo:
117
+ gr.Markdown("# RAG Chatbot with Persistent Storage and Cerebras Inference")
118
+
119
+ # File upload
120
+ file_input = gr.File(label="Upload Document (PDF)", file_types=[".pdf"])
121
+ file_output = gr.Textbox(label="Upload Status")
122
+ file_input.upload(initialize_rag, file_input, file_output)
123
+
124
+ # Chat interface
125
+ chatbot = gr.Chatbot(label="Conversation")
126
+
127
+ # Query and parameters
128
+ with gr.Row():
129
+ query_input = gr.Textbox(label="Query", placeholder="Ask about the document...")
130
+ system_prompt = gr.Textbox(
131
+ label="System Prompt",
132
+ value="You are a helpful assistant answering questions based on the provided document context."
133
+ )
134
+ max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=2000, value=500, step=50)
135
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
136
+
137
+ # Submit button
138
+ submit_btn = gr.Button("Send")
139
+ submit_btn.click(
140
+ query_documents,
141
+ inputs=[query_input, chatbot, system_prompt, max_tokens, temperature],
142
+ outputs=[chatbot, chatbot]
143
+ )
144
+
145
+ if __name__ == "__main__":
146
+ demo.launch()