File size: 2,531 Bytes
ae8f111 412ca3a ae8f111 300c3a7 18a45f0 412ca3a ae8f111 3dee463 ae8f111 412ca3a ae8f111 412ca3a ae8f111 412ca3a ae8f111 412ca3a ae8f111 412ca3a ae8f111 3dee463 412ca3a ab15865 64f36cb 412ca3a 3dee463 3e7941e 3dee463 3e7941e 5dcad43 3dee463 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import os
import datetime
# Setup
device = torch.device("cpu")
save_dir = "/home/user/app/saved_predictions"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print("📁 Folder created:", save_dir)
os.makedirs(save_dir, exist_ok=True)
# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()
# Grad-CAM
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Predict and save
def predict_retinopathy(image):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img = image.convert("RGB").resize((224, 224))
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
probs = F.softmax(output, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
label = "DR" if pred == 0 else "NoDR"
# Grad-CAM
rgb_img_np = np.array(img).astype(np.float32) / 255.0
rgb_img_np = np.ascontiguousarray(rgb_img_np)
grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image)
# Save image with label and confidence
filename = f"{timestamp}_{label}_{confidence:.2f}.png"
cam_pil.save(os.path.join(save_dir, filename))
return cam_pil, f"{label} (Confidence: {confidence:.2f})"
# Gradio app
gr.Interface(
fn=predict_retinopathy,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Метод Grad-CAM"),
gr.Text(label="Вероятность ДР в %")
],
title="Диагностика диабетической ретинопатии",
description="Загрузите ОКТ и смотрите ИИ-карту Grad-CAM heatmap"
).launch()
|