Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import JSONResponse | |
from PIL import Image | |
import torch | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
import io | |
import os | |
os.environ["TRANSFORMERS_CACHE"] = "/code/.cache" | |
app = FastAPI() | |
# Load model and processor once at startup | |
repo_name = "Jayanth2002/dinov2-base-finetuned-SkinDisease" | |
image_processor = AutoImageProcessor.from_pretrained(repo_name, use_fast=True) | |
model = AutoModelForImageClassification.from_pretrained(repo_name) | |
# Class names | |
class_names = [ | |
'Basal Cell Carcinoma', 'Darier_s Disease', 'Epidermolysis Bullosa Pruriginosa', | |
'Hailey-Hailey Disease', 'Herpes Simplex', 'Impetigo', 'Larva Migrans', | |
'Leprosy Borderline', 'Leprosy Lepromatous', 'Leprosy Tuberculoid', | |
'Lichen Planus', 'Lupus Erythematosus Chronicus Discoides', 'Melanoma', | |
'Molluscum Contagiosum', 'Mycosis Fungoides', 'Neurofibromatosis', | |
'Papilomatosis Confluentes And Reticulate', 'Pediculosis Capitis', | |
'Pityriasis Rosea', 'Porokeratosis Actinic', 'Psoriasis', 'Tinea Corporis', | |
'Tinea Nigra', 'Tungiasis', 'actinic keratosis', 'dermatofibroma', 'nevus', | |
'pigmented benign keratosis', 'seborrheic keratosis', 'squamous cell carcinoma', | |
'vascular lesion' | |
] | |
async def predict_skin_disease(file: UploadFile = File(...)): | |
# Check if the uploaded file is an image | |
if not file.content_type.startswith('image/'): | |
raise HTTPException( | |
status_code=400, detail="File provided is not an image.") | |
try: | |
# Read image file | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
# Preprocess the image | |
encoding = image_processor(image, return_tensors="pt") | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class_name = class_names[predicted_class_idx] | |
# You might also want to return the confidence/probability | |
probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
confidence = probabilities[0][predicted_class_idx].item() * 100 | |
return JSONResponse(content={ | |
"model": "Pytorch", | |
"class": predicted_class_name, | |
"probability": f"{confidence:.2f}%", | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_classes(): | |
return {"classes": class_names} | |