File size: 4,285 Bytes
ae8f111
 
 
 
 
 
 
 
 
 
ab15865
7303661
 
ab15865
7303661
079be8c
dd8b58f
7a08276
ae8f111
 
8e84628
ae8f111
 
 
 
 
 
079be8c
ae8f111
 
 
079be8c
ae8f111
 
 
 
 
 
 
079be8c
ab15865
 
7303661
ab15865
 
079be8c
ab15865
 
7303661
079be8c
ae8f111
ab15865
ae8f111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1b2058
8e84628
ab15865
 
 
e1b2058
ab15865
 
 
e1b2058
ab15865
64f36cb
079be8c
ab15865
 
 
 
b1aa13e
 
ab15865
 
 
 
b1aa13e
ab15865
079be8c
 
 
 
dd8b58f
079be8c
64f36cb
8e84628
64f36cb
079be8c
 
64f36cb
 
 
 
 
ab15865
 
64f36cb
 
 
 
 
 
dd8b58f
079be8c
7a08276
 
8e84628
7a08276
 
 
079be8c
 
7a08276
8e84628
 
64f36cb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import os
import csv
import datetime
import zipfile

# Admin secret
ADMIN_KEY = "Diabetes_Detection"

device = torch.device("cpu")

# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()

# Grad-CAM
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])

# Preprocess
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Folders & logs
image_folder = "collected_images"
os.makedirs(image_folder, exist_ok=True)

csv_log_path = "prediction_logs.csv"
if not os.path.exists(csv_log_path):
    with open(csv_log_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])

# Prediction
def predict_retinopathy(image):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    img = image.convert("RGB").resize((224, 224))
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred].item()

    label = "Diabetic Retinopathy (DR)" if pred == 0 else "No DR"

    # Grad-CAM
    rgb_img_np = np.array(img).astype(np.float32) / 255.0
    rgb_img_np = np.ascontiguousarray(rgb_img_np)
    grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
    cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
    cam_pil = Image.fromarray(cam_image)

    # Save image + log
    image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
    image_path = os.path.join(image_folder, image_filename)
    image.save(image_path)

    with open(csv_log_path, mode="a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])

    return cam_pil, f"{label} (Confidence: {confidence:.2f})"

# Downloads
def download_csv():
    return csv_log_path

def download_dataset_zip():
    zip_filename = "dataset_bundle.zip"
    with zipfile.ZipFile(zip_filename, "w") as zipf:
        zipf.write(csv_log_path, arcname="prediction_logs.csv")
        for fname in os.listdir(image_folder):
            fpath = os.path.join(image_folder, fname)
            zipf.write(fpath, arcname=os.path.join("images", fname))
    return zip_filename

def check_admin(query_str):
    if f"admin={ADMIN_KEY}" in query_str:
        return gr.update(visible=True)
    return gr.update(visible=False)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")

    url_input = gr.Textbox(visible=False)  # Holds query string

    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Retinal Image")
        cam_output = gr.Image(type="pil", label="Grad-CAM")

    prediction_output = gr.Text(label="Prediction")
    run_button = gr.Button("Submit")

    run_button.click(
        fn=predict_retinopathy,
        inputs=image_input,
        outputs=[cam_output, prediction_output]
    )

    with gr.Column(visible=False) as admin_section:
        gr.Markdown("### 🔐 Admin Downloads (Private)")
        with gr.Row():
            download_csv_btn = gr.Button("📄 Download CSV Log")
            download_zip_btn = gr.Button("📦 Download Dataset ZIP")
        csv_file = gr.File()
        zip_file = gr.File()

    # Logic to reveal admin section
    url_input.change(fn=check_admin, inputs=url_input, outputs=admin_section)

    download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
    download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)

demo.launch()