|
import torch, open_clip |
|
from PIL import Image |
|
from typing import Any, Dict |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str): |
|
self.device = "cpu" |
|
self.model, _, self.preprocess = open_clip.create_model_and_transforms( |
|
"ViT-B-32", pretrained="laion2b_s34b_b79K", device=self.device |
|
) |
|
self.tokenizer = open_clip.get_tokenizer("ViT-B-32") |
|
|
|
def _encode_text(self, text: str): |
|
tokens = self.tokenizer([text]).to(self.device) |
|
with torch.no_grad(): |
|
return self.model.encode_text(tokens).cpu().numpy()[0].tolist() |
|
|
|
def _encode_image(self, image: Image.Image): |
|
img = self.preprocess(image).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
return self.model.encode_image(img).cpu().numpy()[0].tolist() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
if "image" in data: |
|
if isinstance(data["image"], str): |
|
import requests, io |
|
resp = requests.get(data["image"]) |
|
img = Image.open(io.BytesIO(resp.content)).convert("RGB") |
|
else: |
|
img = Image.open(data["image"]).convert("RGB") |
|
emb = self._encode_image(img) |
|
elif "inputs" in data: |
|
emb = self._encode_text(data["inputs"]) |
|
else: |
|
raise ValueError("Provide 'image' or 'inputs'.") |
|
return {"embedding": emb} |
|
|
|
|