Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -41,6 +41,9 @@ load_dotenv()
|
|
41 |
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
42 |
processor = None
|
43 |
vision_model = None
|
|
|
|
|
|
|
44 |
# OCR + multimodal image description setup
|
45 |
ocr_model = ocr_predictor(
|
46 |
"db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
|
@@ -97,52 +100,34 @@ SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
97 |
)
|
98 |
|
99 |
def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
client = chromadb.PersistentClient(
|
110 |
-
path=PERSIST_DIR,
|
111 |
-
settings=Settings(),
|
112 |
-
tenant=DEFAULT_TENANT,
|
113 |
-
database=DEFAULT_DATABASE
|
114 |
-
)
|
115 |
-
|
116 |
-
# 3) Create / wipe collections
|
117 |
-
for col in ("text_db", "image_db"):
|
118 |
-
if col in [c.name for c in client.list_collections()]:
|
119 |
-
client.delete_collection(col)
|
120 |
-
|
121 |
-
text_col = client.get_or_create_collection(
|
122 |
-
name="text_db",
|
123 |
-
embedding_function=SHARED_EMB_FN
|
124 |
-
)
|
125 |
-
img_col = client.get_or_create_collection(
|
126 |
-
name="image_db",
|
127 |
embedding_function=SHARED_EMB_FN,
|
128 |
metadata={"hnsw:space": "cosine"}
|
129 |
)
|
130 |
|
131 |
-
#
|
132 |
if images:
|
133 |
descs, metas = [], []
|
134 |
-
for
|
135 |
try:
|
136 |
cap = get_image_description(img)
|
137 |
except:
|
138 |
cap = "⚠️ could not describe image"
|
139 |
-
descs.append(f"{img_names[
|
140 |
metas.append({"image": image_to_bytes(img)})
|
141 |
img_col.add(ids=[str(i) for i in range(len(images))],
|
142 |
documents=descs,
|
143 |
metadatas=metas)
|
144 |
|
145 |
-
#
|
146 |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
147 |
docs = splitter.create_documents([text])
|
148 |
text_col.add(ids=[str(i) for i in range(len(docs))],
|
@@ -153,6 +138,7 @@ def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
|
|
153 |
|
154 |
|
155 |
|
|
|
156 |
# Text extraction
|
157 |
def result_to_text(result, as_text=False):
|
158 |
pages = []
|
@@ -224,15 +210,15 @@ def extract_data_from_pdfs(
|
|
224 |
progress(0.6, "Indexing in vector DB…")
|
225 |
client = get_vectordb(all_text, images, names)
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
session["persist_directory"] = PERSIST_DIR
|
230 |
-
sample_imgs = images[:4] if include_images == "Include Images" else []
|
231 |
|
|
|
|
|
232 |
return (
|
233 |
-
session,
|
234 |
all_text[:2000] + "...",
|
235 |
-
|
236 |
"<h3>Done!</h3>"
|
237 |
)
|
238 |
|
@@ -250,49 +236,41 @@ def conversation(
|
|
250 |
max_tok: int,
|
251 |
model_id: str
|
252 |
):
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
255 |
raise gr.Error("Please extract data first")
|
256 |
|
257 |
-
# 1)
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
)
|
264 |
-
|
265 |
-
|
266 |
-
# 2) Text retrieval
|
267 |
-
text_col = client.get_collection("text_db")
|
268 |
-
docs = text_col.query(query_texts=[question],
|
269 |
-
n_results=int(num_ctx),
|
270 |
-
include=["documents"])["documents"][0]
|
271 |
-
|
272 |
-
# 3) Image retrieval
|
273 |
-
img_col = client.get_collection("image_db")
|
274 |
-
img_q = img_col.query(query_texts=[question],
|
275 |
-
n_results=int(img_ctx),
|
276 |
-
include=["metadatas","documents"])
|
277 |
img_descs = img_q["documents"][0] or ["No images found"]
|
278 |
images = []
|
279 |
for meta in img_q["metadatas"][0]:
|
280 |
-
b64 = meta.get("image","")
|
281 |
try:
|
282 |
images.append(Image.open(io.BytesIO(base64.b64decode(b64))))
|
283 |
except:
|
284 |
pass
|
285 |
img_desc = "\n".join(img_descs)
|
286 |
|
287 |
-
#
|
288 |
-
llm = HuggingFaceEndpoint(
|
289 |
-
repo_id=model_id,
|
290 |
-
task="text-generation",
|
291 |
-
temperature=temp,
|
292 |
-
max_new_tokens=max_tok,
|
293 |
-
huggingfacehub_api_token=HF_TOKEN
|
294 |
-
)
|
295 |
-
|
296 |
prompt = PromptTemplate(
|
297 |
template="""
|
298 |
Context:
|
@@ -305,27 +283,43 @@ Question:
|
|
305 |
{q}
|
306 |
|
307 |
Answer:
|
308 |
-
""",
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
)
|
310 |
-
inp = prompt.format(text="\n\n".join(docs), img_desc=img_desc, q=question)
|
311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
try:
|
313 |
-
answer = llm.invoke(
|
314 |
except HfHubHTTPError as e:
|
315 |
-
|
|
|
|
|
|
|
316 |
except Exception as e:
|
317 |
answer = f"⚠️ Unexpected error: {e}"
|
318 |
|
|
|
319 |
new_history = history + [
|
320 |
-
{"role":"user",
|
321 |
-
{"role":"assistant","content":answer}
|
322 |
]
|
323 |
return new_history, docs, images
|
324 |
|
325 |
|
326 |
|
327 |
|
328 |
-
|
329 |
# ─────────────────────────────────────────────────────────────────────────────
|
330 |
# Gradio UI
|
331 |
CSS = """
|
|
|
41 |
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
42 |
processor = None
|
43 |
vision_model = None
|
44 |
+
# hold the in-memory vectordb
|
45 |
+
CURRENT_VDB = None
|
46 |
+
|
47 |
# OCR + multimodal image description setup
|
48 |
ocr_model = ocr_predictor(
|
49 |
"db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
|
|
|
100 |
)
|
101 |
|
102 |
def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
|
103 |
+
client = chromadb.EphemeralClient()
|
104 |
+
# wipe old
|
105 |
+
for name in ("text_db", "image_db"):
|
106 |
+
if name in [c.name for c in client.list_collections()]:
|
107 |
+
client.delete_collection(name)
|
108 |
+
|
109 |
+
text_col = client.get_or_create_collection("text_db", embedding_function=SHARED_EMB_FN)
|
110 |
+
img_col = client.get_or_create_collection(
|
111 |
+
"image_db",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
embedding_function=SHARED_EMB_FN,
|
113 |
metadata={"hnsw:space": "cosine"}
|
114 |
)
|
115 |
|
116 |
+
# add images
|
117 |
if images:
|
118 |
descs, metas = [], []
|
119 |
+
for i, img in enumerate(images):
|
120 |
try:
|
121 |
cap = get_image_description(img)
|
122 |
except:
|
123 |
cap = "⚠️ could not describe image"
|
124 |
+
descs.append(f"{img_names[i]}: {cap}")
|
125 |
metas.append({"image": image_to_bytes(img)})
|
126 |
img_col.add(ids=[str(i) for i in range(len(images))],
|
127 |
documents=descs,
|
128 |
metadatas=metas)
|
129 |
|
130 |
+
# chunk + add text
|
131 |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
132 |
docs = splitter.create_documents([text])
|
133 |
text_col.add(ids=[str(i) for i in range(len(docs))],
|
|
|
138 |
|
139 |
|
140 |
|
141 |
+
|
142 |
# Text extraction
|
143 |
def result_to_text(result, as_text=False):
|
144 |
pages = []
|
|
|
210 |
progress(0.6, "Indexing in vector DB…")
|
211 |
client = get_vectordb(all_text, images, names)
|
212 |
|
213 |
+
global CURRENT_VDB
|
214 |
+
CURRENT_VDB = get_vectordb(all_text, images, names)
|
|
|
|
|
215 |
|
216 |
+
session["processed"] = True
|
217 |
+
sample = images[:4] if include_images=="Include Images" else []
|
218 |
return (
|
219 |
+
session,
|
220 |
all_text[:2000] + "...",
|
221 |
+
sample,
|
222 |
"<h3>Done!</h3>"
|
223 |
)
|
224 |
|
|
|
236 |
max_tok: int,
|
237 |
model_id: str
|
238 |
):
|
239 |
+
"""
|
240 |
+
Uses the in-memory CURRENT_VDB (set by extract_data_from_pdfs) to answer the user.
|
241 |
+
"""
|
242 |
+
global CURRENT_VDB
|
243 |
+
|
244 |
+
# 0) Guard: make sure we've extracted at least once
|
245 |
+
if not session.get("processed") or CURRENT_VDB is None:
|
246 |
raise gr.Error("Please extract data first")
|
247 |
|
248 |
+
# 1) Retrieve top-k text chunks
|
249 |
+
text_col = CURRENT_VDB.get_collection("text_db")
|
250 |
+
docs = text_col.query(
|
251 |
+
query_texts=[question],
|
252 |
+
n_results=int(num_ctx),
|
253 |
+
include=["documents"]
|
254 |
+
)["documents"][0]
|
255 |
+
|
256 |
+
# 2) Retrieve top-k images
|
257 |
+
img_col = CURRENT_VDB.get_collection("image_db")
|
258 |
+
img_q = img_col.query(
|
259 |
+
query_texts=[question],
|
260 |
+
n_results=int(img_ctx),
|
261 |
+
include=["metadatas", "documents"]
|
262 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
img_descs = img_q["documents"][0] or ["No images found"]
|
264 |
images = []
|
265 |
for meta in img_q["metadatas"][0]:
|
266 |
+
b64 = meta.get("image", "")
|
267 |
try:
|
268 |
images.append(Image.open(io.BytesIO(base64.b64decode(b64))))
|
269 |
except:
|
270 |
pass
|
271 |
img_desc = "\n".join(img_descs)
|
272 |
|
273 |
+
# 3) Build the prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
prompt = PromptTemplate(
|
275 |
template="""
|
276 |
Context:
|
|
|
283 |
{q}
|
284 |
|
285 |
Answer:
|
286 |
+
""",
|
287 |
+
input_variables=["text", "img_desc", "q"],
|
288 |
+
)
|
289 |
+
user_input = prompt.format(
|
290 |
+
text="\n\n".join(docs),
|
291 |
+
img_desc=img_desc,
|
292 |
+
q=question
|
293 |
)
|
|
|
294 |
|
295 |
+
# 4) Call the LLM
|
296 |
+
llm = HuggingFaceEndpoint(
|
297 |
+
repo_id=model_id,
|
298 |
+
task="text-generation",
|
299 |
+
temperature=temp,
|
300 |
+
max_new_tokens=max_tok,
|
301 |
+
# the client will pick up HUGGINGFACEHUB_API_TOKEN from env automatically
|
302 |
+
)
|
303 |
try:
|
304 |
+
answer = llm.invoke(user_input)
|
305 |
except HfHubHTTPError as e:
|
306 |
+
if e.response.status_code == 404:
|
307 |
+
answer = f"❌ Model `{model_id}` not hosted on HF Inference API."
|
308 |
+
else:
|
309 |
+
answer = f"⚠️ HF API error: {e}"
|
310 |
except Exception as e:
|
311 |
answer = f"⚠️ Unexpected error: {e}"
|
312 |
|
313 |
+
# 5) Append to chat history and return
|
314 |
new_history = history + [
|
315 |
+
{"role": "user", "content": question},
|
316 |
+
{"role": "assistant", "content": answer}
|
317 |
]
|
318 |
return new_history, docs, images
|
319 |
|
320 |
|
321 |
|
322 |
|
|
|
323 |
# ─────────────────────────────────────────────────────────────────────────────
|
324 |
# Gradio UI
|
325 |
CSS = """
|