File size: 4,285 Bytes
ae8f111 ab15865 7303661 ab15865 7303661 079be8c dd8b58f 7a08276 ae8f111 8e84628 ae8f111 079be8c ae8f111 079be8c ae8f111 079be8c ab15865 7303661 ab15865 079be8c ab15865 7303661 079be8c ae8f111 ab15865 ae8f111 e1b2058 8e84628 ab15865 e1b2058 ab15865 e1b2058 ab15865 64f36cb 079be8c ab15865 b1aa13e ab15865 b1aa13e ab15865 079be8c dd8b58f 079be8c 64f36cb 8e84628 64f36cb 079be8c 64f36cb ab15865 64f36cb dd8b58f 079be8c 7a08276 8e84628 7a08276 079be8c 7a08276 8e84628 64f36cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import os
import csv
import datetime
import zipfile
# Admin secret
ADMIN_KEY = "Diabetes_Detection"
device = torch.device("cpu")
# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()
# Grad-CAM
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])
# Preprocess
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Folders & logs
image_folder = "collected_images"
os.makedirs(image_folder, exist_ok=True)
csv_log_path = "prediction_logs.csv"
if not os.path.exists(csv_log_path):
with open(csv_log_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "image_filename", "prediction", "confidence"])
# Prediction
def predict_retinopathy(image):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img = image.convert("RGB").resize((224, 224))
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
probs = F.softmax(output, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
label = "Diabetic Retinopathy (DR)" if pred == 0 else "No DR"
# Grad-CAM
rgb_img_np = np.array(img).astype(np.float32) / 255.0
rgb_img_np = np.ascontiguousarray(rgb_img_np)
grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image)
# Save image + log
image_filename = f"{timestamp}_{label.replace(' ', '_')}.png"
image_path = os.path.join(image_folder, image_filename)
image.save(image_path)
with open(csv_log_path, mode="a", newline="") as f:
writer = csv.writer(f)
writer.writerow([timestamp, image_filename, label, f"{confidence:.4f}"])
return cam_pil, f"{label} (Confidence: {confidence:.2f})"
# Downloads
def download_csv():
return csv_log_path
def download_dataset_zip():
zip_filename = "dataset_bundle.zip"
with zipfile.ZipFile(zip_filename, "w") as zipf:
zipf.write(csv_log_path, arcname="prediction_logs.csv")
for fname in os.listdir(image_folder):
fpath = os.path.join(image_folder, fname)
zipf.write(fpath, arcname=os.path.join("images", fname))
return zip_filename
def check_admin(query_str):
if f"admin={ADMIN_KEY}" in query_str:
return gr.update(visible=True)
return gr.update(visible=False)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Diabetic Retinopathy Detection with Grad-CAM")
url_input = gr.Textbox(visible=False) # Holds query string
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Retinal Image")
cam_output = gr.Image(type="pil", label="Grad-CAM")
prediction_output = gr.Text(label="Prediction")
run_button = gr.Button("Submit")
run_button.click(
fn=predict_retinopathy,
inputs=image_input,
outputs=[cam_output, prediction_output]
)
with gr.Column(visible=False) as admin_section:
gr.Markdown("### 🔐 Admin Downloads (Private)")
with gr.Row():
download_csv_btn = gr.Button("📄 Download CSV Log")
download_zip_btn = gr.Button("📦 Download Dataset ZIP")
csv_file = gr.File()
zip_file = gr.File()
# Logic to reveal admin section
url_input.change(fn=check_admin, inputs=url_input, outputs=admin_section)
download_csv_btn.click(fn=download_csv, inputs=[], outputs=csv_file)
download_zip_btn.click(fn=download_dataset_zip, inputs=[], outputs=zip_file)
demo.launch()
|