Rodiyah commited on
Commit
ab15865
·
verified ·
1 Parent(s): e1b2058

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -35
app.py CHANGED
@@ -8,9 +8,11 @@ from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
- import io
12
  import csv
13
  import datetime
 
 
14
 
15
  # Set device
16
  device = torch.device("cpu")
@@ -26,7 +28,7 @@ model.eval()
26
  target_layer = model.layer4[-1]
27
  cam = GradCAM(model=model, target_layers=[target_layer])
28
 
29
- # Preprocessing
30
  transform = transforms.Compose([
31
  transforms.Resize((224, 224)),
32
  transforms.ToTensor(),
@@ -34,18 +36,20 @@ transform = transforms.Compose([
34
  [0.229, 0.224, 0.225])
35
  ])
36
 
37
- # In-memory log list
38
- prediction_log = [["timestamp", "image_name", "prediction", "confidence"]]
 
39
 
40
- # Logging function
41
- def log_prediction(filename, prediction, confidence):
42
- timestamp = datetime.datetime.now().isoformat()
43
- row = [timestamp, filename, prediction, f"{confidence:.4f}"]
44
- prediction_log.append(row)
45
- print(" Logged:", row)
46
 
47
  # Prediction function
48
  def predict_retinopathy(image):
 
49
  img = image.convert("RGB").resize((224, 224))
50
  img_tensor = transform(img).unsqueeze(0).to(device)
51
 
@@ -62,31 +66,40 @@ def predict_retinopathy(image):
62
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
63
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
64
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
65
-
66
- # Log it
67
- filename = getattr(image, "filename", "uploaded_image")
68
- log_prediction(filename, label, confidence)
69
-
70
  cam_pil = Image.fromarray(cam_image)
71
- return cam_pil, f"{label} (Confidence: {confidence:.2f})"
72
-
73
- # CSV download function
74
- def download_logs():
75
- output = io.StringIO()
76
- writer = csv.writer(output)
77
- writer.writerows(prediction_log)
78
- output.seek(0)
79
 
80
- # Save as a temporary file for download
81
- with open("prediction_logs.csv", "w", newline="") as f:
82
- f.write(output.getvalue())
 
83
 
84
- return "prediction_logs.csv"
 
 
 
85
 
 
86
 
87
- # Build the UI with Gradio Blocks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with gr.Blocks() as demo:
89
- gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM & Logging")
90
 
91
  with gr.Row():
92
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
@@ -94,10 +107,13 @@ with gr.Blocks() as demo:
94
 
95
  prediction_output = gr.Text(label="Prediction")
96
 
 
 
97
  with gr.Row():
98
- run_button = gr.Button("Submit")
99
- download_button = gr.Button("📥 Download Logs")
100
- download_file = gr.File(label="Your Log File", interactive=False)
 
101
 
102
  run_button.click(
103
  fn=predict_retinopathy,
@@ -105,10 +121,16 @@ with gr.Blocks() as demo:
105
  outputs=[cam_output, prediction_output]
106
  )
107
 
108
- download_button.click(
109
- fn=download_logs,
 
 
 
 
 
 
110
  inputs=[],
111
- outputs=download_file
112
  )
113
 
114
  demo.launch()
 
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
+ import os
12
  import csv
13
  import datetime
14
+ import io
15
+ import zipfile
16
 
17
  # Set device
18
  device = torch.device("cpu")
 
28
  target_layer = model.layer4[-1]
29
  cam = GradCAM(model=model, target_layers=[target_layer])
30
 
31
+ # Image preprocessing
32
  transform = transforms.Compose([
33
  transforms.Resize((224, 224)),
34
  transforms.ToTensor(),
 
36
  [0.229, 0.224, 0.225])
37
  ])
38
 
39
+ # Folder to store uploaded images
40
+ image_folder = "collected_images"
41
+ os.makedirs(image_folder, exist_ok=True)
42
 
43
+ # CSV log file
44
+ csv_log_path = "prediction_logs.csv"
45
+ if not os.path.exists(csv_log_path):
46
+ with open(csv_log_path, mode="w", newline="") as f:
47
+ writer = csv.writer(f)
48
+ writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
49
 
50
  # Prediction function
51
  def predict_retinopathy(image):
52
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
53
  img = image.convert("RGB").resize((224, 224))
54
  img_tensor = transform(img).unsqueeze(0).to(device)
55
 
 
66
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
67
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
68
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
 
 
 
 
 
69
  cam_pil = Image.fromarray(cam_image)
 
 
 
 
 
 
 
 
70
 
71
+ # Save uploaded image
72
+ image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
73
+ image_path = os.path.join(image_folder, image_filename)
74
+ image.save(image_path)
75
 
76
+ # Log prediction
77
+ with open(csv_log_path, mode="a", newline="") as f:
78
+ writer = csv.writer(f)
79
+ writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])
80
 
81
+ return cam_pil, f"{label} (Confidence: {confidence:.2f})"
82
 
83
+ # Download logs
84
+ def download_csv():
85
+ return csv_log_path
86
+
87
+ # Zip dataset for download
88
+ def download_dataset_zip():
89
+ zip_buffer = io.BytesIO()
90
+ with zipfile.ZipFile(zip_buffer, "w") as zipf:
91
+ # Add CSV
92
+ zipf.write(csv_log_path, arcname="prediction_logs.csv")
93
+ # Add images
94
+ for fname in os.listdir(image_folder):
95
+ fpath = os.path.join(image_folder, fname)
96
+ zipf.write(fpath, arcname=os.path.join("images", fname))
97
+ zip_buffer.seek(0)
98
+ return zip_buffer
99
+
100
+ # Gradio UI
101
  with gr.Blocks() as demo:
102
+ gr.Markdown("## 🧠 DR Detection with Grad-CAM + Full Dataset Logging")
103
 
104
  with gr.Row():
105
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
 
107
 
108
  prediction_output = gr.Text(label="Prediction")
109
 
110
+ run_button = gr.Button("Submit")
111
+
112
  with gr.Row():
113
+ download_csv_btn = gr.Button("📄 Download CSV Log")
114
+ download_zip_btn = gr.Button("📦 Download Full Dataset")
115
+ csv_file = gr.File()
116
+ zip_file = gr.File()
117
 
118
  run_button.click(
119
  fn=predict_retinopathy,
 
121
  outputs=[cam_output, prediction_output]
122
  )
123
 
124
+ download_csv_btn.click(
125
+ fn=download_csv,
126
+ inputs=[],
127
+ outputs=csv_file
128
+ )
129
+
130
+ download_zip_btn.click(
131
+ fn=download_dataset_zip,
132
  inputs=[],
133
+ outputs=zip_file
134
  )
135
 
136
  demo.launch()