Rodiyah commited on
Commit
dd8b58f
·
verified ·
1 Parent(s): 7a08276

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -12,25 +12,26 @@ import os
12
  import csv
13
  import datetime
14
  import zipfile
 
15
 
16
- # === ADMIN SETUP ===
17
- ADMIN_KEY = "rodiyah_secret" # change to your own private key
18
 
19
- # Set device
20
  device = torch.device("cpu")
21
 
22
- # Load model
23
  model = models.resnet50(weights=None)
24
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
25
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
26
  model.to(device)
27
  model.eval()
28
 
29
- # Grad-CAM setup
30
  target_layer = model.layer4[-1]
31
  cam = GradCAM(model=model, target_layers=[target_layer])
32
 
33
- # Image preprocessing
34
  transform = transforms.Compose([
35
  transforms.Resize((224, 224)),
36
  transforms.ToTensor(),
@@ -38,7 +39,7 @@ transform = transforms.Compose([
38
  [0.229, 0.224, 0.225])
39
  ])
40
 
41
- # Folder setup
42
  image_folder = "collected_images"
43
  os.makedirs(image_folder, exist_ok=True)
44
 
@@ -48,7 +49,7 @@ if not os.path.exists(csv_log_path):
48
  writer = csv.writer(f)
49
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
50
 
51
- # === MAIN PREDICTION FUNCTION ===
52
  def predict_retinopathy(image):
53
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
54
  img = image.convert("RGB").resize((224, 224))
@@ -81,10 +82,7 @@ def predict_retinopathy(image):
81
 
82
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
83
 
84
- # === ADMIN DOWNLOAD FUNCTIONS ===
85
- def unlock_downloads(key):
86
- return gr.update(visible=True) if key == ADMIN_KEY else gr.update(visible=False)
87
-
88
  def download_csv():
89
  return csv_log_path
90
 
@@ -97,9 +95,14 @@ def download_dataset_zip():
97
  zipf.write(fpath, arcname=os.path.join("images", fname))
98
  return zip_filename
99
 
100
- # === UI SETUP ===
 
 
 
 
 
101
  with gr.Blocks() as demo:
102
- gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM & Data Collection")
103
 
104
  with gr.Row():
105
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
@@ -114,23 +117,20 @@ with gr.Blocks() as demo:
114
  outputs=[cam_output, prediction_output]
115
  )
116
 
117
- gr.Markdown("### 🔐 Admin Area (Restricted Access)")
118
-
119
- with gr.Row():
120
- admin_input = gr.Text(label="Enter Admin Key", type="password", placeholder="Only Rodiyah knows this 🔐")
121
- unlock_btn = gr.Button("Unlock Downloads")
122
-
123
- with gr.Column(visible=False) as download_section:
124
  with gr.Row():
125
  download_csv_btn = gr.Button("📄 Download CSV Log")
126
  download_zip_btn = gr.Button("📦 Download Full Dataset")
127
  csv_file = gr.File()
128
  zip_file = gr.File()
129
 
130
- unlock_btn.click(
131
- fn=unlock_downloads,
132
- inputs=admin_input,
133
- outputs=download_section
 
 
134
  )
135
 
136
  download_csv_btn.click(
 
12
  import csv
13
  import datetime
14
  import zipfile
15
+ from gradio.routes import Request
16
 
17
+ # === SECRET ADMIN KEY ===
18
+ ADMIN_KEY = "Diabetes_Detection"
19
 
20
+ # === DEVICE SETUP ===
21
  device = torch.device("cpu")
22
 
23
+ # === MODEL LOADING ===
24
  model = models.resnet50(weights=None)
25
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
26
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
27
  model.to(device)
28
  model.eval()
29
 
30
+ # === GRAD-CAM ===
31
  target_layer = model.layer4[-1]
32
  cam = GradCAM(model=model, target_layers=[target_layer])
33
 
34
+ # === IMAGE TRANSFORM ===
35
  transform = transforms.Compose([
36
  transforms.Resize((224, 224)),
37
  transforms.ToTensor(),
 
39
  [0.229, 0.224, 0.225])
40
  ])
41
 
42
+ # === STORAGE ===
43
  image_folder = "collected_images"
44
  os.makedirs(image_folder, exist_ok=True)
45
 
 
49
  writer = csv.writer(f)
50
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
51
 
52
+ # === PREDICT FUNCTION ===
53
  def predict_retinopathy(image):
54
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
55
  img = image.convert("RGB").resize((224, 224))
 
82
 
83
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
84
 
85
+ # === ADMIN FILES ===
 
 
 
86
  def download_csv():
87
  return csv_log_path
88
 
 
95
  zipf.write(fpath, arcname=os.path.join("images", fname))
96
  return zip_filename
97
 
98
+ # === VISIBILITY CHECK ===
99
+ def is_admin(request: Request):
100
+ query_params = dict(request.query_params)
101
+ return query_params.get("admin", "") == ADMIN_KEY
102
+
103
+ # === GRADIO APP ===
104
  with gr.Blocks() as demo:
105
+ gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM + Private Logging")
106
 
107
  with gr.Row():
108
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
 
117
  outputs=[cam_output, prediction_output]
118
  )
119
 
120
+ with gr.Column(visible=False) as admin_section:
121
+ gr.Markdown("### 🔐 Admin Downloads")
 
 
 
 
 
122
  with gr.Row():
123
  download_csv_btn = gr.Button("📄 Download CSV Log")
124
  download_zip_btn = gr.Button("📦 Download Full Dataset")
125
  csv_file = gr.File()
126
  zip_file = gr.File()
127
 
128
+ # Admin visibility only for Rodiyah with key
129
+ demo.load(
130
+ lambda req: gr.update(visible=True) if is_admin(req) else gr.update(visible=False),
131
+ inputs=None,
132
+ outputs=admin_section,
133
+ queue=False
134
  )
135
 
136
  download_csv_btn.click(