Rodiyah commited on
Commit
079be8c
·
verified ·
1 Parent(s): 3222f21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -24
app.py CHANGED
@@ -12,12 +12,10 @@ import os
12
  import csv
13
  import datetime
14
  import zipfile
15
- from gradio.routes import Request
16
 
17
- # 🔐 Secret key
18
  ADMIN_KEY = "Diabetes_Detection"
19
 
20
- # Device
21
  device = torch.device("cpu")
22
 
23
  # Load model
@@ -27,11 +25,11 @@ model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=devi
27
  model.to(device)
28
  model.eval()
29
 
30
- # Grad-CAM setup
31
  target_layer = model.layer4[-1]
32
  cam = GradCAM(model=model, target_layers=[target_layer])
33
 
34
- # Preprocessing
35
  transform = transforms.Compose([
36
  transforms.Resize((224, 224)),
37
  transforms.ToTensor(),
@@ -39,17 +37,17 @@ transform = transforms.Compose([
39
  [0.229, 0.224, 0.225])
40
  ])
41
 
42
- # Folders
43
  image_folder = "collected_images"
44
  os.makedirs(image_folder, exist_ok=True)
45
 
46
  csv_log_path = "prediction_logs.csv"
47
  if not os.path.exists(csv_log_path):
48
- with open(csv_log_path, mode="w", newline="") as f:
49
  writer = csv.writer(f)
50
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
51
 
52
- # 🔍 Prediction
53
  def predict_retinopathy(image):
54
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
55
  img = image.convert("RGB").resize((224, 224))
@@ -81,7 +79,7 @@ def predict_retinopathy(image):
81
 
82
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
83
 
84
- # 📁 Admin downloads
85
  def download_csv():
86
  return csv_log_path
87
 
@@ -94,14 +92,17 @@ def download_dataset_zip():
94
  zipf.write(fpath, arcname=os.path.join("images", fname))
95
  return zip_filename
96
 
97
- # ✅ Admin check (query param)
98
- def is_admin(request: Request):
99
- return request and request.query_params.get("admin") == ADMIN_KEY
 
100
 
101
- # 🌐 App
102
  with gr.Blocks() as demo:
103
  gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")
104
 
 
 
105
  with gr.Row():
106
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
107
  cam_output = gr.Image(type="pil", label="Grad-CAM")
@@ -115,24 +116,16 @@ with gr.Blocks() as demo:
115
  outputs=[cam_output, prediction_output]
116
  )
117
 
118
- # 🔒 Hidden admin section
119
  with gr.Column(visible=False) as admin_section:
120
- gr.Markdown("### 🔐 Private Downloads (Rodiyah Only)")
121
  with gr.Row():
122
  download_csv_btn = gr.Button("📄 Download CSV Log")
123
  download_zip_btn = gr.Button("📦 Download Dataset ZIP")
124
  csv_file = gr.File()
125
  zip_file = gr.File()
126
 
127
- # Reveal only if correct ?admin=Diabetes_Detection in URL
128
- demo.load(
129
- fn=lambda req: gr.update(visible=True) if is_admin(req) else gr.update(visible=False),
130
- inputs=[],
131
- outputs=admin_section,
132
- queue=False,
133
- api_name=False,
134
- request=True # ✅ Required to pass HTTP request into lambda
135
- )
136
 
137
  download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
138
  download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)
 
12
  import csv
13
  import datetime
14
  import zipfile
 
15
 
16
+ # Admin secret
17
  ADMIN_KEY = "Diabetes_Detection"
18
 
 
19
  device = torch.device("cpu")
20
 
21
  # Load model
 
25
  model.to(device)
26
  model.eval()
27
 
28
+ # Grad-CAM
29
  target_layer = model.layer4[-1]
30
  cam = GradCAM(model=model, target_layers=[target_layer])
31
 
32
+ # Preprocess
33
  transform = transforms.Compose([
34
  transforms.Resize((224, 224)),
35
  transforms.ToTensor(),
 
37
  [0.229, 0.224, 0.225])
38
  ])
39
 
40
+ # Folders & logs
41
  image_folder = "collected_images"
42
  os.makedirs(image_folder, exist_ok=True)
43
 
44
  csv_log_path = "prediction_logs.csv"
45
  if not os.path.exists(csv_log_path):
46
+ with open(csv_log_path, "w", newline="") as f:
47
  writer = csv.writer(f)
48
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
49
 
50
+ # Prediction
51
  def predict_retinopathy(image):
52
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
53
  img = image.convert("RGB").resize((224, 224))
 
79
 
80
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
81
 
82
+ # Downloads
83
  def download_csv():
84
  return csv_log_path
85
 
 
92
  zipf.write(fpath, arcname=os.path.join("images", fname))
93
  return zip_filename
94
 
95
+ def check_admin(query_str):
96
+ if f"admin={ADMIN_KEY}" in query_str:
97
+ return gr.update(visible=True)
98
+ return gr.update(visible=False)
99
 
100
+ # Gradio UI
101
  with gr.Blocks() as demo:
102
  gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")
103
 
104
+ url_input = gr.Textbox(visible=False) # Holds query string
105
+
106
  with gr.Row():
107
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
108
  cam_output = gr.Image(type="pil", label="Grad-CAM")
 
116
  outputs=[cam_output, prediction_output]
117
  )
118
 
 
119
  with gr.Column(visible=False) as admin_section:
120
+ gr.Markdown("### 🔐 Admin Downloads (Private)")
121
  with gr.Row():
122
  download_csv_btn = gr.Button("📄 Download CSV Log")
123
  download_zip_btn = gr.Button("📦 Download Dataset ZIP")
124
  csv_file = gr.File()
125
  zip_file = gr.File()
126
 
127
+ # Logic to reveal admin section
128
+ url_input.change(fn=check_admin, inputs=url_input, outputs=admin_section)
 
 
 
 
 
 
 
129
 
130
  download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
131
  download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)