File size: 3,317 Bytes
657c17b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1aa475
657c17b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from io import BytesIO
import base64
from huggingface_hub import login
from huggingface_hub import login
import os

class EndpointHandler:
    def __init__(self, model_dir=None):
        print("[Init] Initializing EndpointHandler...")
        self.load_model()

    def load_model(self):
        hf_token = os.getenv("HF_TOKEN")
        model_path = "openbmb/MiniCPM-o-2_6"  # use model repo name directly

        if hf_token:
            print("[Auth] Logging into Hugging Face Hub with token...")
            login(token=hf_token)

        print(f"[Model Load] Loading model from: {model_path}")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
            self.model = AutoModel.from_pretrained(
                model_path,
                trust_remote_code=True,
                attn_implementation='sdpa',
                torch_dtype='auto',  # safer on Spaces
                init_vision=True,
                init_audio=False,
                init_tts=False
            ).eval().cuda()
            print("[Model Load] Model successfully loaded and moved to CUDA.")
        except Exception as e:
            print(f"[Model Load Error] {e}")
            raise RuntimeError(f"Failed to load model: {e}")
        
    def load_image(self, image_base64):
        try:
            print("[Image Load] Decoding base64 image...")
            image_bytes = base64.b64decode(image_base64)
            image = Image.open(BytesIO(image_bytes)).convert("RGB")
            print("[Image Load] Image successfully decoded and converted to RGB.")
            return image
        except Exception as e:
            print(f"[Image Load Error] {e}")
            raise ValueError(f"Failed to open image from base64 string: {e}")

    def predict(self, request):
        print(f"[Predict] Received request: {request}")
        image_base64 = request.get("inputs", {}).get("image")
        question = request.get("inputs", {}).get("question")
        stream = request.get("inputs", {}).get("stream", False)

        if not image_base64 or not question:
            print("[Predict Error] Missing 'image' or 'question' in the request.")
            return {"error": "Missing 'image' or 'question' in inputs."}

        try:
            image = self.load_image(image_base64)
            msgs = [{"role": "user", "content": [image, question]}]

            print(f"[Predict] Asking model with question: {question}")
            print("[Predict] Starting chat inference...")

            res = self.model.chat(
                image=None,
                msgs=msgs,
                tokenizer=self.tokenizer,
                sampling=True,
                stream=stream
            )

            if stream:
                for new_text in res:
                    yield {"output": new_text}
            else:
                generated_text = "".join(res)
                print("[Predict] Inference complete.")
                return {"output": generated_text}

        except Exception as e:
            print(f"[Predict Error] {e}")
            return {"error": str(e)}

    def __call__(self, data):
        print("[__call__] Invoked handler with data.")
        return self.predict(data)