Spaces:
Configuration error
Configuration error
File size: 4,899 Bytes
05e6f93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_image
from sparrow_parse.vllm.inference_base import ModelInference
import os
import json
from rich import print
class MLXInference(ModelInference):
"""
A class for performing inference using the MLX model.
Handles image preprocessing, response formatting, and model interaction.
"""
def __init__(self, model_name):
"""
Initialize the inference class with the given model name.
:param model_name: Name of the model to load.
"""
self.model_name = model_name
print(f"MLXInference initialized for model: {model_name}")
@staticmethod
def _load_model_and_processor(model_name):
"""
Load the model and processor for inference.
:param model_name: Name of the model to load.
:return: Tuple containing the loaded model and processor.
"""
model, processor = load(model_name)
print(f"Loaded model: {model_name}")
return model, processor
def process_response(self, output_text):
"""
Process and clean the model's raw output to format as JSON.
:param output_text: Raw output text from the model.
:return: A formatted JSON string or the original text in case of errors.
"""
try:
cleaned_text = (
output_text.strip("[]'")
.replace("```json\n", "")
.replace("\n```", "")
.replace("'", "")
)
formatted_json = json.loads(cleaned_text)
return json.dumps(formatted_json, indent=2)
except json.JSONDecodeError as e:
print(f"Failed to parse JSON in MLX inference backend: {e}")
return output_text
def load_image_data(self, image_filepath, max_width=1250, max_height=1750):
"""
Load and resize image while maintaining its aspect ratio.
:param image_filepath: Path to the image file.
:param max_width: Maximum allowed width of the image.
:param max_height: Maximum allowed height of the image.
:return: Tuple containing the image object and its new dimensions.
"""
image = load_image(image_filepath) # Assuming load_image is defined elsewhere
width, height = image.size
# Calculate new dimensions while maintaining the aspect ratio
if width > max_width or height > max_height:
aspect_ratio = width / height
new_width = min(max_width, int(max_height * aspect_ratio))
new_height = min(max_height, int(max_width / aspect_ratio))
return image, new_width, new_height
return image, width, height
def inference(self, input_data, mode=None):
"""
Perform inference on input data using the specified model.
:param input_data: A list of dictionaries containing image file paths and text inputs.
:param mode: Optional mode for inference ("static" for simple JSON output).
:return: List of processed model responses.
"""
if mode == "static":
return [self.get_simple_json()]
# Load the model and processor
model, processor = self._load_model_and_processor(self.model_name)
config = model.config
# Prepare absolute file paths
file_paths = self._extract_file_paths(input_data)
results = []
for file_path in file_paths:
image, width, height = self.load_image_data(file_path)
# Prepare messages for the chat model
messages = [
{"role": "system", "content": "You are an expert at extracting structured text from image documents."},
{"role": "user", "content": input_data[0]["text_input"]},
]
# Generate and process response
prompt = apply_chat_template(processor, config, messages) # Assuming defined
response = generate(
model,
processor,
prompt,
image,
resize_shape=(width, height),
max_tokens=4000,
temperature=0.0,
verbose=False
)
results.append(self.process_response(response))
print("Inference completed successfully for: ", file_path)
return results
@staticmethod
def _extract_file_paths(input_data):
"""
Extract and resolve absolute file paths from input data.
:param input_data: List of dictionaries containing image file paths.
:return: List of absolute file paths.
"""
return [
os.path.abspath(file_path)
for data in input_data
for file_path in data.get("file_path", [])
] |