Spaces:
Configuration error
Configuration error
pranav13081999
commited on
Add files via upload
Browse files- multimodal_gradio.py +43 -0
- multimodal_rag_chat.py +110 -0
multimodal_gradio.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from multimodal_rag_chat import partition_pdf_elements, classify_elements, summarize_tables, generate_img_summaries, handle_query, handle_image_query
|
3 |
+
|
4 |
+
# Google API Key (Make sure to replace this with your actual API key)
|
5 |
+
GOOGLE_API_KEY = "YOUR_GOOGLE_API_KEY"
|
6 |
+
|
7 |
+
st.title("PDF and Image Content Summarizer and Query Answerer")
|
8 |
+
|
9 |
+
st.header("Upload PDF or Image")
|
10 |
+
uploaded_file = st.file_uploader("Choose a PDF or Image file", type=["pdf", "jpg", "jpeg", "png"])
|
11 |
+
query = st.text_input("Enter your query")
|
12 |
+
|
13 |
+
if uploaded_file is not None and query:
|
14 |
+
file_type = uploaded_file.type
|
15 |
+
file_path = "temp." + file_type.split('/')[1]
|
16 |
+
|
17 |
+
with open(file_path, "wb") as f:
|
18 |
+
f.write(uploaded_file.getbuffer())
|
19 |
+
|
20 |
+
if file_type.startswith("application/pdf"):
|
21 |
+
raw_pdf_elements = partition_pdf_elements(file_path)
|
22 |
+
Header, Footer, Title, NarrativeText, Text, ListItem, img, tab = classify_elements(raw_pdf_elements)
|
23 |
+
|
24 |
+
text_elements = Header + Footer + Title + NarrativeText + Text + ListItem
|
25 |
+
text_response = handle_query(query, GOOGLE_API_KEY, text_elements)
|
26 |
+
|
27 |
+
st.header("Query Response")
|
28 |
+
st.write(text_response)
|
29 |
+
|
30 |
+
if tab:
|
31 |
+
st.header("Table Summaries")
|
32 |
+
table_summaries = summarize_tables(tab, GOOGLE_API_KEY)
|
33 |
+
st.write(table_summaries)
|
34 |
+
|
35 |
+
if img:
|
36 |
+
st.header("Image Summaries")
|
37 |
+
img_base64_list, image_summaries = generate_img_summaries("extracted_data", GOOGLE_API_KEY)
|
38 |
+
st.write(image_summaries)
|
39 |
+
|
40 |
+
elif file_type.startswith("image"):
|
41 |
+
image_query_response = handle_image_query(file_path, query, GOOGLE_API_KEY)
|
42 |
+
st.header("Image Query Response")
|
43 |
+
st.write(image_query_response)
|
multimodal_rag_chat.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import base64
|
3 |
+
from unstructured.partition.pdf import partition_pdf
|
4 |
+
from langchain_core.output_parsers import StrOutputParser
|
5 |
+
from langchain_core.prompts import ChatPromptTemplate
|
6 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
7 |
+
from langchain_core.messages import HumanMessage
|
8 |
+
from PIL import Image
|
9 |
+
import pytesseract
|
10 |
+
|
11 |
+
# Function to partition PDF
|
12 |
+
def partition_pdf_elements(filename):
|
13 |
+
raw_pdf_elements = partition_pdf(
|
14 |
+
filename=filename,
|
15 |
+
strategy="hi_res",
|
16 |
+
extract_images_in_pdf=True,
|
17 |
+
extract_image_block_types=["Image", "Table"],
|
18 |
+
extract_image_block_to_payload=False,
|
19 |
+
extract_image_block_output_dir="extracted_data"
|
20 |
+
)
|
21 |
+
return raw_pdf_elements
|
22 |
+
|
23 |
+
# Function to classify elements
|
24 |
+
def classify_elements(raw_pdf_elements):
|
25 |
+
Header, Footer, Title, NarrativeText, Text, ListItem, img, tab = [], [], [], [], [], [], [], []
|
26 |
+
|
27 |
+
for element in raw_pdf_elements:
|
28 |
+
if "unstructured.documents.elements.Header" in str(type(element)):
|
29 |
+
Header.append(str(element))
|
30 |
+
elif "unstructured.documents.elements.Footer" in str(type(element)):
|
31 |
+
Footer.append(str(element))
|
32 |
+
elif "unstructured.documents.elements.Title" in str(type(element)):
|
33 |
+
Title.append(str(element))
|
34 |
+
elif "unstructured.documents.elements.NarrativeText" in str(type(element)):
|
35 |
+
NarrativeText.append(str(element))
|
36 |
+
elif "unstructured.documents.elements.Text" in str(type(element)):
|
37 |
+
Text.append(str(element))
|
38 |
+
elif "unstructured.documents.elements.ListItem" in str(type(element)):
|
39 |
+
ListItem.append(str(element))
|
40 |
+
elif "unstructured.documents.elements.Image" in str(type(element)):
|
41 |
+
img.append(str(element))
|
42 |
+
elif "unstructured.documents.elements.Table" in str(type(element)):
|
43 |
+
tab.append(str(element))
|
44 |
+
return Header, Footer, Title, NarrativeText, Text, ListItem, img, tab
|
45 |
+
|
46 |
+
# Function to summarize tables
|
47 |
+
def summarize_tables(tab, google_api_key):
|
48 |
+
prompt_text = """You are an assistant tasked with summarizing tables for retrieval. \
|
49 |
+
These summaries will be embedded and used to retrieve the raw table elements. \
|
50 |
+
Give a concise summary of the table that is well optimized for retrieval. Table {element} """
|
51 |
+
prompt = ChatPromptTemplate.from_template(prompt_text)
|
52 |
+
|
53 |
+
model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key)
|
54 |
+
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
|
55 |
+
table_summaries = summarize_chain.batch(tab, {"max_concurrency": 5})
|
56 |
+
return table_summaries
|
57 |
+
|
58 |
+
# Function to encode image
|
59 |
+
def encode_image(image_path):
|
60 |
+
with open(image_path, "rb") as image_file:
|
61 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
62 |
+
|
63 |
+
# Function to summarize images
|
64 |
+
def image_summarize(img_base64, prompt, google_api_key):
|
65 |
+
chat = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key, max_output_tokens=512)
|
66 |
+
msg = chat.invoke(
|
67 |
+
[
|
68 |
+
HumanMessage(
|
69 |
+
content=[
|
70 |
+
{"type": "text", "text": prompt},
|
71 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}}
|
72 |
+
]
|
73 |
+
)
|
74 |
+
]
|
75 |
+
)
|
76 |
+
return msg.content
|
77 |
+
|
78 |
+
# Function to generate image summaries
|
79 |
+
def generate_img_summaries(path, google_api_key):
|
80 |
+
img_base64_list = []
|
81 |
+
image_summaries = []
|
82 |
+
prompt = """You are an assistant tasked with summarizing images for retrieval. \
|
83 |
+
These summaries will be embedded and used to retrieve the raw image. \
|
84 |
+
Give a concise summary of the image that is well optimized for retrieval.
|
85 |
+
also give the image output if possible"""
|
86 |
+
base64_image = encode_image(path)
|
87 |
+
img_base64_list.append(base64_image)
|
88 |
+
image_summaries.append(image_summarize(base64_image, prompt, google_api_key))
|
89 |
+
return img_base64_list, image_summaries
|
90 |
+
|
91 |
+
# Function to handle text-based queries
|
92 |
+
def handle_query(query, google_api_key, text_elements):
|
93 |
+
prompt_text = f"You are an assistant tasked with answering the following query based on the provided text elements:\n\n{query}\n\nText elements: {text_elements}"
|
94 |
+
model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key)
|
95 |
+
msg = model.invoke([HumanMessage(content=prompt_text)])
|
96 |
+
return msg.content
|
97 |
+
|
98 |
+
# Function to extract text from an image
|
99 |
+
def extract_text_from_image(image_path):
|
100 |
+
image = Image.open(image_path)
|
101 |
+
text = pytesseract.image_to_string(image)
|
102 |
+
return text
|
103 |
+
|
104 |
+
# Function to handle image-based queries
|
105 |
+
def handle_image_query(image_path, query, google_api_key):
|
106 |
+
extracted_text = extract_text_from_image(image_path)
|
107 |
+
prompt_text = f"You are an assistant tasked with answering the following query based on the extracted text from the image:\n\n{query}\n\nExtracted text: {extracted_text}"
|
108 |
+
model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key)
|
109 |
+
msg = model.invoke([HumanMessage(content=prompt_text)])
|
110 |
+
return msg.content
|