zamal's picture
Update app.py
1e770e5 verified
raw
history blame
10.7 kB
import os
import io
import base64
import gc
from huggingface_hub.utils import HfHubHTTPError
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from pypdf import PdfReader
from PIL import Image
import chromadb
from chromadb.utils import embedding_functions
from langchain.text_splitter import RecursiveCharacterTextSplitter
import gradio as gr
# ─────────────────────────────────────────────────────────────────────────────
# Globals
CURRENT_VDB = None
processor = None
vision_model = None
# OCR & V+L defaults
OCR_CHOICES = {
"db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"),
"db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
}
SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
def get_image_description(img: Image.Image) -> str:
global processor, vision_model
if processor is None or vision_model is None:
# use the same default V+L model everywhere
vlm = "llava-hf/llava-v1.6-mistral-7b-hf"
processor = LlavaNextProcessor.from_pretrained(vlm)
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
vlm, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to("cuda")
torch.cuda.empty_cache(); gc.collect()
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
inputs = processor(prompt, img, return_tensors="pt").to("cuda")
out = vision_model.generate(**inputs, max_new_tokens=100)
return processor.decode(out[0], skip_special_tokens=True)
def extract_data_from_pdfs(
docs, session, include_images, do_ocr, ocr_choice, vlm_choice, progress=gr.Progress()
):
if not docs:
raise gr.Error("No documents to process")
# 1) Optional OCR
local_ocr = None
if do_ocr == "Get Text With OCR":
db_m, crnn_m = OCR_CHOICES[ocr_choice]
local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
# 2) Prepare V+L
proc = LlavaNextProcessor.from_pretrained(vlm_choice)
vis = LlavaNextForConditionalGeneration.from_pretrained(
vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to("cuda")
# 3) Patch get_image_description to use this choice
def describe(img: Image.Image) -> str:
torch.cuda.empty_cache(); gc.collect()
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
inp = proc(prompt, img, return_tensors="pt").to("cuda")
out = vis.generate(**inp, max_new_tokens=100)
return proc.decode(out[0], skip_special_tokens=True)
global get_image_description, CURRENT_VDB
get_image_description = describe
# 4) Pull text + images
progress(0.2, "Extracting text and images…")
full_text, images, names = "", [], []
for p in docs:
if local_ocr:
pdf = DocumentFile.from_pdf(p)
res = local_ocr(pdf)
full_text += " ".join(w.value for blk in res.pages for line in blk.lines for w in line.words) + "\n\n"
else:
full_text += (PdfReader(p).pages[0].extract_text() or "") + "\n\n"
if include_images == "Include Images":
imgs = extract_images([p])
images.extend(imgs)
names.extend([os.path.basename(p)] * len(imgs))
# 5) Build in-memory Chroma
progress(0.6, "Indexing in vector DB…")
client = chromadb.EphemeralClient()
for col in ("text_db", "image_db"):
if col in [c.name for c in client.list_collections()]:
client.delete_collection(col)
text_col = client.get_or_create_collection("text_db", embedding_function=SHARED_EMB_FN)
img_col = client.get_or_create_collection("image_db", embedding_function=SHARED_EMB_FN,
metadata={"hnsw:space":"cosine"})
if images:
descs, metas = [], []
for i, im in enumerate(images):
cap = get_image_description(im)
descs.append(f"{names[i]}: {cap}")
metas.append({"image": image_to_bytes(im)})
img_col.add(ids=[str(i) for i in range(len(images))],
documents=descs, metadatas=metas)
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs_ = splitter.create_documents([full_text])
text_col.add(ids=[str(i) for i in range(len(docs_))],
documents=[d.page_content for d in docs_])
CURRENT_VDB = client
session["processed"] = True
sample = images[:4] if include_images=="Include Images" else []
return session, full_text[:2000]+"...", sample, "<h3>Done!</h3>"
def conversation(session, question, num_ctx, img_ctx, history, temp, max_tok, model_id):
global CURRENT_VDB
if not session.get("processed") or CURRENT_VDB is None:
raise gr.Error("Please extract data first")
# a) text retrieval
docs = CURRENT_VDB.get_collection("text_db")\
.query(query_texts=[question], n_results=int(num_ctx), include=["documents"])["documents"][0]
# b) image retrieval
img_q = CURRENT_VDB.get_collection("image_db")\
.query(query_texts=[question], n_results=int(img_ctx),
include=["metadatas","documents"])
img_descs = img_q["documents"][0] or ["No images found"]
images = []
for m in img_q["metadatas"][0]:
b = m.get("image","")
try: images.append(Image.open(io.BytesIO(base64.b64decode(b))))
except: pass
img_desc = "\n".join(img_descs)
# c) prompt & LLM
prompt = PromptTemplate(
template="""
Context:
{text}
Included Images:
{img_desc}
Question:
{q}
Answer:
""", input_variables=["text","img_desc","q"])
inp = prompt.format(text="\n\n".join(docs), img_desc=img_desc, q=question)
llm = HuggingFaceEndpoint(
repo_id=model_id, task="text-generation",
temperature=temp, max_new_tokens=max_tok,
huggingfacehub_api_token=HF_TOKEN
)
try: ans = llm.invoke(inp)
except HfHubHTTPError as e:
ans = f"❌ Model `{model_id}` not hosted." if e.response.status_code==404 else f"⚠️ HF API error: {e}"
except Exception as e:
ans = f"⚠️ Unexpected error: {e}"
new_hist = history + [{"role":"user","content":question},
{"role":"assistant","content":ans}]
return new_hist, docs, images
# ─────────────────────────────────────────────────────────────────────────────
# Gradio UI
CSS = """
footer {visibility:hidden;}
"""
MODEL_OPTIONS = [
"HuggingFaceH4/zephyr-7b-beta",
"mistralai/Mistral-7B-Instruct-v0.2",
"openchat/openchat-3.5-0106",
"google/gemma-7b-it",
"deepseek-ai/deepseek-llm-7b-chat",
"microsoft/Phi-3-mini-4k-instruct",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Qwen/Qwen1.5-7B-Chat",
"tiiuae/falcon-7b-instruct", # Falcon 7B Instruct
"bigscience/bloomz-7b1", # BLOOMZ 7B
"facebook/opt-2.7b",
]
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
session_state = gr.State({})
with gr.Column(visible=True) as welcome_col:
gr.Markdown(f"<div style='text-align:center'>{WELCOME_INTRO}</div>")
start_btn = gr.Button("🚀 Start")
with gr.Column(visible=False) as app_col:
gr.Markdown("## 📚 Multimodal Chat-PDF Playground")
extract_event = None
with gr.Tabs() as tabs:
with gr.TabItem("1. Upload & Extract"):
docs = gr.File(file_count="multiple", file_types=[".pdf"], label="Upload PDFs")
include_dd = gr.Radio(["Include Images","Exclude Images"],"Exclude Images","Images")
ocr_radio = gr.Radio(["Get Text With OCR","Get Available Text Only"],"Get Available Text Only","OCR")
ocr_dd = gr.Dropdown(list(OCR_CHOICES.keys()), list(OCR_CHOICES.keys())[0], "OCR Model")
vlm_dd = gr.Dropdown(["llava-hf/llava-v1.6-mistral-7b-hf","llava-hf/llava-v1.5-mistral-7b"], "llava-hf/llava-v1.6-mistral-7b-hf", "Vision-Language Model")
extract_btn = gr.Button("Extract")
preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
preview_img = gr.Gallery(label="Sample Images", rows=2, value=[])
preview_html = gr.HTML()
extract_event = extract_btn.click(
fn=extract_data_from_pdfs,
inputs=[docs, session_state, include_dd, ocr_radio, ocr_dd, vlm_dd],
outputs=[session_state, preview_text, preview_img, preview_html]
)
with gr.TabItem("2. Chat", visible=False) as chat_tab:
with gr.Row():
with gr.Column(scale=3):
chat = gr.Chatbot(type="messages", label="Chat")
msg = gr.Textbox(placeholder="Ask about your PDF...", label="Your question")
send = gr.Button("Send")
with gr.Column(scale=1):
model_dd = gr.Dropdown(MODEL_OPTIONS, MODEL_OPTIONS[0], "Choose Chat Model")
num_ctx = gr.Slider(1,20, value=3, label="Text Contexts")
img_ctx = gr.Slider(1,10, value=2, label="Image Contexts")
temp = gr.Slider(0.1,1.0, step=0.1, value=0.4, label="Temperature")
max_tok = gr.Slider(10,1000, step=10, value=200, label="Max Tokens")
send.click(
fn=conversation,
inputs=[session_state, msg, num_ctx, img_ctx, chat, temp, max_tok, model_dd],
outputs=[chat, gr.Dataframe(), gr.Gallery(label="Relevant Images", rows=2, value=[])]
)
# Unhide the Chat tab once extraction completes
extract_event.then(
fn=lambda: gr.update(visible=True),
inputs=[],
outputs=[chat_tab]
)
gr.HTML("<center>Made with ❤️ by Zamal</center>")
start_btn.click(
fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
outputs=[welcome_col, app_col]
)
if __name__ == "__main__":
demo.launch()