hafez1082's picture
Update app.py
1e67db5 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
import io
import os
os.environ["TRANSFORMERS_CACHE"] = "/code/.cache"
app = FastAPI()
# Load model and processor once at startup
repo_name = "Jayanth2002/dinov2-base-finetuned-SkinDisease"
image_processor = AutoImageProcessor.from_pretrained(repo_name, use_fast=True)
model = AutoModelForImageClassification.from_pretrained(repo_name)
# Class names
class_names = [
'Basal Cell Carcinoma', 'Darier_s Disease', 'Epidermolysis Bullosa Pruriginosa',
'Hailey-Hailey Disease', 'Herpes Simplex', 'Impetigo', 'Larva Migrans',
'Leprosy Borderline', 'Leprosy Lepromatous', 'Leprosy Tuberculoid',
'Lichen Planus', 'Lupus Erythematosus Chronicus Discoides', 'Melanoma',
'Molluscum Contagiosum', 'Mycosis Fungoides', 'Neurofibromatosis',
'Papilomatosis Confluentes And Reticulate', 'Pediculosis Capitis',
'Pityriasis Rosea', 'Porokeratosis Actinic', 'Psoriasis', 'Tinea Corporis',
'Tinea Nigra', 'Tungiasis', 'actinic keratosis', 'dermatofibroma', 'nevus',
'pigmented benign keratosis', 'seborrheic keratosis', 'squamous cell carcinoma',
'vascular lesion'
]
@app.post("/")
async def predict_skin_disease(file: UploadFile = File(...)):
# Check if the uploaded file is an image
if not file.content_type.startswith('image/'):
raise HTTPException(
status_code=400, detail="File provided is not an image.")
try:
# Read image file
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Preprocess the image
encoding = image_processor(image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class_name = class_names[predicted_class_idx]
# You might also want to return the confidence/probability
probabilities = torch.nn.functional.softmax(logits, dim=-1)
confidence = probabilities[0][predicted_class_idx].item() * 100
return JSONResponse(content={
"model": "Pytorch",
"class": predicted_class_name,
"probability": f"{confidence:.2f}%",
})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/classes")
async def get_classes():
return {"classes": class_names}