benstaf's picture
Upload 3 files
ccd3d82 verified
raw
history blame
924 Bytes
from fastapi import FastAPI, UploadFile, File
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import torch.nn.functional as F
import torch
import io
app = FastAPI()
model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Gender-Classifier-Mini")
processor = AutoImageProcessor.from_pretrained("prithivMLmods/Gender-Classifier-Mini")
@app.post("/classify/")
async def classify_gender(image: UploadFile = File(...)):
contents = await image.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1)
pred = torch.argmax(probs).item()
confidence = probs[0][pred].item()
label = model.config.id2label[pred]
return {"label": label, "confidence": confidence}