Rodiyah commited on
Commit
ac84363
·
verified ·
1 Parent(s): 079be8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -21
app.py CHANGED
@@ -13,9 +13,10 @@ import csv
13
  import datetime
14
  import zipfile
15
 
16
- # Admin secret
17
  ADMIN_KEY = "Diabetes_Detection"
18
 
 
19
  device = torch.device("cpu")
20
 
21
  # Load model
@@ -25,11 +26,11 @@ model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=devi
25
  model.to(device)
26
  model.eval()
27
 
28
- # Grad-CAM
29
  target_layer = model.layer4[-1]
30
  cam = GradCAM(model=model, target_layers=[target_layer])
31
 
32
- # Preprocess
33
  transform = transforms.Compose([
34
  transforms.Resize((224, 224)),
35
  transforms.ToTensor(),
@@ -37,7 +38,7 @@ transform = transforms.Compose([
37
  [0.229, 0.224, 0.225])
38
  ])
39
 
40
- # Folders & logs
41
  image_folder = "collected_images"
42
  os.makedirs(image_folder, exist_ok=True)
43
 
@@ -47,7 +48,7 @@ if not os.path.exists(csv_log_path):
47
  writer = csv.writer(f)
48
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
49
 
50
- # Prediction
51
  def predict_retinopathy(image):
52
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
53
  img = image.convert("RGB").resize((224, 224))
@@ -68,7 +69,7 @@ def predict_retinopathy(image):
68
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
69
  cam_pil = Image.fromarray(cam_image)
70
 
71
- # Save image + log
72
  image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
73
  image_path = os.path.join(image_folder, image_filename)
74
  image.save(image_path)
@@ -79,7 +80,13 @@ def predict_retinopathy(image):
79
 
80
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
81
 
82
- # Downloads
 
 
 
 
 
 
83
  def download_csv():
84
  return csv_log_path
85
 
@@ -92,20 +99,13 @@ def download_dataset_zip():
92
  zipf.write(fpath, arcname=os.path.join("images", fname))
93
  return zip_filename
94
 
95
- def check_admin(query_str):
96
- if f"admin={ADMIN_KEY}" in query_str:
97
- return gr.update(visible=True)
98
- return gr.update(visible=False)
99
-
100
- # Gradio UI
101
  with gr.Blocks() as demo:
102
  gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")
103
 
104
- url_input = gr.Textbox(visible=False) # Holds query string
105
-
106
  with gr.Row():
107
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
108
- cam_output = gr.Image(type="pil", label="Grad-CAM")
109
 
110
  prediction_output = gr.Text(label="Prediction")
111
  run_button = gr.Button("Submit")
@@ -116,16 +116,24 @@ with gr.Blocks() as demo:
116
  outputs=[cam_output, prediction_output]
117
  )
118
 
119
- with gr.Column(visible=False) as admin_section:
120
- gr.Markdown("### 🔐 Admin Downloads (Private)")
 
 
 
 
 
121
  with gr.Row():
122
  download_csv_btn = gr.Button("📄 Download CSV Log")
123
- download_zip_btn = gr.Button("📦 Download Dataset ZIP")
124
  csv_file = gr.File()
125
  zip_file = gr.File()
126
 
127
- # Logic to reveal admin section
128
- url_input.change(fn=check_admin, inputs=url_input, outputs=admin_section)
 
 
 
129
 
130
  download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
131
  download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)
 
13
  import datetime
14
  import zipfile
15
 
16
+ # Admin key (hidden until typed)
17
  ADMIN_KEY = "Diabetes_Detection"
18
 
19
+ # Set device
20
  device = torch.device("cpu")
21
 
22
  # Load model
 
26
  model.to(device)
27
  model.eval()
28
 
29
+ # Grad-CAM setup
30
  target_layer = model.layer4[-1]
31
  cam = GradCAM(model=model, target_layers=[target_layer])
32
 
33
+ # Image preprocessing
34
  transform = transforms.Compose([
35
  transforms.Resize((224, 224)),
36
  transforms.ToTensor(),
 
38
  [0.229, 0.224, 0.225])
39
  ])
40
 
41
+ # Data storage
42
  image_folder = "collected_images"
43
  os.makedirs(image_folder, exist_ok=True)
44
 
 
48
  writer = csv.writer(f)
49
  writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
50
 
51
+ # Prediction function
52
  def predict_retinopathy(image):
53
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
54
  img = image.convert("RGB").resize((224, 224))
 
69
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
70
  cam_pil = Image.fromarray(cam_image)
71
 
72
+ # Save image & log
73
  image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
74
  image_path = os.path.join(image_folder, image_filename)
75
  image.save(image_path)
 
80
 
81
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
82
 
83
+ # Admin unlock
84
+ def unlock_admin(key_input):
85
+ if key_input == ADMIN_KEY:
86
+ return gr.update(visible=True)
87
+ return gr.update(visible=False)
88
+
89
+ # Download functions
90
  def download_csv():
91
  return csv_log_path
92
 
 
99
  zipf.write(fpath, arcname=os.path.join("images", fname))
100
  return zip_filename
101
 
102
+ # UI
 
 
 
 
 
103
  with gr.Blocks() as demo:
104
  gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")
105
 
 
 
106
  with gr.Row():
107
  image_input = gr.Image(type="pil", label="Upload Retinal Image")
108
+ cam_output = gr.Image(type="pil", label="Grad-CAM Output")
109
 
110
  prediction_output = gr.Text(label="Prediction")
111
  run_button = gr.Button("Submit")
 
116
  outputs=[cam_output, prediction_output]
117
  )
118
 
119
+ gr.Markdown("### 🔐 Admin Access (Rodiyah only)")
120
+
121
+ admin_key_input = gr.Text(label="Enter Admin Key", type="password", placeholder="Only Rodiyah knows this!")
122
+ unlock_button = gr.Button("Unlock Downloads")
123
+
124
+ with gr.Column(visible=False) as admin_panel:
125
+ gr.Markdown("### ✅ Download Panel (Private Access)")
126
  with gr.Row():
127
  download_csv_btn = gr.Button("📄 Download CSV Log")
128
+ download_zip_btn = gr.Button("📦 Download Full Dataset")
129
  csv_file = gr.File()
130
  zip_file = gr.File()
131
 
132
+ unlock_button.click(
133
+ fn=unlock_admin,
134
+ inputs=admin_key_input,
135
+ outputs=admin_panel
136
+ )
137
 
138
  download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
139
  download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)