File size: 2,619 Bytes
6aafe6d
 
 
 
 
 
4f83924
 
6aafe6d
 
 
 
 
1e67db5
6aafe6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f83924
6aafe6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84061a8
 
427e849
6aafe6d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
import io
import os
os.environ["TRANSFORMERS_CACHE"] = "/code/.cache"

app = FastAPI()

# Load model and processor once at startup
repo_name = "Jayanth2002/dinov2-base-finetuned-SkinDisease"
image_processor = AutoImageProcessor.from_pretrained(repo_name, use_fast=True)
model = AutoModelForImageClassification.from_pretrained(repo_name)

# Class names
class_names = [
    'Basal Cell Carcinoma', 'Darier_s Disease', 'Epidermolysis Bullosa Pruriginosa',
    'Hailey-Hailey Disease', 'Herpes Simplex', 'Impetigo', 'Larva Migrans',
    'Leprosy Borderline', 'Leprosy Lepromatous', 'Leprosy Tuberculoid',
    'Lichen Planus', 'Lupus Erythematosus Chronicus Discoides', 'Melanoma',
    'Molluscum Contagiosum', 'Mycosis Fungoides', 'Neurofibromatosis',
    'Papilomatosis Confluentes And Reticulate', 'Pediculosis Capitis',
    'Pityriasis Rosea', 'Porokeratosis Actinic', 'Psoriasis', 'Tinea Corporis',
    'Tinea Nigra', 'Tungiasis', 'actinic keratosis', 'dermatofibroma', 'nevus',
    'pigmented benign keratosis', 'seborrheic keratosis', 'squamous cell carcinoma',
    'vascular lesion'
]


@app.post("/")
async def predict_skin_disease(file: UploadFile = File(...)):
    # Check if the uploaded file is an image
    if not file.content_type.startswith('image/'):
        raise HTTPException(
            status_code=400, detail="File provided is not an image.")

    try:
        # Read image file
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")

        # Preprocess the image
        encoding = image_processor(image, return_tensors="pt")

        # Make prediction
        with torch.no_grad():
            outputs = model(**encoding)
            logits = outputs.logits

        predicted_class_idx = logits.argmax(-1).item()
        predicted_class_name = class_names[predicted_class_idx]

        # You might also want to return the confidence/probability
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        confidence = probabilities[0][predicted_class_idx].item() * 100

        return JSONResponse(content={
            "model": "Pytorch",
            "class": predicted_class_name,
            "probability": f"{confidence:.2f}%",
        })

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))




@app.get("/classes")
async def get_classes():
    return {"classes": class_names}