Rodiyah commited on
Commit
7a08276
·
verified ·
1 Parent(s): b1aa13e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -17
app.py CHANGED
@@ -11,9 +11,11 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
11
  import os
12
  import csv
13
  import datetime
14
- import io
15
  import zipfile
16
 
 
 
 
17
  # Set device
18
  device = torch.device("cpu")
19
 
@@ -36,18 +38,17 @@ transform = transforms.Compose([
36
  [0.229, 0.224, 0.225])
37
  ])
38
 
39
- # Folder to store uploaded images
40
  image_folder = "collected_images"
41
  os.makedirs(image_folder, exist_ok=True)
42
 
43
- # CSV log file
44
  csv_log_path = "prediction_logs.csv"
45
  if not os.path.exists(csv_log_path):
46
  with open(csv_log_path, mode="w", newline="") as f:
47
  writer = csv.writer(f)
48
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
49
 
50
- # Prediction function
51
  def predict_retinopathy(image):
52
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
53
  img = image.convert("RGB").resize((224, 224))
@@ -80,46 +81,58 @@ def predict_retinopathy(image):
80
 
81
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
82
 
83
- # Download logs
 
 
 
84
  def download_csv():
85
  return csv_log_path
86
 
87
- # Zip dataset for download
88
  def download_dataset_zip():
89
  zip_filename = "dataset_bundle.zip"
90
  with zipfile.ZipFile(zip_filename, "w") as zipf:
91
- # Add CSV
92
  zipf.write(csv_log_path, arcname="prediction_logs.csv")
93
- # Add images
94
  for fname in os.listdir(image_folder):
95
  fpath = os.path.join(image_folder, fname)
96
  zipf.write(fpath, arcname=os.path.join("images", fname))
97
  return zip_filename
98
 
99
- # Gradio UI
100
  with gr.Blocks() as demo:
101
- gr.Markdown("## 🧠 DR Detection with Grad-CAM + Full Dataset Logging")
102
 
103
  with gr.Row():
104
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
105
  cam_output = gr.Image(type="pil", label="Grad-CAM")
106
 
107
  prediction_output = gr.Text(label="Prediction")
108
-
109
  run_button = gr.Button("Submit")
110
 
111
- with gr.Row():
112
- download_csv_btn = gr.Button("📄 Download CSV Log")
113
- download_zip_btn = gr.Button("📦 Download Full Dataset")
114
- csv_file = gr.File()
115
- zip_file = gr.File()
116
-
117
  run_button.click(
118
  fn=predict_retinopathy,
119
  inputs=image_input,
120
  outputs=[cam_output, prediction_output]
121
  )
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  download_csv_btn.click(
124
  fn=download_csv,
125
  inputs=[],
 
11
  import os
12
  import csv
13
  import datetime
 
14
  import zipfile
15
 
16
+ # === ADMIN SETUP ===
17
+ ADMIN_KEY = "rodiyah_secret" # change to your own private key
18
+
19
  # Set device
20
  device = torch.device("cpu")
21
 
 
38
  [0.229, 0.224, 0.225])
39
  ])
40
 
41
+ # Folder setup
42
  image_folder = "collected_images"
43
  os.makedirs(image_folder, exist_ok=True)
44
 
 
45
  csv_log_path = "prediction_logs.csv"
46
  if not os.path.exists(csv_log_path):
47
  with open(csv_log_path, mode="w", newline="") as f:
48
  writer = csv.writer(f)
49
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
50
 
51
+ # === MAIN PREDICTION FUNCTION ===
52
  def predict_retinopathy(image):
53
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
54
  img = image.convert("RGB").resize((224, 224))
 
81
 
82
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
83
 
84
+ # === ADMIN DOWNLOAD FUNCTIONS ===
85
+ def unlock_downloads(key):
86
+ return gr.update(visible=True) if key == ADMIN_KEY else gr.update(visible=False)
87
+
88
  def download_csv():
89
  return csv_log_path
90
 
 
91
  def download_dataset_zip():
92
  zip_filename = "dataset_bundle.zip"
93
  with zipfile.ZipFile(zip_filename, "w") as zipf:
 
94
  zipf.write(csv_log_path, arcname="prediction_logs.csv")
 
95
  for fname in os.listdir(image_folder):
96
  fpath = os.path.join(image_folder, fname)
97
  zipf.write(fpath, arcname=os.path.join("images", fname))
98
  return zip_filename
99
 
100
+ # === UI SETUP ===
101
  with gr.Blocks() as demo:
102
+ gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM & Data Collection")
103
 
104
  with gr.Row():
105
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
106
  cam_output = gr.Image(type="pil", label="Grad-CAM")
107
 
108
  prediction_output = gr.Text(label="Prediction")
 
109
  run_button = gr.Button("Submit")
110
 
 
 
 
 
 
 
111
  run_button.click(
112
  fn=predict_retinopathy,
113
  inputs=image_input,
114
  outputs=[cam_output, prediction_output]
115
  )
116
 
117
+ gr.Markdown("### 🔐 Admin Area (Restricted Access)")
118
+
119
+ with gr.Row():
120
+ admin_input = gr.Text(label="Enter Admin Key", type="password", placeholder="Only Rodiyah knows this 🔐")
121
+ unlock_btn = gr.Button("Unlock Downloads")
122
+
123
+ with gr.Column(visible=False) as download_section:
124
+ with gr.Row():
125
+ download_csv_btn = gr.Button("📄 Download CSV Log")
126
+ download_zip_btn = gr.Button("📦 Download Full Dataset")
127
+ csv_file = gr.File()
128
+ zip_file = gr.File()
129
+
130
+ unlock_btn.click(
131
+ fn=unlock_downloads,
132
+ inputs=admin_input,
133
+ outputs=download_section
134
+ )
135
+
136
  download_csv_btn.click(
137
  fn=download_csv,
138
  inputs=[],