File size: 4,899 Bytes
05e6f93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_image
from sparrow_parse.vllm.inference_base import ModelInference
import os
import json
from rich import print


class MLXInference(ModelInference):
    """
        A class for performing inference using the MLX model.
        Handles image preprocessing, response formatting, and model interaction.
        """

    def __init__(self, model_name):
        """
        Initialize the inference class with the given model name.

        :param model_name: Name of the model to load.
        """
        self.model_name = model_name
        print(f"MLXInference initialized for model: {model_name}")


    @staticmethod
    def _load_model_and_processor(model_name):
        """
        Load the model and processor for inference.

        :param model_name: Name of the model to load.
        :return: Tuple containing the loaded model and processor.
        """
        model, processor = load(model_name)
        print(f"Loaded model: {model_name}")
        return model, processor


    def process_response(self, output_text):
        """
        Process and clean the model's raw output to format as JSON.

        :param output_text: Raw output text from the model.
        :return: A formatted JSON string or the original text in case of errors.
        """
        try:
            cleaned_text = (
                output_text.strip("[]'")
                .replace("```json\n", "")
                .replace("\n```", "")
                .replace("'", "")
            )
            formatted_json = json.loads(cleaned_text)
            return json.dumps(formatted_json, indent=2)
        except json.JSONDecodeError as e:
            print(f"Failed to parse JSON in MLX inference backend: {e}")
            return output_text


    def load_image_data(self, image_filepath, max_width=1250, max_height=1750):
        """
        Load and resize image while maintaining its aspect ratio.

        :param image_filepath: Path to the image file.
        :param max_width: Maximum allowed width of the image.
        :param max_height: Maximum allowed height of the image.
        :return: Tuple containing the image object and its new dimensions.
        """
        image = load_image(image_filepath)  # Assuming load_image is defined elsewhere
        width, height = image.size

        # Calculate new dimensions while maintaining the aspect ratio
        if width > max_width or height > max_height:
            aspect_ratio = width / height
            new_width = min(max_width, int(max_height * aspect_ratio))
            new_height = min(max_height, int(max_width / aspect_ratio))
            return image, new_width, new_height

        return image, width, height


    def inference(self, input_data, mode=None):
        """
        Perform inference on input data using the specified model.

        :param input_data: A list of dictionaries containing image file paths and text inputs.
        :param mode: Optional mode for inference ("static" for simple JSON output).
        :return: List of processed model responses.
        """
        if mode == "static":
            return [self.get_simple_json()]

        # Load the model and processor
        model, processor = self._load_model_and_processor(self.model_name)
        config = model.config

        # Prepare absolute file paths
        file_paths = self._extract_file_paths(input_data)

        results = []
        for file_path in file_paths:
            image, width, height = self.load_image_data(file_path)

            # Prepare messages for the chat model
            messages = [
                {"role": "system", "content": "You are an expert at extracting structured text from image documents."},
                {"role": "user", "content": input_data[0]["text_input"]},
            ]

            # Generate and process response
            prompt = apply_chat_template(processor, config, messages)  # Assuming defined
            response = generate(
                model,
                processor,
                prompt,
                image,
                resize_shape=(width, height),
                max_tokens=4000,
                temperature=0.0,
                verbose=False
            )
            results.append(self.process_response(response))

            print("Inference completed successfully for: ", file_path)

        return results

    @staticmethod
    def _extract_file_paths(input_data):
        """
        Extract and resolve absolute file paths from input data.

        :param input_data: List of dictionaries containing image file paths.
        :return: List of absolute file paths.
        """
        return [
            os.path.abspath(file_path)
            for data in input_data
            for file_path in data.get("file_path", [])
        ]