File size: 4,549 Bytes
ae8f111
 
 
 
 
 
 
 
 
 
ab15865
7303661
 
ab15865
dd8b58f
7303661
8e84628
dd8b58f
7a08276
8e84628
ae8f111
 
8e84628
ae8f111
 
 
 
 
 
8e84628
ae8f111
 
 
8e84628
ae8f111
 
 
 
 
 
 
8e84628
ab15865
 
7303661
ab15865
 
 
 
 
7303661
8e84628
ae8f111
ab15865
ae8f111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1b2058
8e84628
ab15865
 
 
e1b2058
ab15865
 
 
e1b2058
ab15865
64f36cb
8e84628
ab15865
 
 
 
b1aa13e
 
ab15865
 
 
 
b1aa13e
ab15865
8e84628
dd8b58f
3222f21
dd8b58f
8e84628
64f36cb
8e84628
64f36cb
 
 
 
 
 
ab15865
 
64f36cb
 
 
 
 
 
8e84628
dd8b58f
8e84628
7a08276
 
8e84628
7a08276
 
 
8e84628
dd8b58f
3222f21
8e84628
dd8b58f
8e84628
 
3222f21
7a08276
 
8e84628
 
64f36cb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import os
import csv
import datetime
import zipfile
from gradio.routes import Request

# 🔐 Secret key
ADMIN_KEY = "Diabetes_Detection"

# Device
device = torch.device("cpu")

# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()

# Grad-CAM setup
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Folders
image_folder = "collected_images"
os.makedirs(image_folder, exist_ok=True)

csv_log_path = "prediction_logs.csv"
if not os.path.exists(csv_log_path):
    with open(csv_log_path, mode="w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])

# 🔍 Prediction
def predict_retinopathy(image):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    img = image.convert("RGB").resize((224, 224))
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred].item()

    label = "Diabetic Retinopathy (DR)" if pred == 0 else "No DR"

    # Grad-CAM
    rgb_img_np = np.array(img).astype(np.float32) / 255.0
    rgb_img_np = np.ascontiguousarray(rgb_img_np)
    grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
    cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
    cam_pil = Image.fromarray(cam_image)

    # Save image + log
    image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
    image_path = os.path.join(image_folder, image_filename)
    image.save(image_path)

    with open(csv_log_path, mode="a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])

    return cam_pil, f"{label} (Confidence: {confidence:.2f})"

# 📁 Admin downloads
def download_csv():
    return csv_log_path

def download_dataset_zip():
    zip_filename = "dataset_bundle.zip"
    with zipfile.ZipFile(zip_filename, "w") as zipf:
        zipf.write(csv_log_path, arcname="prediction_logs.csv")
        for fname in os.listdir(image_folder):
            fpath = os.path.join(image_folder, fname)
            zipf.write(fpath, arcname=os.path.join("images", fname))
    return zip_filename

# ✅ Admin check (query param)
def is_admin(request: Request):
    return request and request.query_params.get("admin") == ADMIN_KEY

# 🌐 App
with gr.Blocks() as demo:
    gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")

    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Retinal Image")
        cam_output = gr.Image(type="pil", label="Grad-CAM")

    prediction_output = gr.Text(label="Prediction")
    run_button = gr.Button("Submit")

    run_button.click(
        fn=predict_retinopathy,
        inputs=image_input,
        outputs=[cam_output, prediction_output]
    )

    # 🔒 Hidden admin section
    with gr.Column(visible=False) as admin_section:
        gr.Markdown("### 🔐 Private Downloads (Rodiyah Only)")
        with gr.Row():
            download_csv_btn = gr.Button("📄 Download CSV Log")
            download_zip_btn = gr.Button("📦 Download Dataset ZIP")
        csv_file = gr.File()
        zip_file = gr.File()

    # ✅ Reveal only if correct ?admin=Diabetes_Detection in URL
    demo.load(
        fn=lambda req: gr.update(visible=True) if is_admin(req) else gr.update(visible=False),
        inputs=[],
        outputs=admin_section,
        queue=False,
        api_name=False,
        request=True  # ✅ Required to pass HTTP request into lambda
    )

    download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
    download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)

demo.launch()