from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import JSONResponse from PIL import Image import torch from transformers import AutoModelForImageClassification, AutoImageProcessor import io import os os.environ["TRANSFORMERS_CACHE"] = "/code/.cache" app = FastAPI() # Load model and processor once at startup repo_name = "Jayanth2002/dinov2-base-finetuned-SkinDisease" image_processor = AutoImageProcessor.from_pretrained(repo_name, use_fast=True) model = AutoModelForImageClassification.from_pretrained(repo_name) # Class names class_names = [ 'Basal Cell Carcinoma', 'Darier_s Disease', 'Epidermolysis Bullosa Pruriginosa', 'Hailey-Hailey Disease', 'Herpes Simplex', 'Impetigo', 'Larva Migrans', 'Leprosy Borderline', 'Leprosy Lepromatous', 'Leprosy Tuberculoid', 'Lichen Planus', 'Lupus Erythematosus Chronicus Discoides', 'Melanoma', 'Molluscum Contagiosum', 'Mycosis Fungoides', 'Neurofibromatosis', 'Papilomatosis Confluentes And Reticulate', 'Pediculosis Capitis', 'Pityriasis Rosea', 'Porokeratosis Actinic', 'Psoriasis', 'Tinea Corporis', 'Tinea Nigra', 'Tungiasis', 'actinic keratosis', 'dermatofibroma', 'nevus', 'pigmented benign keratosis', 'seborrheic keratosis', 'squamous cell carcinoma', 'vascular lesion' ] @app.post("/") async def predict_skin_disease(file: UploadFile = File(...)): # Check if the uploaded file is an image if not file.content_type.startswith('image/'): raise HTTPException( status_code=400, detail="File provided is not an image.") try: # Read image file contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") # Preprocess the image encoding = image_processor(image, return_tensors="pt") # Make prediction with torch.no_grad(): outputs = model(**encoding) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() predicted_class_name = class_names[predicted_class_idx] # You might also want to return the confidence/probability probabilities = torch.nn.functional.softmax(logits, dim=-1) confidence = probabilities[0][predicted_class_idx].item() * 100 return JSONResponse(content={ "model": "Pytorch", "class": predicted_class_name, "probability": f"{confidence:.2f}%", }) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/classes") async def get_classes(): return {"classes": class_names}