clip-img-encoder / handler.py
NunoCarvalho's picture
add handler
4e79dbb verified
import torch, open_clip
from PIL import Image
from typing import Any, Dict
class EndpointHandler:
def __init__(self, model_dir: str):
self.device = "cpu"
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="laion2b_s34b_b79K", device=self.device
)
self.tokenizer = open_clip.get_tokenizer("ViT-B-32")
def _encode_text(self, text: str):
tokens = self.tokenizer([text]).to(self.device)
with torch.no_grad():
return self.model.encode_text(tokens).cpu().numpy()[0].tolist()
def _encode_image(self, image: Image.Image):
img = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
return self.model.encode_image(img).cpu().numpy()[0].tolist()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
if "image" in data:
if isinstance(data["image"], str):
import requests, io
resp = requests.get(data["image"])
img = Image.open(io.BytesIO(resp.content)).convert("RGB")
else:
img = Image.open(data["image"]).convert("RGB")
emb = self._encode_image(img)
elif "inputs" in data:
emb = self._encode_text(data["inputs"])
else:
raise ValueError("Provide 'image' or 'inputs'.")
return {"embedding": emb}