Rodiyah commited on
Commit
8e84628
·
verified ·
1 Parent(s): dd8b58f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -31
app.py CHANGED
@@ -14,24 +14,24 @@ import datetime
14
  import zipfile
15
  from gradio.routes import Request
16
 
17
- # === SECRET ADMIN KEY ===
18
  ADMIN_KEY = "Diabetes_Detection"
19
 
20
- # === DEVICE SETUP ===
21
  device = torch.device("cpu")
22
 
23
- # === MODEL LOADING ===
24
  model = models.resnet50(weights=None)
25
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
26
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
27
  model.to(device)
28
  model.eval()
29
 
30
- # === GRAD-CAM ===
31
  target_layer = model.layer4[-1]
32
  cam = GradCAM(model=model, target_layers=[target_layer])
33
 
34
- # === IMAGE TRANSFORM ===
35
  transform = transforms.Compose([
36
  transforms.Resize((224, 224)),
37
  transforms.ToTensor(),
@@ -39,7 +39,7 @@ transform = transforms.Compose([
39
  [0.229, 0.224, 0.225])
40
  ])
41
 
42
- # === STORAGE ===
43
  image_folder = "collected_images"
44
  os.makedirs(image_folder, exist_ok=True)
45
 
@@ -49,7 +49,7 @@ if not os.path.exists(csv_log_path):
49
  writer = csv.writer(f)
50
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
51
 
52
- # === PREDICT FUNCTION ===
53
  def predict_retinopathy(image):
54
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
55
  img = image.convert("RGB").resize((224, 224))
@@ -70,19 +70,18 @@ def predict_retinopathy(image):
70
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
71
  cam_pil = Image.fromarray(cam_image)
72
 
73
- # Save uploaded image
74
  image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
75
  image_path = os.path.join(image_folder, image_filename)
76
  image.save(image_path)
77
 
78
- # Log prediction
79
  with open(csv_log_path, mode="a", newline="") as f:
80
  writer = csv.writer(f)
81
  writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])
82
 
83
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
84
 
85
- # === ADMIN FILES ===
86
  def download_csv():
87
  return csv_log_path
88
 
@@ -95,14 +94,13 @@ def download_dataset_zip():
95
  zipf.write(fpath, arcname=os.path.join("images", fname))
96
  return zip_filename
97
 
98
- # === VISIBILITY CHECK ===
99
  def is_admin(request: Request):
100
- query_params = dict(request.query_params)
101
- return query_params.get("admin", "") == ADMIN_KEY
102
 
103
- # === GRADIO APP ===
104
  with gr.Blocks() as demo:
105
- gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM + Private Logging")
106
 
107
  with gr.Row():
108
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
@@ -117,32 +115,26 @@ with gr.Blocks() as demo:
117
  outputs=[cam_output, prediction_output]
118
  )
119
 
 
120
  with gr.Column(visible=False) as admin_section:
121
- gr.Markdown("### 🔐 Admin Downloads")
122
  with gr.Row():
123
  download_csv_btn = gr.Button("📄 Download CSV Log")
124
- download_zip_btn = gr.Button("📦 Download Full Dataset")
125
  csv_file = gr.File()
126
  zip_file = gr.File()
127
 
128
- # Admin visibility only for Rodiyah with key
129
  demo.load(
130
  lambda req: gr.update(visible=True) if is_admin(req) else gr.update(visible=False),
131
- inputs=None,
132
  outputs=admin_section,
133
- queue=False
 
134
  )
135
 
136
- download_csv_btn.click(
137
- fn=download_csv,
138
- inputs=[],
139
- outputs=csv_file
140
- )
141
-
142
- download_zip_btn.click(
143
- fn=download_dataset_zip,
144
- inputs=[],
145
- outputs=zip_file
146
- )
147
 
148
  demo.launch()
 
 
14
  import zipfile
15
  from gradio.routes import Request
16
 
17
+ # 🔐 Secret key
18
  ADMIN_KEY = "Diabetes_Detection"
19
 
20
+ # Device
21
  device = torch.device("cpu")
22
 
23
+ # Load model
24
  model = models.resnet50(weights=None)
25
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
26
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
27
  model.to(device)
28
  model.eval()
29
 
30
+ # Grad-CAM setup
31
  target_layer = model.layer4[-1]
32
  cam = GradCAM(model=model, target_layers=[target_layer])
33
 
34
+ # Preprocessing
35
  transform = transforms.Compose([
36
  transforms.Resize((224, 224)),
37
  transforms.ToTensor(),
 
39
  [0.229, 0.224, 0.225])
40
  ])
41
 
42
+ # Folders
43
  image_folder = "collected_images"
44
  os.makedirs(image_folder, exist_ok=True)
45
 
 
49
  writer = csv.writer(f)
50
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
51
 
52
+ # 🔍 Prediction
53
  def predict_retinopathy(image):
54
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
55
  img = image.convert("RGB").resize((224, 224))
 
70
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
71
  cam_pil = Image.fromarray(cam_image)
72
 
73
+ # Save image + log
74
  image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
75
  image_path = os.path.join(image_folder, image_filename)
76
  image.save(image_path)
77
 
 
78
  with open(csv_log_path, mode="a", newline="") as f:
79
  writer = csv.writer(f)
80
  writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])
81
 
82
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
83
 
84
+ # 📁 Admin downloads
85
  def download_csv():
86
  return csv_log_path
87
 
 
94
  zipf.write(fpath, arcname=os.path.join("images", fname))
95
  return zip_filename
96
 
97
+ # Admin check (query param)
98
  def is_admin(request: Request):
99
+ return request.query_params.get("admin") == ADMIN_KEY
 
100
 
101
+ # 🌐 App
102
  with gr.Blocks() as demo:
103
+ gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")
104
 
105
  with gr.Row():
106
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
 
115
  outputs=[cam_output, prediction_output]
116
  )
117
 
118
+ # 🔒 Hidden admin section
119
  with gr.Column(visible=False) as admin_section:
120
+ gr.Markdown("### 🔐 Private Downloads (Rodiyah Only)")
121
  with gr.Row():
122
  download_csv_btn = gr.Button("📄 Download CSV Log")
123
+ download_zip_btn = gr.Button("📦 Download Dataset ZIP")
124
  csv_file = gr.File()
125
  zip_file = gr.File()
126
 
127
+ # Reveal only if correct ?admin=Diabetes_Detection in URL
128
  demo.load(
129
  lambda req: gr.update(visible=True) if is_admin(req) else gr.update(visible=False),
130
+ inputs=[],
131
  outputs=admin_section,
132
+ queue=False,
133
+ api_name=False,
134
  )
135
 
136
+ download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
137
+ download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)
 
 
 
 
 
 
 
 
 
138
 
139
  demo.launch()
140
+