File size: 5,183 Bytes
1b0d0e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import base64
from unstructured.partition.pdf import partition_pdf
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from PIL import Image
import pytesseract

# Function to partition PDF
def partition_pdf_elements(filename):
    raw_pdf_elements = partition_pdf(
        filename=filename,
        strategy="hi_res",
        extract_images_in_pdf=True,
        extract_image_block_types=["Image", "Table"],
        extract_image_block_to_payload=False,
        extract_image_block_output_dir="extracted_data"
    )
    return raw_pdf_elements

# Function to classify elements
def classify_elements(raw_pdf_elements):
    Header, Footer, Title, NarrativeText, Text, ListItem, img, tab = [], [], [], [], [], [], [], []

    for element in raw_pdf_elements:
        if "unstructured.documents.elements.Header" in str(type(element)):
            Header.append(str(element))
        elif "unstructured.documents.elements.Footer" in str(type(element)):
            Footer.append(str(element))
        elif "unstructured.documents.elements.Title" in str(type(element)):
            Title.append(str(element))
        elif "unstructured.documents.elements.NarrativeText" in str(type(element)):
            NarrativeText.append(str(element))
        elif "unstructured.documents.elements.Text" in str(type(element)):
            Text.append(str(element))
        elif "unstructured.documents.elements.ListItem" in str(type(element)):
            ListItem.append(str(element))
        elif "unstructured.documents.elements.Image" in str(type(element)):
            img.append(str(element))
        elif "unstructured.documents.elements.Table" in str(type(element)):
            tab.append(str(element))
    return Header, Footer, Title, NarrativeText, Text, ListItem, img, tab

# Function to summarize tables
def summarize_tables(tab, google_api_key):
    prompt_text = """You are an assistant tasked with summarizing tables for retrieval. \

    These summaries will be embedded and used to retrieve the raw table elements. \

    Give a concise summary of the table that is well optimized for retrieval. Table {element} """
    prompt = ChatPromptTemplate.from_template(prompt_text)

    model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key)
    summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
    table_summaries = summarize_chain.batch(tab, {"max_concurrency": 5})
    return table_summaries

# Function to encode image
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

# Function to summarize images
def image_summarize(img_base64, prompt, google_api_key):
    chat = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key, max_output_tokens=512)
    msg = chat.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}}
                ]
            )
        ]
    )
    return msg.content

# Function to generate image summaries
def generate_img_summaries(path, google_api_key):
    img_base64_list = []
    image_summaries = []
    prompt = """You are an assistant tasked with summarizing images for retrieval. \

    These summaries will be embedded and used to retrieve the raw image. \

    Give a concise summary of the image that is well optimized for retrieval.

    also give the image output if possible"""
    base64_image = encode_image(path)
    img_base64_list.append(base64_image)
    image_summaries.append(image_summarize(base64_image, prompt, google_api_key))
    return img_base64_list, image_summaries

# Function to handle text-based queries
def handle_query(query, google_api_key, text_elements):
    prompt_text = f"You are an assistant tasked with answering the following query based on the provided text elements:\n\n{query}\n\nText elements: {text_elements}"
    model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key)
    msg = model.invoke([HumanMessage(content=prompt_text)])
    return msg.content

# Function to extract text from an image
def extract_text_from_image(image_path):
    image = Image.open(image_path)
    text = pytesseract.image_to_string(image)
    return text

# Function to handle image-based queries
def handle_image_query(image_path, query, google_api_key):
    extracted_text = extract_text_from_image(image_path)
    prompt_text = f"You are an assistant tasked with answering the following query based on the extracted text from the image:\n\n{query}\n\nExtracted text: {extracted_text}"
    model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key)
    msg = model.invoke([HumanMessage(content=prompt_text)])
    return msg.content