Rodiyah commited on
Commit
7303661
·
verified ·
1 Parent(s): e39c9f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -4
app.py CHANGED
@@ -8,7 +8,11 @@ from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
- # Set device (CPU for Hugging Face Spaces)
 
 
 
 
12
  device = torch.device("cpu")
13
 
14
  # Load model
@@ -22,7 +26,7 @@ model.eval()
22
  target_layer = model.layer4[-1]
23
  cam = GradCAM(model=model, target_layers=[target_layer])
24
 
25
- # Image transform
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
@@ -30,6 +34,23 @@ transform = transforms.Compose([
30
  [0.229, 0.224, 0.225])
31
  ])
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def predict_retinopathy(image):
34
  img = image.convert("RGB").resize((224, 224))
35
  img_tensor = transform(img).unsqueeze(0).to(device)
@@ -48,10 +69,14 @@ def predict_retinopathy(image):
48
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
49
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
50
 
 
 
 
 
51
  cam_pil = Image.fromarray(cam_image)
52
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
53
 
54
- # Gradio UI
55
  gr.Interface(
56
  fn=predict_retinopathy,
57
  inputs=gr.Image(type="pil"),
@@ -60,5 +85,5 @@ gr.Interface(
60
  gr.Text(label="Prediction")
61
  ],
62
  title="Diabetic Retinopathy Detection",
63
- description="Upload a retinal image to classify DR and view Grad-CAM heatmap."
64
  ).launch()
 
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
+ import csv
12
+ import datetime
13
+ import os
14
+
15
+ # Set device
16
  device = torch.device("cpu")
17
 
18
  # Load model
 
26
  target_layer = model.layer4[-1]
27
  cam = GradCAM(model=model, target_layers=[target_layer])
28
 
29
+ # Image preprocessing
30
  transform = transforms.Compose([
31
  transforms.Resize((224, 224)),
32
  transforms.ToTensor(),
 
34
  [0.229, 0.224, 0.225])
35
  ])
36
 
37
+ # Logging setup
38
+ log_path = "prediction_logs.csv"
39
+
40
+ def log_prediction(filename, prediction, confidence):
41
+ timestamp = datetime.datetime.now().isoformat()
42
+ row = [timestamp, filename, prediction, f"{confidence:.4f}"]
43
+
44
+ if not os.path.exists(log_path):
45
+ with open(log_path, mode='w', newline='') as file:
46
+ writer = csv.writer(file)
47
+ writer.writerow(["timestamp", "image_name", "prediction", "confidence"])
48
+
49
+ with open(log_path, mode='a', newline='') as file:
50
+ writer = csv.writer(file)
51
+ writer.writerow(row)
52
+
53
+ # Prediction function
54
  def predict_retinopathy(image):
55
  img = image.convert("RGB").resize((224, 224))
56
  img_tensor = transform(img).unsqueeze(0).to(device)
 
69
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
70
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
71
 
72
+ # Logging
73
+ filename = getattr(image, "filename", "uploaded_image")
74
+ log_prediction(filename, label, confidence)
75
+
76
  cam_pil = Image.fromarray(cam_image)
77
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
78
 
79
+ # Gradio interface
80
  gr.Interface(
81
  fn=predict_retinopathy,
82
  inputs=gr.Image(type="pil"),
 
85
  gr.Text(label="Prediction")
86
  ],
87
  title="Diabetic Retinopathy Detection",
88
+ description="Upload a retinal image to classify DR and view Grad-CAM heatmap. All predictions are logged for analysis."
89
  ).launch()