Rodiyah commited on
Commit
ae8f111
·
verified ·
1 Parent(s): ad487de
Files changed (1) hide show
  1. app.py +64 -0
app.py CHANGED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from torchvision import models, transforms
7
+ from pytorch_grad_cam import GradCAM
8
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
+ from pytorch_grad_cam.utils.image import show_cam_on_image
10
+
11
+ # Set device (CPU for Hugging Face Spaces)
12
+ device = torch.device("cpu")
13
+
14
+ # Load model
15
+ model = models.resnet50(weights=None)
16
+ model.fc = torch.nn.Linear(model.fc.in_features, 2)
17
+ model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
18
+ model.to(device)
19
+ model.eval()
20
+
21
+ # Grad-CAM setup
22
+ target_layer = model.layer4[-1]
23
+ cam = GradCAM(model=model, target_layers=[target_layer])
24
+
25
+ # Image transform
26
+ transform = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.485, 0.456, 0.406],
30
+ [0.229, 0.224, 0.225])
31
+ ])
32
+
33
+ def predict_retinopathy(image):
34
+ img = image.convert("RGB").resize((224, 224))
35
+ img_tensor = transform(img).unsqueeze(0).to(device)
36
+
37
+ with torch.no_grad():
38
+ output = model(img_tensor)
39
+ probs = F.softmax(output, dim=1)
40
+ pred = torch.argmax(probs, dim=1).item()
41
+ confidence = probs[0][pred].item()
42
+
43
+ label = "Diabetic Retinopathy (DR)" if pred == 0 else "No DR"
44
+
45
+ # Grad-CAM
46
+ rgb_img_np = np.array(img).astype(np.float32) / 255.0
47
+ rgb_img_np = np.ascontiguousarray(rgb_img_np)
48
+ grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
49
+ cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
50
+
51
+ cam_pil = Image.fromarray(cam_image)
52
+ return cam_pil, f"{label} (Confidence: {confidence:.2f})"
53
+
54
+ # Gradio UI
55
+ gr.Interface(
56
+ fn=predict_retinopathy,
57
+ inputs=gr.Image(type="pil"),
58
+ outputs=[
59
+ gr.Image(type="pil", label="Grad-CAM"),
60
+ gr.Text(label="Prediction")
61
+ ],
62
+ title="Diabetic Retinopathy Detection",
63
+ description="Upload a retinal image to classify DR and view Grad-CAM heatmap."
64
+ ).launch()