aidetect / main.py
youngjeck's picture
Update main.py
434b10c verified
raw
history blame
1.59 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import io, imghdr
# Initialize FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_headers=["*"],
)
# Load the model + labels
labels = ["Real", "AI"]
feature_extractor = AutoFeatureExtractor.from_pretrained("Nahrawy/AIorNot")
model = AutoModelForImageClassification.from_pretrained("Nahrawy/AIorNot")
@app.post("/analyze")
async def analyze(file: UploadFile = File(...)):
# Read image bytes
img_bytes = await file.read()
# Sanity check
if imghdr.what(None, img_bytes) is None:
raise HTTPException(status_code=400, detail="Uploaded file is not a valid image")
# Load image with PIL
try:
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Cannot open image")
# Run inference
inputs = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)[0]
prediction = logits.argmax(-1).item()
label = labels[prediction]
confidence = float(probs[prediction])
return {
"label": label,
"confidence": confidence,
"scores": {labels[i]: float(probs[i]) for i in range(len(labels))}
}