NunoCarvalho commited on
Commit
323fbae
·
verified ·
1 Parent(s): be8cf8d

add handler

Browse files
Files changed (1) hide show
  1. handler.py +38 -0
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch, base64, io
3
+ from PIL import Image
4
+ import open_clip
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ model, _, preprocess = open_clip.create_model_and_transforms(
8
+ 'ViT-B-32', pretrained='laion2b_s34b_b79K', device=device
9
+ )
10
+
11
+ def _embed_image(img_b64: str) -> List[float]:
12
+ img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
13
+ tensor = preprocess(img).unsqueeze(0).to(device)
14
+ with torch.no_grad():
15
+ emb = model.encode_image(tensor)
16
+ return emb.squeeze().cpu().tolist()
17
+
18
+ def _embed_text(text: str) -> List[float]:
19
+ tok = open_clip.tokenize([text]).to(device)
20
+ with torch.no_grad():
21
+ emb = model.encode_text(tok)
22
+ return emb.squeeze().cpu().tolist()
23
+
24
+ # === HF endpoint entrypoint ===
25
+ def preprocess(payload: Dict):
26
+ return payload
27
+
28
+ def inference(payload: Dict):
29
+ if isinstance(payload, str) and payload.startswith("data:image"):
30
+ b64 = payload.split(",")[-1]
31
+ return {"vector": _embed_image(b64)}
32
+ elif isinstance(payload, str):
33
+ return {"vector": _embed_text(payload)}
34
+ else:
35
+ raise ValueError("Unsupported input")
36
+
37
+ def postprocess(output): # HF expects this even se passas direto
38
+ return output