Rodiyah commited on
Commit
1c7b662
·
verified ·
1 Parent(s): ac84363

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -82
app.py CHANGED
@@ -8,13 +8,9 @@ from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
- import os
12
  import csv
13
  import datetime
14
- import zipfile
15
-
16
- # ✅ Admin key (hidden until typed)
17
- ADMIN_KEY = "Diabetes_Detection"
18
 
19
  # Set device
20
  device = torch.device("cpu")
@@ -38,19 +34,21 @@ transform = transforms.Compose([
38
  [0.229, 0.224, 0.225])
39
  ])
40
 
41
- # Data storage
42
- image_folder = "collected_images"
43
- os.makedirs(image_folder, exist_ok=True)
 
 
 
44
 
45
- csv_log_path = "prediction_logs.csv"
46
- if not os.path.exists(csv_log_path):
47
- with open(csv_log_path, "w", newline="") as f:
48
- writer = csv.writer(f)
49
- writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
50
 
51
  # Prediction function
52
  def predict_retinopathy(image):
53
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
54
  img = image.convert("RGB").resize((224, 224))
55
  img_tensor = transform(img).unsqueeze(0).to(device)
56
 
@@ -67,75 +65,22 @@ def predict_retinopathy(image):
67
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
68
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
69
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
70
- cam_pil = Image.fromarray(cam_image)
71
 
72
- # Save image & log
73
- image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
74
- image_path = os.path.join(image_folder, image_filename)
75
- image.save(image_path)
76
-
77
- with open(csv_log_path, mode="a", newline="") as f:
78
- writer = csv.writer(f)
79
- writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])
80
 
 
81
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
82
 
83
- # Admin unlock
84
- def unlock_admin(key_input):
85
- if key_input == ADMIN_KEY:
86
- return gr.update(visible=True)
87
- return gr.update(visible=False)
88
-
89
- # Download functions
90
- def download_csv():
91
- return csv_log_path
92
-
93
- def download_dataset_zip():
94
- zip_filename = "dataset_bundle.zip"
95
- with zipfile.ZipFile(zip_filename, "w") as zipf:
96
- zipf.write(csv_log_path, arcname="prediction_logs.csv")
97
- for fname in os.listdir(image_folder):
98
- fpath = os.path.join(image_folder, fname)
99
- zipf.write(fpath, arcname=os.path.join("images", fname))
100
- return zip_filename
101
-
102
- # UI
103
- with gr.Blocks() as demo:
104
- gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")
105
-
106
- with gr.Row():
107
- image_input = gr.Image(type="pil", label="Upload Retinal Image")
108
- cam_output = gr.Image(type="pil", label="Grad-CAM Output")
109
-
110
- prediction_output = gr.Text(label="Prediction")
111
- run_button = gr.Button("Submit")
112
-
113
- run_button.click(
114
- fn=predict_retinopathy,
115
- inputs=image_input,
116
- outputs=[cam_output, prediction_output]
117
- )
118
-
119
- gr.Markdown("### 🔐 Admin Access (Rodiyah only)")
120
-
121
- admin_key_input = gr.Text(label="Enter Admin Key", type="password", placeholder="Only Rodiyah knows this!")
122
- unlock_button = gr.Button("Unlock Downloads")
123
-
124
- with gr.Column(visible=False) as admin_panel:
125
- gr.Markdown("### ✅ Download Panel (Private Access)")
126
- with gr.Row():
127
- download_csv_btn = gr.Button("📄 Download CSV Log")
128
- download_zip_btn = gr.Button("📦 Download Full Dataset")
129
- csv_file = gr.File()
130
- zip_file = gr.File()
131
-
132
- unlock_button.click(
133
- fn=unlock_admin,
134
- inputs=admin_key_input,
135
- outputs=admin_panel
136
- )
137
-
138
- download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
139
- download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)
140
-
141
- demo.launch()
 
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
 
11
  import csv
12
  import datetime
13
+ import os
 
 
 
14
 
15
  # Set device
16
  device = torch.device("cpu")
 
34
  [0.229, 0.224, 0.225])
35
  ])
36
 
37
+ # Logging setup
38
+ log_path = "prediction_logs.csv"
39
+
40
+ def log_prediction(filename, prediction, confidence):
41
+ timestamp = datetime.datetime.now().isoformat()
42
+ row = [timestamp, filename, prediction, f"{confidence:.4f}"]
43
 
44
+ print("⏺ Logging prediction:", row) # 🔍 Add this line
45
+
46
+ with open(log_path, mode='a', newline='') as file:
47
+ writer = csv.writer(file)
48
+ writer.writerow(row)
49
 
50
  # Prediction function
51
  def predict_retinopathy(image):
 
52
  img = image.convert("RGB").resize((224, 224))
53
  img_tensor = transform(img).unsqueeze(0).to(device)
54
 
 
65
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
66
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
67
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
 
68
 
69
+ # Logging
70
+ filename = getattr(image, "filename", "uploaded_image")
71
+ log_prediction(filename, label, confidence)
 
 
 
 
 
72
 
73
+ cam_pil = Image.fromarray(cam_image)
74
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
75
 
76
+ # Gradio interface
77
+ gr.Interface(
78
+ fn=predict_retinopathy,
79
+ inputs=gr.Image(type="pil"),
80
+ outputs=[
81
+ gr.Image(type="pil", label="Grad-CAM"),
82
+ gr.Text(label="Prediction")
83
+ ],
84
+ title="Diabetic Retinopathy Detection",
85
+ description="Upload a retinal image to classify DR and view Grad-CAM heatmap. All predictions are logged for analysis."
86
+ ).launch()