File size: 4,043 Bytes
ae8f111 ab15865 7303661 ab15865 7303661 ae8f111 ab15865 ae8f111 ab15865 7303661 ab15865 7303661 ae8f111 ab15865 ae8f111 e1b2058 ab15865 e1b2058 ab15865 e1b2058 ab15865 64f36cb ab15865 64f36cb ab15865 64f36cb ab15865 64f36cb ab15865 64f36cb ab15865 64f36cb ab15865 64f36cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import os
import csv
import datetime
import io
import zipfile
# Set device
device = torch.device("cpu")
# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()
# Grad-CAM setup
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Folder to store uploaded images
image_folder = "collected_images"
os.makedirs(image_folder, exist_ok=True)
# CSV log file
csv_log_path = "prediction_logs.csv"
if not os.path.exists(csv_log_path):
with open(csv_log_path, mode="w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
# Prediction function
def predict_retinopathy(image):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img = image.convert("RGB").resize((224, 224))
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
probs = F.softmax(output, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
label = "Diabetic Retinopathy (DR)" if pred == 0 else "No DR"
# Grad-CAM
rgb_img_np = np.array(img).astype(np.float32) / 255.0
rgb_img_np = np.ascontiguousarray(rgb_img_np)
grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image)
# Save uploaded image
image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
image_path = os.path.join(image_folder, image_filename)
image.save(image_path)
# Log prediction
with open(csv_log_path, mode="a", newline="") as f:
writer = csv.writer(f)
writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])
return cam_pil, f"{label} (Confidence: {confidence:.2f})"
# Download logs
def download_csv():
return csv_log_path
# Zip dataset for download
def download_dataset_zip():
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zipf:
# Add CSV
zipf.write(csv_log_path, arcname="prediction_logs.csv")
# Add images
for fname in os.listdir(image_folder):
fpath = os.path.join(image_folder, fname)
zipf.write(fpath, arcname=os.path.join("images", fname))
zip_buffer.seek(0)
return zip_buffer
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🧠 DR Detection with Grad-CAM + Full Dataset Logging")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Retinal Image")
cam_output = gr.Image(type="pil", label="Grad-CAM")
prediction_output = gr.Text(label="Prediction")
run_button = gr.Button("Submit")
with gr.Row():
download_csv_btn = gr.Button("📄 Download CSV Log")
download_zip_btn = gr.Button("📦 Download Full Dataset")
csv_file = gr.File()
zip_file = gr.File()
run_button.click(
fn=predict_retinopathy,
inputs=image_input,
outputs=[cam_output, prediction_output]
)
download_csv_btn.click(
fn=download_csv,
inputs=[],
outputs=csv_file
)
download_zip_btn.click(
fn=download_dataset_zip,
inputs=[],
outputs=zip_file
)
demo.launch()
|