Spaces:
Running
Running
File size: 4,868 Bytes
42cd5f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from rag.agents.interface import Pipeline
from sparrow_parse.vllm.inference_factory import InferenceFactory
from sparrow_parse.extractors.vllm_extractor import VLLMExtractor
import timeit
from rich import print
from rich.progress import Progress, SpinnerColumn, TextColumn
from typing import Any, List
from .sparrow_validator import Validator
from .sparrow_utils import is_valid_json, get_json_keys_as_string
import warnings
import os
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class SparrowParsePipeline(Pipeline):
def __init__(self):
pass
def run_pipeline(self,
payload: str,
query_inputs: [str],
query_types: [str],
keywords: [str],
query: str,
file_path: str,
index_name: str,
options: List[str] = None,
group_by_rows: bool = True,
update_targets: bool = True,
debug: bool = False,
local: bool = True) -> Any:
print(f"\nRunning pipeline with {payload}\n")
start = timeit.default_timer()
query_all_data = False
if query == "*":
query_all_data = True
query = None
else:
try:
query, query_schema = self.invoke_pipeline_step(lambda: self.prepare_query_and_schema(query, debug),
"Preparing query and schema", local)
except ValueError as e:
raise e
llm_output = self.invoke_pipeline_step(lambda: self.execute_query(options, query_all_data, query, file_path, debug),
"Executing query", local)
validation_result = None
if query_all_data is False:
validation_result = self.invoke_pipeline_step(lambda: self.validate_result(llm_output, query_all_data, query_schema, debug),
"Validating result", local)
end = timeit.default_timer()
print(f"Time to retrieve answer: {end - start}")
return validation_result if validation_result is not None else llm_output
def prepare_query_and_schema(self, query, debug):
is_query_valid = is_valid_json(query)
if not is_query_valid:
raise ValueError("Invalid query. Please provide a valid JSON query.")
query_keys = get_json_keys_as_string(query)
query_schema = query
query = "retrieve " + query_keys
query = query + ". return response in JSON format, by strictly following this JSON schema: " + query_schema
return query, query_schema
def execute_query(self, options, query_all_data, query, file_path, debug):
extractor = VLLMExtractor()
# export HF_TOKEN="hf_"
config = {}
if options[0] == 'huggingface':
config = {
"method": options[0], # Could be 'huggingface' or 'local_gpu'
"hf_space": options[1],
"hf_token": os.getenv('HF_TOKEN')
}
else:
# Handle other cases if needed
return "First element is not 'huggingface'"
# Use the factory to get the correct instance
factory = InferenceFactory(config)
model_inference_instance = factory.get_inference_instance()
input_data = [
{
"image": file_path,
"text_input": query
}
]
# Now you can run inference without knowing which implementation is used
llm_output = extractor.run_inference(model_inference_instance, input_data, generic_query=query_all_data,
debug=debug)
return llm_output
def validate_result(self, llm_output, query_all_data, query_schema, debug):
validator = Validator(query_schema)
validation_result = validator.validate_json_against_schema(llm_output, validator.generated_schema)
if validation_result is not None:
return validation_result
else:
if debug:
print("LLM output is valid according to the schema.")
def invoke_pipeline_step(self, task_call, task_description, local):
if local:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=False,
) as progress:
progress.add_task(description=task_description, total=None)
ret = task_call()
else:
print(task_description)
ret = task_call()
return ret |