Spaces:
Running
Running
from rag.agents.interface import Pipeline | |
from sparrow_parse.vllm.inference_factory import InferenceFactory | |
from sparrow_parse.extractors.vllm_extractor import VLLMExtractor | |
import timeit | |
from rich import print | |
from rich.progress import Progress, SpinnerColumn, TextColumn | |
from typing import Any, List | |
from .sparrow_validator import Validator | |
from .sparrow_utils import is_valid_json, get_json_keys_as_string | |
import warnings | |
import os | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
class SparrowParsePipeline(Pipeline): | |
def __init__(self): | |
pass | |
def run_pipeline(self, | |
payload: str, | |
query_inputs: [str], | |
query_types: [str], | |
keywords: [str], | |
query: str, | |
file_path: str, | |
index_name: str, | |
options: List[str] = None, | |
group_by_rows: bool = True, | |
update_targets: bool = True, | |
debug: bool = False, | |
local: bool = True) -> Any: | |
print(f"\nRunning pipeline with {payload}\n") | |
start = timeit.default_timer() | |
query_all_data = False | |
if query == "*": | |
query_all_data = True | |
query = None | |
else: | |
try: | |
query, query_schema = self.invoke_pipeline_step(lambda: self.prepare_query_and_schema(query, debug), | |
"Preparing query and schema", local) | |
except ValueError as e: | |
raise e | |
llm_output = self.invoke_pipeline_step(lambda: self.execute_query(options, query_all_data, query, file_path, debug), | |
"Executing query", local) | |
validation_result = None | |
if query_all_data is False: | |
validation_result = self.invoke_pipeline_step(lambda: self.validate_result(llm_output, query_all_data, query_schema, debug), | |
"Validating result", local) | |
end = timeit.default_timer() | |
print(f"Time to retrieve answer: {end - start}") | |
return validation_result if validation_result is not None else llm_output | |
def prepare_query_and_schema(self, query, debug): | |
is_query_valid = is_valid_json(query) | |
if not is_query_valid: | |
raise ValueError("Invalid query. Please provide a valid JSON query.") | |
query_keys = get_json_keys_as_string(query) | |
query_schema = query | |
query = "retrieve " + query_keys | |
query = query + ". return response in JSON format, by strictly following this JSON schema: " + query_schema | |
return query, query_schema | |
def execute_query(self, options, query_all_data, query, file_path, debug): | |
extractor = VLLMExtractor() | |
# export HF_TOKEN="hf_" | |
config = {} | |
if options[0] == 'huggingface': | |
config = { | |
"method": options[0], # Could be 'huggingface' or 'local_gpu' | |
"hf_space": options[1], | |
"hf_token": os.getenv('HF_TOKEN') | |
} | |
else: | |
# Handle other cases if needed | |
return "First element is not 'huggingface'" | |
# Use the factory to get the correct instance | |
factory = InferenceFactory(config) | |
model_inference_instance = factory.get_inference_instance() | |
input_data = [ | |
{ | |
"image": file_path, | |
"text_input": query | |
} | |
] | |
# Now you can run inference without knowing which implementation is used | |
llm_output = extractor.run_inference(model_inference_instance, input_data, generic_query=query_all_data, | |
debug=debug) | |
return llm_output | |
def validate_result(self, llm_output, query_all_data, query_schema, debug): | |
validator = Validator(query_schema) | |
validation_result = validator.validate_json_against_schema(llm_output, validator.generated_schema) | |
if validation_result is not None: | |
return validation_result | |
else: | |
if debug: | |
print("LLM output is valid according to the schema.") | |
def invoke_pipeline_step(self, task_call, task_description, local): | |
if local: | |
with Progress( | |
SpinnerColumn(), | |
TextColumn("[progress.description]{task.description}"), | |
transient=False, | |
) as progress: | |
progress.add_task(description=task_description, total=None) | |
ret = task_call() | |
else: | |
print(task_description) | |
ret = task_call() | |
return ret |