from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from PIL import Image import torch from transformers import AutoFeatureExtractor, AutoModelForImageClassification import io, imghdr # Initialize FastAPI app app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["POST"], allow_headers=["*"], ) # Load the model + labels labels = ["Real", "AI"] feature_extractor = AutoFeatureExtractor.from_pretrained("Nahrawy/AIorNot") model = AutoModelForImageClassification.from_pretrained("Nahrawy/AIorNot") @app.post("/analyze") async def analyze(file: UploadFile = File(...)): # Read image bytes img_bytes = await file.read() # Sanity check if imghdr.what(None, img_bytes) is None: raise HTTPException(status_code=400, detail="Uploaded file is not a valid image") # Load image with PIL try: image = Image.open(io.BytesIO(img_bytes)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Cannot open image") # Run inference inputs = feature_extractor(image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=1)[0] prediction = logits.argmax(-1).item() label = labels[prediction] confidence = float(probs[prediction]) return { "label": label, "confidence": confidence, "scores": {labels[i]: float(probs[i]) for i in range(len(labels))} }