import os os.environ["HF_HOME"] = "/tmp/huggingface" from fastapi import FastAPI, UploadFile, File from transformers import SiglipForImageClassification, AutoImageProcessor from PIL import Image import torch import torch.nn.functional as F import io from typing import List app = FastAPI() model_name = "prithivMLmods/Gender-Classifier-Mini" model = SiglipForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name) @app.get("/") async def root(): return {"message": "Gender classifier API is running. Use POST /classify/ with an image file."} @app.post("/classify/") async def classify_gender(image: UploadFile = File(...)): contents = await image.read() try: img = Image.open(io.BytesIO(contents)).convert("RGB") except Exception: return {"error": "Invalid image file"} inputs = processor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = F.softmax(logits, dim=1).squeeze().tolist() labels = ["Female ♀", "Male ♂"] predictions = {labels[i]: round(probs[i], 3) for i in range(len(probs))} max_idx = probs.index(max(probs)) return { "predictions": predictions, "most_likely": labels[max_idx], "confidence": round(probs[max_idx], 3) } @app.post("/classify_batch/") async def classify_gender_batch(images: List[UploadFile] = File(...)): pil_images = [] for image in images: contents = await image.read() try: img = Image.open(io.BytesIO(contents)).convert("RGB") pil_images.append(img) except Exception: return {"error": f"Invalid image file: {image.filename}"} # Batch process inputs = processor(images=pil_images, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = F.softmax(logits, dim=1).tolist() # shape: [batch_size, 2] labels = ["Female ♀", "Male ♂"] results = [] for p in probs: predictions = {labels[i]: round(p[i], 3) for i in range(len(p))} max_idx = p.index(max(p)) results.append({ "predictions": predictions, "most_likely": labels[max_idx], "confidence": round(p[max_idx], 3) }) return {"results": results}