youngjeck commited on
Commit
434b10c
·
verified ·
1 Parent(s): eea5e88

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +52 -44
main.py CHANGED
@@ -1,44 +1,52 @@
1
- import os
2
- os.environ["HF_HOME"] = "/tmp/huggingface" # Allow caching on Hugging Face Spaces
3
-
4
- from fastapi import FastAPI, UploadFile, File, HTTPException
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from transformers import AutoImageProcessor, AutoModelForImageClassification
7
- from PIL import Image
8
- import torch
9
- import io
10
-
11
- app = FastAPI()
12
- app.add_middleware(
13
- CORSMiddleware,
14
- allow_origins=["*"],
15
- allow_methods=["*"],
16
- allow_headers=["*"],
17
- )
18
-
19
- # Load model + processor
20
- processor = AutoImageProcessor.from_pretrained("Organika/sdxl-detector")
21
- model = AutoModelForImageClassification.from_pretrained("Organika/sdxl-detector")
22
-
23
- @app.post("/analyze")
24
- async def analyze(file: UploadFile = File(...)):
25
- img_bytes = await file.read()
26
-
27
- try:
28
- image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
29
- except:
30
- raise HTTPException(status_code=400, detail="Invalid image file.")
31
-
32
- inputs = processor(images=image, return_tensors="pt")
33
- with torch.no_grad():
34
- logits = model(**inputs).logits
35
- probs = torch.nn.functional.softmax(logits, dim=1)[0]
36
-
37
- labels = model.config.id2label
38
- scores = {labels[i]: float(probs[i]) for i in range(len(probs))}
39
-
40
- return {
41
- "result": max(scores, key=scores.get),
42
- "confidence": max(scores.values()),
43
- "scores": scores,
44
- }
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
6
+ import io, imghdr
7
+
8
+ # Initialize FastAPI app
9
+ app = FastAPI()
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"],
13
+ allow_methods=["POST"],
14
+ allow_headers=["*"],
15
+ )
16
+
17
+ # Load the model + labels
18
+ labels = ["Real", "AI"]
19
+ feature_extractor = AutoFeatureExtractor.from_pretrained("Nahrawy/AIorNot")
20
+ model = AutoModelForImageClassification.from_pretrained("Nahrawy/AIorNot")
21
+
22
+ @app.post("/analyze")
23
+ async def analyze(file: UploadFile = File(...)):
24
+ # Read image bytes
25
+ img_bytes = await file.read()
26
+
27
+ # Sanity check
28
+ if imghdr.what(None, img_bytes) is None:
29
+ raise HTTPException(status_code=400, detail="Uploaded file is not a valid image")
30
+
31
+ # Load image with PIL
32
+ try:
33
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
34
+ except Exception:
35
+ raise HTTPException(status_code=400, detail="Cannot open image")
36
+
37
+ # Run inference
38
+ inputs = feature_extractor(image, return_tensors="pt")
39
+ with torch.no_grad():
40
+ outputs = model(**inputs)
41
+ logits = outputs.logits
42
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
43
+
44
+ prediction = logits.argmax(-1).item()
45
+ label = labels[prediction]
46
+ confidence = float(probs[prediction])
47
+
48
+ return {
49
+ "label": label,
50
+ "confidence": confidence,
51
+ "scores": {labels[i]: float(probs[i]) for i in range(len(labels))}
52
+ }