pranav13081999 commited on
Commit
1b0d0e9
·
unverified ·
1 Parent(s): 659874c

Add files via upload

Browse files
Files changed (2) hide show
  1. multimodal_gradio.py +43 -0
  2. multimodal_rag_chat.py +110 -0
multimodal_gradio.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from multimodal_rag_chat import partition_pdf_elements, classify_elements, summarize_tables, generate_img_summaries, handle_query, handle_image_query
3
+
4
+ # Google API Key (Make sure to replace this with your actual API key)
5
+ GOOGLE_API_KEY = "YOUR_GOOGLE_API_KEY"
6
+
7
+ st.title("PDF and Image Content Summarizer and Query Answerer")
8
+
9
+ st.header("Upload PDF or Image")
10
+ uploaded_file = st.file_uploader("Choose a PDF or Image file", type=["pdf", "jpg", "jpeg", "png"])
11
+ query = st.text_input("Enter your query")
12
+
13
+ if uploaded_file is not None and query:
14
+ file_type = uploaded_file.type
15
+ file_path = "temp." + file_type.split('/')[1]
16
+
17
+ with open(file_path, "wb") as f:
18
+ f.write(uploaded_file.getbuffer())
19
+
20
+ if file_type.startswith("application/pdf"):
21
+ raw_pdf_elements = partition_pdf_elements(file_path)
22
+ Header, Footer, Title, NarrativeText, Text, ListItem, img, tab = classify_elements(raw_pdf_elements)
23
+
24
+ text_elements = Header + Footer + Title + NarrativeText + Text + ListItem
25
+ text_response = handle_query(query, GOOGLE_API_KEY, text_elements)
26
+
27
+ st.header("Query Response")
28
+ st.write(text_response)
29
+
30
+ if tab:
31
+ st.header("Table Summaries")
32
+ table_summaries = summarize_tables(tab, GOOGLE_API_KEY)
33
+ st.write(table_summaries)
34
+
35
+ if img:
36
+ st.header("Image Summaries")
37
+ img_base64_list, image_summaries = generate_img_summaries("extracted_data", GOOGLE_API_KEY)
38
+ st.write(image_summaries)
39
+
40
+ elif file_type.startswith("image"):
41
+ image_query_response = handle_image_query(file_path, query, GOOGLE_API_KEY)
42
+ st.header("Image Query Response")
43
+ st.write(image_query_response)
multimodal_rag_chat.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ from unstructured.partition.pdf import partition_pdf
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_core.messages import HumanMessage
8
+ from PIL import Image
9
+ import pytesseract
10
+
11
+ # Function to partition PDF
12
+ def partition_pdf_elements(filename):
13
+ raw_pdf_elements = partition_pdf(
14
+ filename=filename,
15
+ strategy="hi_res",
16
+ extract_images_in_pdf=True,
17
+ extract_image_block_types=["Image", "Table"],
18
+ extract_image_block_to_payload=False,
19
+ extract_image_block_output_dir="extracted_data"
20
+ )
21
+ return raw_pdf_elements
22
+
23
+ # Function to classify elements
24
+ def classify_elements(raw_pdf_elements):
25
+ Header, Footer, Title, NarrativeText, Text, ListItem, img, tab = [], [], [], [], [], [], [], []
26
+
27
+ for element in raw_pdf_elements:
28
+ if "unstructured.documents.elements.Header" in str(type(element)):
29
+ Header.append(str(element))
30
+ elif "unstructured.documents.elements.Footer" in str(type(element)):
31
+ Footer.append(str(element))
32
+ elif "unstructured.documents.elements.Title" in str(type(element)):
33
+ Title.append(str(element))
34
+ elif "unstructured.documents.elements.NarrativeText" in str(type(element)):
35
+ NarrativeText.append(str(element))
36
+ elif "unstructured.documents.elements.Text" in str(type(element)):
37
+ Text.append(str(element))
38
+ elif "unstructured.documents.elements.ListItem" in str(type(element)):
39
+ ListItem.append(str(element))
40
+ elif "unstructured.documents.elements.Image" in str(type(element)):
41
+ img.append(str(element))
42
+ elif "unstructured.documents.elements.Table" in str(type(element)):
43
+ tab.append(str(element))
44
+ return Header, Footer, Title, NarrativeText, Text, ListItem, img, tab
45
+
46
+ # Function to summarize tables
47
+ def summarize_tables(tab, google_api_key):
48
+ prompt_text = """You are an assistant tasked with summarizing tables for retrieval. \
49
+ These summaries will be embedded and used to retrieve the raw table elements. \
50
+ Give a concise summary of the table that is well optimized for retrieval. Table {element} """
51
+ prompt = ChatPromptTemplate.from_template(prompt_text)
52
+
53
+ model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key)
54
+ summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
55
+ table_summaries = summarize_chain.batch(tab, {"max_concurrency": 5})
56
+ return table_summaries
57
+
58
+ # Function to encode image
59
+ def encode_image(image_path):
60
+ with open(image_path, "rb") as image_file:
61
+ return base64.b64encode(image_file.read()).decode("utf-8")
62
+
63
+ # Function to summarize images
64
+ def image_summarize(img_base64, prompt, google_api_key):
65
+ chat = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key, max_output_tokens=512)
66
+ msg = chat.invoke(
67
+ [
68
+ HumanMessage(
69
+ content=[
70
+ {"type": "text", "text": prompt},
71
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}}
72
+ ]
73
+ )
74
+ ]
75
+ )
76
+ return msg.content
77
+
78
+ # Function to generate image summaries
79
+ def generate_img_summaries(path, google_api_key):
80
+ img_base64_list = []
81
+ image_summaries = []
82
+ prompt = """You are an assistant tasked with summarizing images for retrieval. \
83
+ These summaries will be embedded and used to retrieve the raw image. \
84
+ Give a concise summary of the image that is well optimized for retrieval.
85
+ also give the image output if possible"""
86
+ base64_image = encode_image(path)
87
+ img_base64_list.append(base64_image)
88
+ image_summaries.append(image_summarize(base64_image, prompt, google_api_key))
89
+ return img_base64_list, image_summaries
90
+
91
+ # Function to handle text-based queries
92
+ def handle_query(query, google_api_key, text_elements):
93
+ prompt_text = f"You are an assistant tasked with answering the following query based on the provided text elements:\n\n{query}\n\nText elements: {text_elements}"
94
+ model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key)
95
+ msg = model.invoke([HumanMessage(content=prompt_text)])
96
+ return msg.content
97
+
98
+ # Function to extract text from an image
99
+ def extract_text_from_image(image_path):
100
+ image = Image.open(image_path)
101
+ text = pytesseract.image_to_string(image)
102
+ return text
103
+
104
+ # Function to handle image-based queries
105
+ def handle_image_query(image_path, query, google_api_key):
106
+ extracted_text = extract_text_from_image(image_path)
107
+ prompt_text = f"You are an assistant tasked with answering the following query based on the extracted text from the image:\n\n{query}\n\nExtracted text: {extracted_text}"
108
+ model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key)
109
+ msg = model.invoke([HumanMessage(content=prompt_text)])
110
+ return msg.content