NunoCarvalho commited on
Commit
4e79dbb
·
verified ·
1 Parent(s): be72f4a

add handler

Browse files
Files changed (1) hide show
  1. handler.py +31 -32
handler.py CHANGED
@@ -1,38 +1,37 @@
1
- from typing import Dict, List
2
- import torch, base64, io
3
  from PIL import Image
4
- import open_clip
5
 
6
- device = "cuda" if torch.cuda.is_available() else "cpu"
7
- model, _, preprocess = open_clip.create_model_and_transforms(
8
- 'ViT-B-32', pretrained='laion2b_s34b_b79K', device=device
9
- )
 
 
 
10
 
11
- def _embed_image(img_b64: str) -> List[float]:
12
- img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
13
- tensor = preprocess(img).unsqueeze(0).to(device)
14
- with torch.no_grad():
15
- emb = model.encode_image(tensor)
16
- return emb.squeeze().cpu().tolist()
17
 
18
- def _embed_text(text: str) -> List[float]:
19
- tok = open_clip.tokenize([text]).to(device)
20
- with torch.no_grad():
21
- emb = model.encode_text(tok)
22
- return emb.squeeze().cpu().tolist()
23
 
24
- # === HF endpoint entrypoint ===
25
- def preprocess(payload: Dict):
26
- return payload
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def inference(payload: Dict):
29
- if isinstance(payload, str) and payload.startswith("data:image"):
30
- b64 = payload.split(",")[-1]
31
- return {"vector": _embed_image(b64)}
32
- elif isinstance(payload, str):
33
- return {"vector": _embed_text(payload)}
34
- else:
35
- raise ValueError("Unsupported input")
36
-
37
- def postprocess(output): # HF expects this even se passas direto
38
- return output
 
1
+ import torch, open_clip
 
2
  from PIL import Image
3
+ from typing import Any, Dict
4
 
5
+ class EndpointHandler:
6
+ def __init__(self, model_dir: str):
7
+ self.device = "cpu"
8
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
9
+ "ViT-B-32", pretrained="laion2b_s34b_b79K", device=self.device
10
+ )
11
+ self.tokenizer = open_clip.get_tokenizer("ViT-B-32")
12
 
13
+ def _encode_text(self, text: str):
14
+ tokens = self.tokenizer([text]).to(self.device)
15
+ with torch.no_grad():
16
+ return self.model.encode_text(tokens).cpu().numpy()[0].tolist()
 
 
17
 
18
+ def _encode_image(self, image: Image.Image):
19
+ img = self.preprocess(image).unsqueeze(0).to(self.device)
20
+ with torch.no_grad():
21
+ return self.model.encode_image(img).cpu().numpy()[0].tolist()
 
22
 
23
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
24
+ if "image" in data:
25
+ if isinstance(data["image"], str):
26
+ import requests, io
27
+ resp = requests.get(data["image"])
28
+ img = Image.open(io.BytesIO(resp.content)).convert("RGB")
29
+ else:
30
+ img = Image.open(data["image"]).convert("RGB")
31
+ emb = self._encode_image(img)
32
+ elif "inputs" in data:
33
+ emb = self._encode_text(data["inputs"])
34
+ else:
35
+ raise ValueError("Provide 'image' or 'inputs'.")
36
+ return {"embedding": emb}
37