File size: 9,449 Bytes
42cd5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
from rag.agents.interface import Pipeline as PipelineInterface
from typing import Any
from haystack import Pipeline
from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack_integrations.components.retrievers.weaviate.embedding_retriever import WeaviateEmbeddingRetriever
from haystack.components.builders import PromptBuilder
from haystack_integrations.components.generators.ollama import OllamaGenerator
from pydantic import create_model
import json
from haystack import component
import pydantic
from typing import Optional, List
from pydantic import ValidationError
import timeit
import box
import yaml
from rich import print
from rich.progress import Progress, SpinnerColumn, TextColumn
import warnings


warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)


# Import config vars
with open('config.yml', 'r', encoding='utf8') as ymlfile:
    cfg = box.Box(yaml.safe_load(ymlfile))


class HaystackPipeline(PipelineInterface):
    def run_pipeline(self,
                     payload: str,
                     query_inputs: [str],
                     query_types: [str],
                     keywords: [str],
                     query: str,
                     file_path: str,
                     index_name: str,
                     options: List[str] = None,
                     group_by_rows: bool = True,
                     update_targets: bool = True,
                     debug: bool = False,
                     local: bool = True) -> Any:
        print(f"\nRunning pipeline with {payload}\n")

        ResponseModel, json_schema = self.invoke_pipeline_step(lambda: self.build_response_class(query_inputs, query_types),
                                                               "Building dynamic response class...",
                                                               local)

        output_validator = self.invoke_pipeline_step(lambda: self.build_validator(ResponseModel),
                                                     "Building output validator...",
                                                     local)

        document_store = self.run_preprocessing_pipeline(index_name, local)

        answer = self.run_inference_pipeline(document_store, json_schema, output_validator, query, local)

        return answer

    # Function to safely evaluate type strings
    def safe_eval_type(self, type_str, context):
        try:
            return eval(type_str, {}, context)
        except NameError:
            raise ValueError(f"Type '{type_str}' is not recognized")

    def build_response_class(self, query_inputs, query_types_as_strings):
        # Controlled context for eval
        context = {
            'List': List,
            'str': str,
            'int': int,
            'float': float
            # Include other necessary types or typing constructs here
        }

        # Convert string representations to actual types
        query_types = [self.safe_eval_type(type_str, context) for type_str in query_types_as_strings]

        # Create fields dictionary
        fields = {name: (type_, ...) for name, type_ in zip(query_inputs, query_types)}

        DynamicModel = create_model('DynamicModel', **fields)

        json_schema = DynamicModel.schema_json(indent=2)

        return DynamicModel, json_schema

    def build_validator(self, Invoice):
        @component
        class OutputValidator:
            def __init__(self, pydantic_model: pydantic.BaseModel):
                self.pydantic_model = pydantic_model
                self.iteration_counter = 0

            # Define the component output
            @component.output_types(valid_replies=List[str], invalid_replies=Optional[List[str]],
                                    error_message=Optional[str])
            def run(self, replies: List[str]):

                self.iteration_counter += 1

                ## Try to parse the LLM's reply ##
                # If the LLM's reply is a valid object, return `"valid_replies"`
                try:
                    output_dict = json.loads(replies[0].strip())
                    # Disable data validation for now
                    # self.pydantic_model.model_validate(output_dict)
                    print(
                        f"OutputValidator at Iteration {self.iteration_counter}: Valid JSON from LLM - No need for looping."
                    )
                    return {"valid_replies": replies}

                # If the LLM's reply is corrupted or not valid, return "invalid_replies" and the "error_message" for LLM to try again
                except (ValueError, ValidationError) as e:
                    print(
                          f"\nOutputValidator at Iteration {self.iteration_counter}: Invalid JSON from LLM - Let's try again.\n"
                          f"Output from LLM:\n {replies[0]} \n"
                          f"Error from OutputValidator: {e}"
                    )
                    return {"invalid_replies": replies, "error_message": str(e)}

        output_validator = OutputValidator(pydantic_model=Invoice)

        return output_validator

    def run_preprocessing_pipeline(self, index_name, local):
        document_store = WeaviateDocumentStore(url=cfg.WEAVIATE_URL, collection_settings={"class": index_name})

        print(f"\nNumber of documents in document store: {document_store.count_documents()}\n")

        if document_store.count_documents() == 0:
            raise ValueError("Document store is empty. Please check your data source.")

        return document_store

    def run_inference_pipeline(self, document_store, json_schema, output_validator, query, local):
        start = timeit.default_timer()

        generator = OllamaGenerator(model=cfg.LLM_HAYSTACK,
                                    url=cfg.OLLAMA_BASE_URL_HAYSTACK + "/api/generate",
                                    timeout=900)

        template = """
        Given only the following document information, retrieve answer.
        Ignore your own knowledge. Format response with the following JSON schema:
        {{schema}}
        Make sure your response is a dict and not a list. Return only JSON, no additional text.

        Context:
        {% for document in documents %}
            {{ document.content }}
        {% endfor %}

        Question: {{ question }}?

        {% if invalid_replies and error_message %}
          You already created the following output in a previous attempt: {{invalid_replies}}
          However, this doesn't comply with the format requirements from above and triggered this Python exception: {{error_message}}
          Correct the output and try again. Just return the corrected output without any extra explanations.
        {% endif %}
        """

        text_embedder = SentenceTransformersTextEmbedder(model=cfg.EMBEDDINGS_HAYSTACK,
                                                         progress_bar=False)

        retriever = WeaviateEmbeddingRetriever(document_store=document_store, top_k=3)

        prompt_builder = PromptBuilder(template=template)

        pipe = Pipeline(max_loops_allowed=cfg.MAX_LOOPS_ALLOWED_HAYSTACK)
        pipe.add_component("embedder", text_embedder)
        pipe.add_component("retriever", retriever)
        pipe.add_component("prompt_builder", prompt_builder)
        pipe.add_component("llm", generator)
        pipe.add_component("output_validator", output_validator)

        pipe.connect("embedder.embedding", "retriever.query_embedding")
        pipe.connect("retriever", "prompt_builder.documents")
        pipe.connect("prompt_builder", "llm")
        pipe.connect("llm", "output_validator")
        # If a component has more than one output or input, explicitly specify the connections:
        pipe.connect("output_validator.invalid_replies", "prompt_builder.invalid_replies")
        pipe.connect("output_validator.error_message", "prompt_builder.error_message")

        question = (
            query
        )

        response = self.invoke_pipeline_step(
                            lambda: pipe.run(
                                        {
                                            "embedder": {"text": question},
                                            "prompt_builder": {"question": question, "schema": json_schema}
                                        }
                                    ),
            "Running inference pipeline...",
                          local)

        end = timeit.default_timer()

        valid_reply = response["output_validator"]["valid_replies"][0]
        valid_json = json.loads(valid_reply)
        print(f"\nJSON response:\n")
        print(valid_json)
        print('\n' + ('=' * 50))

        print(f"Time to retrieve answer: {end - start}")

        return valid_json

    def invoke_pipeline_step(self, task_call, task_description, local):
        if local:
            with Progress(
                    SpinnerColumn(),
                    TextColumn("[progress.description]{task.description}"),
                    transient=False,
            ) as progress:
                progress.add_task(description=task_description, total=None)
                ret = task_call()
        else:
            print(task_description)
            ret = task_call()

        return ret