File size: 4,043 Bytes
ae8f111
 
 
 
 
 
 
 
 
 
ab15865
7303661
 
ab15865
 
7303661
 
ae8f111
 
 
 
 
 
 
 
 
 
 
 
 
ab15865
ae8f111
 
 
 
 
 
 
ab15865
 
 
7303661
ab15865
 
 
 
 
 
7303661
 
ae8f111
ab15865
ae8f111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1b2058
ab15865
 
 
 
e1b2058
ab15865
 
 
 
e1b2058
ab15865
64f36cb
ab15865
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64f36cb
ab15865
64f36cb
 
 
 
 
 
 
ab15865
 
64f36cb
ab15865
 
 
 
64f36cb
 
 
 
 
 
 
ab15865
 
 
 
 
 
 
 
64f36cb
ab15865
64f36cb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import os
import csv
import datetime
import io
import zipfile

# Set device
device = torch.device("cpu")

# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()

# Grad-CAM setup
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Folder to store uploaded images
image_folder = "collected_images"
os.makedirs(image_folder, exist_ok=True)

# CSV log file
csv_log_path = "prediction_logs.csv"
if not os.path.exists(csv_log_path):
    with open(csv_log_path, mode="w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])

# Prediction function
def predict_retinopathy(image):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    img = image.convert("RGB").resize((224, 224))
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred].item()

    label = "Diabetic Retinopathy (DR)" if pred == 0 else "No DR"

    # Grad-CAM
    rgb_img_np = np.array(img).astype(np.float32) / 255.0
    rgb_img_np = np.ascontiguousarray(rgb_img_np)
    grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
    cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
    cam_pil = Image.fromarray(cam_image)

    # Save uploaded image
    image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
    image_path = os.path.join(image_folder, image_filename)
    image.save(image_path)

    # Log prediction
    with open(csv_log_path, mode="a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])

    return cam_pil, f"{label} (Confidence: {confidence:.2f})"

# Download logs
def download_csv():
    return csv_log_path

# Zip dataset for download
def download_dataset_zip():
    zip_buffer = io.BytesIO()
    with zipfile.ZipFile(zip_buffer, "w") as zipf:
        # Add CSV
        zipf.write(csv_log_path, arcname="prediction_logs.csv")
        # Add images
        for fname in os.listdir(image_folder):
            fpath = os.path.join(image_folder, fname)
            zipf.write(fpath, arcname=os.path.join("images", fname))
    zip_buffer.seek(0)
    return zip_buffer

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🧠 DR Detection with Grad-CAM + Full Dataset Logging")

    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Retinal Image")
        cam_output = gr.Image(type="pil", label="Grad-CAM")

    prediction_output = gr.Text(label="Prediction")

    run_button = gr.Button("Submit")

    with gr.Row():
        download_csv_btn = gr.Button("📄 Download CSV Log")
        download_zip_btn = gr.Button("📦 Download Full Dataset")
        csv_file = gr.File()
        zip_file = gr.File()

    run_button.click(
        fn=predict_retinopathy,
        inputs=image_input,
        outputs=[cam_output, prediction_output]
    )

    download_csv_btn.click(
        fn=download_csv,
        inputs=[],
        outputs=csv_file
    )

    download_zip_btn.click(
        fn=download_dataset_zip,
        inputs=[],
        outputs=zip_file
    )

demo.launch()