File size: 10,730 Bytes
15067e5
 
53b8f1f
 
15067e5
 
 
1e770e5
ebb0646
 
1e770e5
 
15067e5
 
 
1e770e5
ebb0646
15067e5
1e770e5
 
cd8c42c
 
3ad87bd
1e770e5
 
 
 
 
 
 
ebb0646
82895ea
1e770e5
cd8c42c
 
1e770e5
 
 
cd8c42c
1e770e5
cd8c42c
1e770e5
3655123
1e770e5
 
 
281f3ad
ebb0646
1e770e5
ebb0646
15067e5
281f3ad
 
1e770e5
 
15067e5
 
 
 
1e770e5
15067e5
1e770e5
 
 
15067e5
1e770e5
0a3438b
6d3678b
15067e5
6d3678b
 
 
1e770e5
15067e5
 
1e770e5
15067e5
1e770e5
 
15067e5
1e770e5
15067e5
1e770e5
ebb0646
1e770e5
ebb0646
53b8f1f
1e770e5
dbe872f
1e770e5
53b8f1f
1e770e5
15067e5
1e770e5
 
 
 
 
 
 
cd8c42c
1e770e5
 
 
 
 
 
 
 
82895ea
1e770e5
 
 
 
08d9c00
1e770e5
 
 
 
6d3678b
1e770e5
3ad87bd
 
15067e5
 
1e770e5
 
 
3ad87bd
1e770e5
 
 
 
82895ea
 
1e770e5
 
 
 
15067e5
281f3ad
1e770e5
15067e5
 
 
 
281f3ad
15067e5
 
281f3ad
15067e5
 
281f3ad
15067e5
1e770e5
 
15067e5
3ad87bd
1e770e5
 
 
3ad87bd
1e770e5
15067e5
1e770e5
15067e5
1e770e5
15067e5
1e770e5
 
 
281f3ad
 
 
08d9c00
15067e5
 
 
 
 
53b8f1f
15067e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d3678b
15067e5
 
 
 
a5216ce
 
94f2e74
15067e5
6d3678b
 
 
 
 
 
 
 
20a5a76
15067e5
a5216ce
08d9c00
6d3678b
 
53b8f1f
281f3ad
94f2e74
15067e5
 
 
6d3678b
15067e5
 
6d3678b
 
 
 
 
15067e5
 
08d9c00
6d3678b
 
281f3ad
 
6d3678b
a5216ce
 
 
 
 
 
15067e5
281f3ad
15067e5
 
08d9c00
15067e5
e45b54b
281f3ad
3b6e025
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import os
import io
import base64
import gc
from huggingface_hub.utils import HfHubHTTPError
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from pypdf import PdfReader
from PIL import Image
import chromadb
from chromadb.utils import embedding_functions
from langchain.text_splitter import RecursiveCharacterTextSplitter
import gradio as gr

# ─────────────────────────────────────────────────────────────────────────────
# Globals
CURRENT_VDB = None
processor = None
vision_model = None

# OCR & V+L defaults
OCR_CHOICES = {
    "db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"),
    "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
}
SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
    model_name="all-MiniLM-L6-v2"
)

def get_image_description(img: Image.Image) -> str:
    global processor, vision_model
    if processor is None or vision_model is None:
        # use the same default V+L model everywhere
        vlm = "llava-hf/llava-v1.6-mistral-7b-hf"
        processor = LlavaNextProcessor.from_pretrained(vlm)
        vision_model = LlavaNextForConditionalGeneration.from_pretrained(
            vlm, torch_dtype=torch.float16, low_cpu_mem_usage=True
        ).to("cuda")
    torch.cuda.empty_cache(); gc.collect()
    prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
    inputs = processor(prompt, img, return_tensors="pt").to("cuda")
    out = vision_model.generate(**inputs, max_new_tokens=100)
    return processor.decode(out[0], skip_special_tokens=True)

def extract_data_from_pdfs(
    docs, session, include_images, do_ocr, ocr_choice, vlm_choice, progress=gr.Progress()
):
    if not docs:
        raise gr.Error("No documents to process")

    # 1) Optional OCR
    local_ocr = None
    if do_ocr == "Get Text With OCR":
        db_m, crnn_m = OCR_CHOICES[ocr_choice]
        local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)

    # 2) Prepare V+L
    proc = LlavaNextProcessor.from_pretrained(vlm_choice)
    vis = LlavaNextForConditionalGeneration.from_pretrained(
        vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True
    ).to("cuda")

    # 3) Patch get_image_description to use this choice
    def describe(img: Image.Image) -> str:
        torch.cuda.empty_cache(); gc.collect()
        prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
        inp = proc(prompt, img, return_tensors="pt").to("cuda")
        out = vis.generate(**inp, max_new_tokens=100)
        return proc.decode(out[0], skip_special_tokens=True)
    global get_image_description, CURRENT_VDB
    get_image_description = describe

    # 4) Pull text + images
    progress(0.2, "Extracting text and images…")
    full_text, images, names = "", [], []
    for p in docs:
        if local_ocr:
            pdf = DocumentFile.from_pdf(p)
            res = local_ocr(pdf)
            full_text += " ".join(w.value for blk in res.pages for line in blk.lines for w in line.words) + "\n\n"
        else:
            full_text += (PdfReader(p).pages[0].extract_text() or "") + "\n\n"

        if include_images == "Include Images":
            imgs = extract_images([p])
            images.extend(imgs)
            names.extend([os.path.basename(p)] * len(imgs))

    # 5) Build in-memory Chroma
    progress(0.6, "Indexing in vector DB…")
    client = chromadb.EphemeralClient()
    for col in ("text_db", "image_db"):
        if col in [c.name for c in client.list_collections()]:
            client.delete_collection(col)
    text_col = client.get_or_create_collection("text_db", embedding_function=SHARED_EMB_FN)
    img_col = client.get_or_create_collection("image_db", embedding_function=SHARED_EMB_FN,
                                              metadata={"hnsw:space":"cosine"})

    if images:
        descs, metas = [], []
        for i, im in enumerate(images):
            cap = get_image_description(im)
            descs.append(f"{names[i]}: {cap}")
            metas.append({"image": image_to_bytes(im)})
        img_col.add(ids=[str(i) for i in range(len(images))],
                    documents=descs, metadatas=metas)

    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    docs_ = splitter.create_documents([full_text])
    text_col.add(ids=[str(i) for i in range(len(docs_))],
                 documents=[d.page_content for d in docs_])

    CURRENT_VDB = client
    session["processed"] = True
    sample = images[:4] if include_images=="Include Images" else []
    return session, full_text[:2000]+"...", sample, "<h3>Done!</h3>"

def conversation(session, question, num_ctx, img_ctx, history, temp, max_tok, model_id):
    global CURRENT_VDB
    if not session.get("processed") or CURRENT_VDB is None:
        raise gr.Error("Please extract data first")

    # a) text retrieval
    docs = CURRENT_VDB.get_collection("text_db")\
         .query(query_texts=[question], n_results=int(num_ctx), include=["documents"])["documents"][0]

    # b) image retrieval
    img_q = CURRENT_VDB.get_collection("image_db")\
           .query(query_texts=[question], n_results=int(img_ctx),
                  include=["metadatas","documents"])
    img_descs = img_q["documents"][0] or ["No images found"]
    images = []
    for m in img_q["metadatas"][0]:
        b = m.get("image","")
        try: images.append(Image.open(io.BytesIO(base64.b64decode(b))))
        except: pass
    img_desc = "\n".join(img_descs)

    # c) prompt & LLM
    prompt = PromptTemplate(
        template="""
Context:
{text}

Included Images:
{img_desc}

Question:
{q}

Answer:
""", input_variables=["text","img_desc","q"])
    inp = prompt.format(text="\n\n".join(docs), img_desc=img_desc, q=question)

    llm = HuggingFaceEndpoint(
        repo_id=model_id, task="text-generation",
        temperature=temp, max_new_tokens=max_tok,
        huggingfacehub_api_token=HF_TOKEN
    )
    try:    ans = llm.invoke(inp)
    except HfHubHTTPError as e:
        ans = f"❌ Model `{model_id}` not hosted." if e.response.status_code==404 else f"⚠️ HF API error: {e}"
    except Exception as e:
        ans = f"⚠️ Unexpected error: {e}"

    new_hist = history + [{"role":"user","content":question},
                          {"role":"assistant","content":ans}]
    return new_hist, docs, images




# ─────────────────────────────────────────────────────────────────────────────
# Gradio UI
CSS = """
footer {visibility:hidden;}
"""

MODEL_OPTIONS = [
    "HuggingFaceH4/zephyr-7b-beta",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "openchat/openchat-3.5-0106",
    "google/gemma-7b-it",
    "deepseek-ai/deepseek-llm-7b-chat",
    "microsoft/Phi-3-mini-4k-instruct",
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "Qwen/Qwen1.5-7B-Chat",
    "tiiuae/falcon-7b-instruct",              # Falcon 7B Instruct
    "bigscience/bloomz-7b1",                  # BLOOMZ 7B
    "facebook/opt-2.7b",  
]

with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    session_state = gr.State({})

    with gr.Column(visible=True) as welcome_col:
        gr.Markdown(f"<div style='text-align:center'>{WELCOME_INTRO}</div>")
        start_btn = gr.Button("🚀 Start")

    with gr.Column(visible=False) as app_col:
        gr.Markdown("## 📚 Multimodal Chat-PDF Playground")
        extract_event = None

        with gr.Tabs() as tabs:
            with gr.TabItem("1. Upload & Extract"):
                docs = gr.File(file_count="multiple", file_types=[".pdf"], label="Upload PDFs")
                include_dd = gr.Radio(["Include Images","Exclude Images"],"Exclude Images","Images")
                ocr_radio = gr.Radio(["Get Text With OCR","Get Available Text Only"],"Get Available Text Only","OCR")
                ocr_dd    = gr.Dropdown(list(OCR_CHOICES.keys()), list(OCR_CHOICES.keys())[0], "OCR Model")
                vlm_dd    = gr.Dropdown(["llava-hf/llava-v1.6-mistral-7b-hf","llava-hf/llava-v1.5-mistral-7b"], "llava-hf/llava-v1.6-mistral-7b-hf", "Vision-Language Model")
                extract_btn  = gr.Button("Extract")
                preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
                preview_img  = gr.Gallery(label="Sample Images", rows=2, value=[])
                preview_html = gr.HTML()

                extract_event = extract_btn.click(
                    fn=extract_data_from_pdfs,
                    inputs=[docs, session_state, include_dd, ocr_radio, ocr_dd, vlm_dd],
                    outputs=[session_state, preview_text, preview_img, preview_html]
                )

            with gr.TabItem("2. Chat", visible=False) as chat_tab:
                with gr.Row():
                    with gr.Column(scale=3):
                        chat = gr.Chatbot(type="messages", label="Chat")
                        msg  = gr.Textbox(placeholder="Ask about your PDF...", label="Your question")
                        send = gr.Button("Send")
                    with gr.Column(scale=1):
                        model_dd = gr.Dropdown(MODEL_OPTIONS, MODEL_OPTIONS[0], "Choose Chat Model")
                        num_ctx  = gr.Slider(1,20, value=3, label="Text Contexts")
                        img_ctx  = gr.Slider(1,10, value=2, label="Image Contexts")
                        temp     = gr.Slider(0.1,1.0, step=0.1, value=0.4, label="Temperature")
                        max_tok  = gr.Slider(10,1000, step=10, value=200, label="Max Tokens")

                send.click(
                    fn=conversation,
                    inputs=[session_state, msg, num_ctx, img_ctx, chat, temp, max_tok, model_dd],
                    outputs=[chat, gr.Dataframe(), gr.Gallery(label="Relevant Images", rows=2, value=[])]
                )

        # Unhide the Chat tab once extraction completes
        extract_event.then(
            fn=lambda: gr.update(visible=True),
            inputs=[],
            outputs=[chat_tab]
        )

        gr.HTML("<center>Made with ❤️ by Zamal</center>")

    start_btn.click(
        fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
        outputs=[welcome_col, app_col]
    )

if __name__ == "__main__":
    demo.launch()