Rodiyah commited on
Commit
2d888c0
·
verified ·
1 Parent(s): 1c7b662

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -35
app.py CHANGED
@@ -8,25 +8,28 @@ from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
- import csv
12
- import datetime
13
  import os
 
 
14
 
15
- # Set device
16
  device = torch.device("cpu")
 
 
 
17
 
18
- # Load model
19
  model = models.resnet50(weights=None)
20
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
21
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
22
  model.to(device)
23
  model.eval()
24
 
25
- # Grad-CAM setup
26
  target_layer = model.layer4[-1]
27
  cam = GradCAM(model=model, target_layers=[target_layer])
28
 
29
- # Image preprocessing
30
  transform = transforms.Compose([
31
  transforms.Resize((224, 224)),
32
  transforms.ToTensor(),
@@ -34,21 +37,35 @@ transform = transforms.Compose([
34
  [0.229, 0.224, 0.225])
35
  ])
36
 
37
- # Logging setup
38
- log_path = "prediction_logs.csv"
39
-
40
- def log_prediction(filename, prediction, confidence):
41
- timestamp = datetime.datetime.now().isoformat()
42
- row = [timestamp, filename, prediction, f"{confidence:.4f}"]
43
-
44
- print("⏺ Logging prediction:", row) # 🔍 Add this line
45
-
46
- with open(log_path, mode='a', newline='') as file:
47
- writer = csv.writer(file)
48
- writer.writerow(row)
49
-
50
- # Prediction function
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def predict_retinopathy(image):
 
52
  img = image.convert("RGB").resize((224, 224))
53
  img_tensor = transform(img).unsqueeze(0).to(device)
54
 
@@ -65,22 +82,32 @@ def predict_retinopathy(image):
65
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
66
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
67
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
 
68
 
69
- # Logging
70
- filename = getattr(image, "filename", "uploaded_image")
71
- log_prediction(filename, label, confidence)
 
 
 
72
 
73
- cam_pil = Image.fromarray(cam_image)
74
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
75
 
76
- # Gradio interface
77
- gr.Interface(
78
- fn=predict_retinopathy,
79
- inputs=gr.Image(type="pil"),
80
- outputs=[
81
- gr.Image(type="pil", label="Grad-CAM"),
82
- gr.Text(label="Prediction")
83
- ],
84
- title="Diabetic Retinopathy Detection",
85
- description="Upload a retinal image to classify DR and view Grad-CAM heatmap. All predictions are logged for analysis."
86
- ).launch()
 
 
 
 
 
 
 
 
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
 
 
11
  import os
12
+ import datetime
13
+ import sqlite3
14
 
15
+ # === Setup paths and model ===
16
  device = torch.device("cpu")
17
+ ADMIN_KEY = "Diabetes_Detection"
18
+ image_folder = "collected_images"
19
+ os.makedirs(image_folder, exist_ok=True)
20
 
21
+ # === Load model ===
22
  model = models.resnet50(weights=None)
23
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
24
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
25
  model.to(device)
26
  model.eval()
27
 
28
+ # === Grad-CAM setup ===
29
  target_layer = model.layer4[-1]
30
  cam = GradCAM(model=model, target_layers=[target_layer])
31
 
32
+ # === Image transform ===
33
  transform = transforms.Compose([
34
  transforms.Resize((224, 224)),
35
  transforms.ToTensor(),
 
37
  [0.229, 0.224, 0.225])
38
  ])
39
 
40
+ # === SQLite setup ===
41
+ def init_db():
42
+ conn = sqlite3.connect("logs.db")
43
+ cursor = conn.cursor()
44
+ cursor.execute("""
45
+ CREATE TABLE IF NOT EXISTS predictions (
46
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
47
+ timestamp TEXT,
48
+ filename TEXT,
49
+ prediction TEXT,
50
+ confidence REAL
51
+ )
52
+ """)
53
+ conn.commit()
54
+ conn.close()
55
+
56
+ def log_to_db(timestamp, filename, prediction, confidence):
57
+ conn = sqlite3.connect("logs.db")
58
+ cursor = conn.cursor()
59
+ cursor.execute("INSERT INTO predictions (timestamp, filename, prediction, confidence) VALUES (?, ?, ?, ?)",
60
+ (timestamp, filename, prediction, confidence))
61
+ conn.commit()
62
+ conn.close()
63
+
64
+ init_db() # ✅ Initialize table
65
+
66
+ # === Prediction Function ===
67
  def predict_retinopathy(image):
68
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
69
  img = image.convert("RGB").resize((224, 224))
70
  img_tensor = transform(img).unsqueeze(0).to(device)
71
 
 
82
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
83
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
84
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
85
+ cam_pil = Image.fromarray(cam_image)
86
 
87
+ # Save image and log
88
+ filename = f"{timestamp}_{label.replace(' ', '_')}.png"
89
+ image_path = os.path.join(image_folder, filename)
90
+ image.save(image_path)
91
+
92
+ log_to_db(timestamp, image_path, label, confidence)
93
 
 
94
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
95
 
96
+ # === Gradio Interface ===
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown("## 🧠 DR Detection with Grad-CAM + SQLite Logging")
99
+
100
+ with gr.Row():
101
+ image_input = gr.Image(type="pil", label="Upload Retinal Image")
102
+ cam_output = gr.Image(type="pil", label="Grad-CAM")
103
+
104
+ prediction_output = gr.Text(label="Prediction")
105
+ run_button = gr.Button("Submit")
106
+
107
+ run_button.click(
108
+ fn=predict_retinopathy,
109
+ inputs=image_input,
110
+ outputs=[cam_output, prediction_output]
111
+ )
112
+
113
+ demo.launch()