File size: 1,432 Bytes
4e79dbb 323fbae 4e79dbb 323fbae 4e79dbb 323fbae 4e79dbb 323fbae 4e79dbb 323fbae 4e79dbb 323fbae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch, open_clip
from PIL import Image
from typing import Any, Dict
class EndpointHandler:
def __init__(self, model_dir: str):
self.device = "cpu"
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="laion2b_s34b_b79K", device=self.device
)
self.tokenizer = open_clip.get_tokenizer("ViT-B-32")
def _encode_text(self, text: str):
tokens = self.tokenizer([text]).to(self.device)
with torch.no_grad():
return self.model.encode_text(tokens).cpu().numpy()[0].tolist()
def _encode_image(self, image: Image.Image):
img = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
return self.model.encode_image(img).cpu().numpy()[0].tolist()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
if "image" in data:
if isinstance(data["image"], str):
import requests, io
resp = requests.get(data["image"])
img = Image.open(io.BytesIO(resp.content)).convert("RGB")
else:
img = Image.open(data["image"]).convert("RGB")
emb = self._encode_image(img)
elif "inputs" in data:
emb = self._encode_text(data["inputs"])
else:
raise ValueError("Provide 'image' or 'inputs'.")
return {"embedding": emb}
|