Rodiyah commited on
Commit
412ca3a
·
verified ·
1 Parent(s): 32232cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -8,8 +8,13 @@ from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
- # Set device
 
 
 
12
  device = torch.device("cpu")
 
 
13
 
14
  # Load model
15
  model = models.resnet50(weights=None)
@@ -18,11 +23,11 @@ model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=devi
18
  model.to(device)
19
  model.eval()
20
 
21
- # Grad-CAM setup
22
  target_layer = model.layer4[-1]
23
  cam = GradCAM(model=model, target_layers=[target_layer])
24
 
25
- # Image preprocessing
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
@@ -30,8 +35,9 @@ transform = transforms.Compose([
30
  [0.229, 0.224, 0.225])
31
  ])
32
 
33
- # Prediction function
34
  def predict_retinopathy(image):
 
35
  img = image.convert("RGB").resize((224, 224))
36
  img_tensor = transform(img).unsqueeze(0).to(device)
37
 
@@ -41,18 +47,22 @@ def predict_retinopathy(image):
41
  pred = torch.argmax(probs, dim=1).item()
42
  confidence = probs[0][pred].item()
43
 
44
- label = "Diabetic Retinopathy (DR)" if pred == 0 else "No DR"
45
 
46
  # Grad-CAM
47
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
48
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
49
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
50
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
51
-
52
  cam_pil = Image.fromarray(cam_image)
 
 
 
 
 
53
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
54
 
55
- # Gradio interface
56
  gr.Interface(
57
  fn=predict_retinopathy,
58
  inputs=gr.Image(type="pil"),
@@ -61,5 +71,5 @@ gr.Interface(
61
  gr.Text(label="Prediction")
62
  ],
63
  title="Diabetic Retinopathy Detection",
64
- description="Upload a retinal image to classify DR and view Grad-CAM heatmap."
65
  ).launch()
 
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
+ import os
12
+ import datetime
13
+
14
+ # Setup
15
  device = torch.device("cpu")
16
+ save_dir = "saved_predictions"
17
+ os.makedirs(save_dir, exist_ok=True)
18
 
19
  # Load model
20
  model = models.resnet50(weights=None)
 
23
  model.to(device)
24
  model.eval()
25
 
26
+ # Grad-CAM
27
  target_layer = model.layer4[-1]
28
  cam = GradCAM(model=model, target_layers=[target_layer])
29
 
30
+ # Preprocessing
31
  transform = transforms.Compose([
32
  transforms.Resize((224, 224)),
33
  transforms.ToTensor(),
 
35
  [0.229, 0.224, 0.225])
36
  ])
37
 
38
+ # Predict and save
39
  def predict_retinopathy(image):
40
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
41
  img = image.convert("RGB").resize((224, 224))
42
  img_tensor = transform(img).unsqueeze(0).to(device)
43
 
 
47
  pred = torch.argmax(probs, dim=1).item()
48
  confidence = probs[0][pred].item()
49
 
50
+ label = "DR" if pred == 0 else "NoDR"
51
 
52
  # Grad-CAM
53
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
54
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
55
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
56
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
 
57
  cam_pil = Image.fromarray(cam_image)
58
+
59
+ # Save image with label and confidence
60
+ filename = f"{timestamp}_{label}_{confidence:.2f}.png"
61
+ cam_pil.save(os.path.join(save_dir, filename))
62
+
63
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
64
 
65
+ # Gradio app
66
  gr.Interface(
67
  fn=predict_retinopathy,
68
  inputs=gr.Image(type="pil"),
 
71
  gr.Text(label="Prediction")
72
  ],
73
  title="Diabetic Retinopathy Detection",
74
+ description="Upload a retinal image to classify DR and view Grad-CAM heatmap. All predictions are auto-saved with label and confidence."
75
  ).launch()