File size: 4,868 Bytes
42cd5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from rag.agents.interface import Pipeline
from sparrow_parse.vllm.inference_factory import InferenceFactory
from sparrow_parse.extractors.vllm_extractor import VLLMExtractor
import timeit
from rich import print
from rich.progress import Progress, SpinnerColumn, TextColumn
from typing import Any, List
from .sparrow_validator import Validator
from .sparrow_utils import is_valid_json, get_json_keys_as_string
import warnings
import os


warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)


class SparrowParsePipeline(Pipeline):

    def __init__(self):
        pass

    def run_pipeline(self,
                     payload: str,
                     query_inputs: [str],
                     query_types: [str],
                     keywords: [str],
                     query: str,
                     file_path: str,
                     index_name: str,
                     options: List[str] = None,
                     group_by_rows: bool = True,
                     update_targets: bool = True,
                     debug: bool = False,
                     local: bool = True) -> Any:
        print(f"\nRunning pipeline with {payload}\n")

        start = timeit.default_timer()

        query_all_data = False
        if query == "*":
            query_all_data = True
            query = None
        else:
            try:
                query, query_schema = self.invoke_pipeline_step(lambda: self.prepare_query_and_schema(query, debug),
                                                          "Preparing query and schema", local)
            except ValueError as e:
                raise e

        llm_output = self.invoke_pipeline_step(lambda: self.execute_query(options, query_all_data, query, file_path, debug),
                                               "Executing query", local)

        validation_result = None
        if query_all_data is False:
            validation_result = self.invoke_pipeline_step(lambda: self.validate_result(llm_output, query_all_data, query_schema, debug),
                                                      "Validating result", local)

        end = timeit.default_timer()

        print(f"Time to retrieve answer: {end - start}")

        return validation_result if validation_result is not None else llm_output


    def prepare_query_and_schema(self, query, debug):
        is_query_valid = is_valid_json(query)
        if not is_query_valid:
            raise ValueError("Invalid query. Please provide a valid JSON query.")

        query_keys = get_json_keys_as_string(query)
        query_schema = query
        query = "retrieve " + query_keys

        query = query + ". return response in JSON format, by strictly following this JSON schema: " + query_schema

        return query, query_schema


    def execute_query(self, options, query_all_data, query, file_path, debug):
        extractor = VLLMExtractor()

        # export HF_TOKEN="hf_"
        config = {}
        if options[0] == 'huggingface':
            config = {
                "method": options[0],  # Could be 'huggingface' or 'local_gpu'
                "hf_space": options[1],
                "hf_token": os.getenv('HF_TOKEN')
            }
        else:
            # Handle other cases if needed
            return "First element is not 'huggingface'"

        # Use the factory to get the correct instance
        factory = InferenceFactory(config)
        model_inference_instance = factory.get_inference_instance()

        input_data = [
            {
                "image": file_path,
                "text_input": query
            }
        ]

        # Now you can run inference without knowing which implementation is used
        llm_output = extractor.run_inference(model_inference_instance, input_data, generic_query=query_all_data,
                                             debug=debug)

        return llm_output


    def validate_result(self, llm_output, query_all_data, query_schema, debug):
        validator = Validator(query_schema)

        validation_result = validator.validate_json_against_schema(llm_output, validator.generated_schema)
        if validation_result is not None:
            return validation_result
        else:
            if debug:
                print("LLM output is valid according to the schema.")


    def invoke_pipeline_step(self, task_call, task_description, local):
        if local:
            with Progress(
                    SpinnerColumn(),
                    TextColumn("[progress.description]{task.description}"),
                    transient=False,
            ) as progress:
                progress.add_task(description=task_description, total=None)
                ret = task_call()
        else:
            print(task_description)
            ret = task_call()

        return ret