hafez1082 commited on
Commit
6aafe6d
·
verified ·
1 Parent(s): 5ac05b9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
6
+ import io
7
+
8
+ app = FastAPI()
9
+
10
+ # Load model and processor once at startup
11
+ repo_name = "Jayanth2002/dinov2-base-finetuned-SkinDisease"
12
+ image_processor = AutoImageProcessor.from_pretrained(repo_name)
13
+ model = AutoModelForImageClassification.from_pretrained(repo_name)
14
+
15
+ # Class names
16
+ class_names = [
17
+ 'Basal Cell Carcinoma', 'Darier_s Disease', 'Epidermolysis Bullosa Pruriginosa',
18
+ 'Hailey-Hailey Disease', 'Herpes Simplex', 'Impetigo', 'Larva Migrans',
19
+ 'Leprosy Borderline', 'Leprosy Lepromatous', 'Leprosy Tuberculoid',
20
+ 'Lichen Planus', 'Lupus Erythematosus Chronicus Discoides', 'Melanoma',
21
+ 'Molluscum Contagiosum', 'Mycosis Fungoides', 'Neurofibromatosis',
22
+ 'Papilomatosis Confluentes And Reticulate', 'Pediculosis Capitis',
23
+ 'Pityriasis Rosea', 'Porokeratosis Actinic', 'Psoriasis', 'Tinea Corporis',
24
+ 'Tinea Nigra', 'Tungiasis', 'actinic keratosis', 'dermatofibroma', 'nevus',
25
+ 'pigmented benign keratosis', 'seborrheic keratosis', 'squamous cell carcinoma',
26
+ 'vascular lesion'
27
+ ]
28
+
29
+
30
+ @app.post("/predict/")
31
+ async def predict_skin_disease(file: UploadFile = File(...)):
32
+ # Check if the uploaded file is an image
33
+ if not file.content_type.startswith('image/'):
34
+ raise HTTPException(
35
+ status_code=400, detail="File provided is not an image.")
36
+
37
+ try:
38
+ # Read image file
39
+ contents = await file.read()
40
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
41
+
42
+ # Preprocess the image
43
+ encoding = image_processor(image, return_tensors="pt")
44
+
45
+ # Make prediction
46
+ with torch.no_grad():
47
+ outputs = model(**encoding)
48
+ logits = outputs.logits
49
+
50
+ predicted_class_idx = logits.argmax(-1).item()
51
+ predicted_class_name = class_names[predicted_class_idx]
52
+
53
+ # You might also want to return the confidence/probability
54
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
55
+ confidence = probabilities[0][predicted_class_idx].item() * 100
56
+
57
+ return JSONResponse(content={
58
+ "predicted_class": predicted_class_name,
59
+ "confidence": float(confidence),
60
+ })
61
+
62
+ except Exception as e:
63
+ raise HTTPException(status_code=500, detail=str(e))
64
+
65
+
66
+ @app.get("/")
67
+ async def root():
68
+ return {"message": "Skin Disease Classification API"}
69
+
70
+
71
+ @app.get("/classes")
72
+ async def get_classes():
73
+ return {"classes": class_names}