benstaf commited on
Commit
23dd76e
·
verified ·
1 Parent(s): c5561b1

Added batch processing

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
7
  import torch
8
  import torch.nn.functional as F
9
  import io
 
10
 
11
  app = FastAPI()
12
 
@@ -42,3 +43,41 @@ async def classify_gender(image: UploadFile = File(...)):
42
  "most_likely": labels[max_idx],
43
  "confidence": round(probs[max_idx], 3)
44
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  import torch.nn.functional as F
9
  import io
10
+ from typing import List
11
 
12
  app = FastAPI()
13
 
 
43
  "most_likely": labels[max_idx],
44
  "confidence": round(probs[max_idx], 3)
45
  }
46
+
47
+
48
+
49
+
50
+ @app.post("/classify_batch/")
51
+ async def classify_gender_batch(images: List[UploadFile] = File(...)):
52
+ pil_images = []
53
+ for image in images:
54
+ contents = await image.read()
55
+ try:
56
+ img = Image.open(io.BytesIO(contents)).convert("RGB")
57
+ pil_images.append(img)
58
+ except Exception:
59
+ return {"error": f"Invalid image file: {image.filename}"}
60
+
61
+ # Batch process
62
+ inputs = processor(images=pil_images, return_tensors="pt")
63
+
64
+ with torch.no_grad():
65
+ outputs = model(**inputs)
66
+ logits = outputs.logits
67
+ probs = F.softmax(logits, dim=1).tolist() # shape: [batch_size, 2]
68
+
69
+ labels = ["Female ♀", "Male ♂"]
70
+
71
+ results = []
72
+ for p in probs:
73
+ predictions = {labels[i]: round(p[i], 3) for i in range(len(p))}
74
+ max_idx = p.index(max(p))
75
+ results.append({
76
+ "predictions": predictions,
77
+ "most_likely": labels[max_idx],
78
+ "confidence": round(p[max_idx], 3)
79
+ })
80
+
81
+ return {"results": results}
82
+
83
+