File size: 1,591 Bytes
434b10c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import io, imghdr
# Initialize FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_headers=["*"],
)
# Load the model + labels
labels = ["Real", "AI"]
feature_extractor = AutoFeatureExtractor.from_pretrained("Nahrawy/AIorNot")
model = AutoModelForImageClassification.from_pretrained("Nahrawy/AIorNot")
@app.post("/analyze")
async def analyze(file: UploadFile = File(...)):
# Read image bytes
img_bytes = await file.read()
# Sanity check
if imghdr.what(None, img_bytes) is None:
raise HTTPException(status_code=400, detail="Uploaded file is not a valid image")
# Load image with PIL
try:
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Cannot open image")
# Run inference
inputs = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)[0]
prediction = logits.argmax(-1).item()
label = labels[prediction]
confidence = float(probs[prediction])
return {
"label": label,
"confidence": confidence,
"scores": {labels[i]: float(probs[i]) for i in range(len(labels))}
}
|