File size: 4,549 Bytes
ae8f111 ab15865 7303661 ab15865 dd8b58f 7303661 8e84628 dd8b58f 7a08276 8e84628 ae8f111 8e84628 ae8f111 8e84628 ae8f111 8e84628 ae8f111 8e84628 ab15865 7303661 ab15865 7303661 8e84628 ae8f111 ab15865 ae8f111 e1b2058 8e84628 ab15865 e1b2058 ab15865 e1b2058 ab15865 64f36cb 8e84628 ab15865 b1aa13e ab15865 b1aa13e ab15865 8e84628 dd8b58f 3222f21 dd8b58f 8e84628 64f36cb 8e84628 64f36cb ab15865 64f36cb 8e84628 dd8b58f 8e84628 7a08276 8e84628 7a08276 8e84628 dd8b58f 3222f21 8e84628 dd8b58f 8e84628 3222f21 7a08276 8e84628 64f36cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import os
import csv
import datetime
import zipfile
from gradio.routes import Request
# 🔐 Secret key
ADMIN_KEY = "Diabetes_Detection"
# Device
device = torch.device("cpu")
# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()
# Grad-CAM setup
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Folders
image_folder = "collected_images"
os.makedirs(image_folder, exist_ok=True)
csv_log_path = "prediction_logs.csv"
if not os.path.exists(csv_log_path):
with open(csv_log_path, mode="w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
# 🔍 Prediction
def predict_retinopathy(image):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img = image.convert("RGB").resize((224, 224))
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
probs = F.softmax(output, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
label = "Diabetic Retinopathy (DR)" if pred == 0 else "No DR"
# Grad-CAM
rgb_img_np = np.array(img).astype(np.float32) / 255.0
rgb_img_np = np.ascontiguousarray(rgb_img_np)
grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image)
# Save image + log
image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
image_path = os.path.join(image_folder, image_filename)
image.save(image_path)
with open(csv_log_path, mode="a", newline="") as f:
writer = csv.writer(f)
writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])
return cam_pil, f"{label} (Confidence: {confidence:.2f})"
# 📁 Admin downloads
def download_csv():
return csv_log_path
def download_dataset_zip():
zip_filename = "dataset_bundle.zip"
with zipfile.ZipFile(zip_filename, "w") as zipf:
zipf.write(csv_log_path, arcname="prediction_logs.csv")
for fname in os.listdir(image_folder):
fpath = os.path.join(image_folder, fname)
zipf.write(fpath, arcname=os.path.join("images", fname))
return zip_filename
# ✅ Admin check (query param)
def is_admin(request: Request):
return request and request.query_params.get("admin") == ADMIN_KEY
# 🌐 App
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Retinal Image")
cam_output = gr.Image(type="pil", label="Grad-CAM")
prediction_output = gr.Text(label="Prediction")
run_button = gr.Button("Submit")
run_button.click(
fn=predict_retinopathy,
inputs=image_input,
outputs=[cam_output, prediction_output]
)
# 🔒 Hidden admin section
with gr.Column(visible=False) as admin_section:
gr.Markdown("### 🔐 Private Downloads (Rodiyah Only)")
with gr.Row():
download_csv_btn = gr.Button("📄 Download CSV Log")
download_zip_btn = gr.Button("📦 Download Dataset ZIP")
csv_file = gr.File()
zip_file = gr.File()
# ✅ Reveal only if correct ?admin=Diabetes_Detection in URL
demo.load(
fn=lambda req: gr.update(visible=True) if is_admin(req) else gr.update(visible=False),
inputs=[],
outputs=admin_section,
queue=False,
api_name=False,
request=True # ✅ Required to pass HTTP request into lambda
)
download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)
demo.launch()
|