katanaml's picture
Sparrow Parse
42cd5f6
from rag.agents.interface import Pipeline
from sparrow_parse.vllm.inference_factory import InferenceFactory
from sparrow_parse.extractors.vllm_extractor import VLLMExtractor
import timeit
from rich import print
from rich.progress import Progress, SpinnerColumn, TextColumn
from typing import Any, List
from .sparrow_validator import Validator
from .sparrow_utils import is_valid_json, get_json_keys_as_string
import warnings
import os
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class SparrowParsePipeline(Pipeline):
def __init__(self):
pass
def run_pipeline(self,
payload: str,
query_inputs: [str],
query_types: [str],
keywords: [str],
query: str,
file_path: str,
index_name: str,
options: List[str] = None,
group_by_rows: bool = True,
update_targets: bool = True,
debug: bool = False,
local: bool = True) -> Any:
print(f"\nRunning pipeline with {payload}\n")
start = timeit.default_timer()
query_all_data = False
if query == "*":
query_all_data = True
query = None
else:
try:
query, query_schema = self.invoke_pipeline_step(lambda: self.prepare_query_and_schema(query, debug),
"Preparing query and schema", local)
except ValueError as e:
raise e
llm_output = self.invoke_pipeline_step(lambda: self.execute_query(options, query_all_data, query, file_path, debug),
"Executing query", local)
validation_result = None
if query_all_data is False:
validation_result = self.invoke_pipeline_step(lambda: self.validate_result(llm_output, query_all_data, query_schema, debug),
"Validating result", local)
end = timeit.default_timer()
print(f"Time to retrieve answer: {end - start}")
return validation_result if validation_result is not None else llm_output
def prepare_query_and_schema(self, query, debug):
is_query_valid = is_valid_json(query)
if not is_query_valid:
raise ValueError("Invalid query. Please provide a valid JSON query.")
query_keys = get_json_keys_as_string(query)
query_schema = query
query = "retrieve " + query_keys
query = query + ". return response in JSON format, by strictly following this JSON schema: " + query_schema
return query, query_schema
def execute_query(self, options, query_all_data, query, file_path, debug):
extractor = VLLMExtractor()
# export HF_TOKEN="hf_"
config = {}
if options[0] == 'huggingface':
config = {
"method": options[0], # Could be 'huggingface' or 'local_gpu'
"hf_space": options[1],
"hf_token": os.getenv('HF_TOKEN')
}
else:
# Handle other cases if needed
return "First element is not 'huggingface'"
# Use the factory to get the correct instance
factory = InferenceFactory(config)
model_inference_instance = factory.get_inference_instance()
input_data = [
{
"image": file_path,
"text_input": query
}
]
# Now you can run inference without knowing which implementation is used
llm_output = extractor.run_inference(model_inference_instance, input_data, generic_query=query_all_data,
debug=debug)
return llm_output
def validate_result(self, llm_output, query_all_data, query_schema, debug):
validator = Validator(query_schema)
validation_result = validator.validate_json_against_schema(llm_output, validator.generated_schema)
if validation_result is not None:
return validation_result
else:
if debug:
print("LLM output is valid according to the schema.")
def invoke_pipeline_step(self, task_call, task_description, local):
if local:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=False,
) as progress:
progress.add_task(description=task_description, total=None)
ret = task_call()
else:
print(task_description)
ret = task_call()
return ret