import torch, open_clip from PIL import Image from typing import Any, Dict class EndpointHandler: def __init__(self, model_dir: str): self.device = "cpu" self.model, _, self.preprocess = open_clip.create_model_and_transforms( "ViT-B-32", pretrained="laion2b_s34b_b79K", device=self.device ) self.tokenizer = open_clip.get_tokenizer("ViT-B-32") def _encode_text(self, text: str): tokens = self.tokenizer([text]).to(self.device) with torch.no_grad(): return self.model.encode_text(tokens).cpu().numpy()[0].tolist() def _encode_image(self, image: Image.Image): img = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): return self.model.encode_image(img).cpu().numpy()[0].tolist() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: if "image" in data: if isinstance(data["image"], str): import requests, io resp = requests.get(data["image"]) img = Image.open(io.BytesIO(resp.content)).convert("RGB") else: img = Image.open(data["image"]).convert("RGB") emb = self._encode_image(img) elif "inputs" in data: emb = self._encode_text(data["inputs"]) else: raise ValueError("Provide 'image' or 'inputs'.") return {"embedding": emb}