|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from PIL import Image |
|
import torch |
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
import io, imghdr |
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_methods=["POST"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
labels = ["Real", "AI"] |
|
feature_extractor = AutoFeatureExtractor.from_pretrained("Nahrawy/AIorNot") |
|
model = AutoModelForImageClassification.from_pretrained("Nahrawy/AIorNot") |
|
|
|
@app.post("/analyze") |
|
async def analyze(file: UploadFile = File(...)): |
|
|
|
img_bytes = await file.read() |
|
|
|
|
|
if imghdr.what(None, img_bytes) is None: |
|
raise HTTPException(status_code=400, detail="Uploaded file is not a valid image") |
|
|
|
|
|
try: |
|
image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
except Exception: |
|
raise HTTPException(status_code=400, detail="Cannot open image") |
|
|
|
|
|
inputs = feature_extractor(image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probs = torch.nn.functional.softmax(logits, dim=1)[0] |
|
|
|
prediction = logits.argmax(-1).item() |
|
label = labels[prediction] |
|
confidence = float(probs[prediction]) |
|
|
|
return { |
|
"label": label, |
|
"confidence": confidence, |
|
"scores": {labels[i]: float(probs[i]) for i in range(len(labels))} |
|
} |
|
|