Rodiyah commited on
Commit
3dee463
·
verified ·
1 Parent(s): 2d888c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -59
app.py CHANGED
@@ -8,28 +8,25 @@ from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
- import os
12
  import datetime
13
- import sqlite3
14
 
15
- # === Setup paths and model ===
16
  device = torch.device("cpu")
17
- ADMIN_KEY = "Diabetes_Detection"
18
- image_folder = "collected_images"
19
- os.makedirs(image_folder, exist_ok=True)
20
 
21
- # === Load model ===
22
  model = models.resnet50(weights=None)
23
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
24
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
25
  model.to(device)
26
  model.eval()
27
 
28
- # === Grad-CAM setup ===
29
  target_layer = model.layer4[-1]
30
  cam = GradCAM(model=model, target_layers=[target_layer])
31
 
32
- # === Image transform ===
33
  transform = transforms.Compose([
34
  transforms.Resize((224, 224)),
35
  transforms.ToTensor(),
@@ -37,35 +34,21 @@ transform = transforms.Compose([
37
  [0.229, 0.224, 0.225])
38
  ])
39
 
40
- # === SQLite setup ===
41
- def init_db():
42
- conn = sqlite3.connect("logs.db")
43
- cursor = conn.cursor()
44
- cursor.execute("""
45
- CREATE TABLE IF NOT EXISTS predictions (
46
- id INTEGER PRIMARY KEY AUTOINCREMENT,
47
- timestamp TEXT,
48
- filename TEXT,
49
- prediction TEXT,
50
- confidence REAL
51
- )
52
- """)
53
- conn.commit()
54
- conn.close()
55
-
56
- def log_to_db(timestamp, filename, prediction, confidence):
57
- conn = sqlite3.connect("logs.db")
58
- cursor = conn.cursor()
59
- cursor.execute("INSERT INTO predictions (timestamp, filename, prediction, confidence) VALUES (?, ?, ?, ?)",
60
- (timestamp, filename, prediction, confidence))
61
- conn.commit()
62
- conn.close()
63
-
64
- init_db() # ✅ Initialize table
65
-
66
- # === Prediction Function ===
67
  def predict_retinopathy(image):
68
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
69
  img = image.convert("RGB").resize((224, 224))
70
  img_tensor = transform(img).unsqueeze(0).to(device)
71
 
@@ -82,32 +65,26 @@ def predict_retinopathy(image):
82
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
83
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
84
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
85
- cam_pil = Image.fromarray(cam_image)
86
-
87
- # Save image and log
88
- filename = f"{timestamp}_{label.replace(' ', '_')}.png"
89
- image_path = os.path.join(image_folder, filename)
90
- image.save(image_path)
91
 
92
- log_to_db(timestamp, image_path, label, confidence)
 
 
93
 
 
94
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
95
 
96
- # === Gradio Interface ===
97
- with gr.Blocks() as demo:
98
- gr.Markdown("## 🧠 DR Detection with Grad-CAM + SQLite Logging")
99
-
100
- with gr.Row():
101
- image_input = gr.Image(type="pil", label="Upload Retinal Image")
102
- cam_output = gr.Image(type="pil", label="Grad-CAM")
103
-
104
- prediction_output = gr.Text(label="Prediction")
105
- run_button = gr.Button("Submit")
106
-
107
- run_button.click(
108
- fn=predict_retinopathy,
109
- inputs=image_input,
110
- outputs=[cam_output, prediction_output]
111
  )
112
 
113
  demo.launch()
 
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
 
11
+ import csv
12
  import datetime
13
+ import os
14
 
15
+ # Set device
16
  device = torch.device("cpu")
 
 
 
17
 
18
+ # Load model
19
  model = models.resnet50(weights=None)
20
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
21
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
22
  model.to(device)
23
  model.eval()
24
 
25
+ # Grad-CAM setup
26
  target_layer = model.layer4[-1]
27
  cam = GradCAM(model=model, target_layers=[target_layer])
28
 
29
+ # Image preprocessing
30
  transform = transforms.Compose([
31
  transforms.Resize((224, 224)),
32
  transforms.ToTensor(),
 
34
  [0.229, 0.224, 0.225])
35
  ])
36
 
37
+ # Logging setup
38
+ log_path = "prediction_logs.csv"
39
+
40
+ def log_prediction(filename, prediction, confidence):
41
+ timestamp = datetime.datetime.now().isoformat()
42
+ row = [timestamp, filename, prediction, f"{confidence:.4f}"]
43
+
44
+ print("⏺ Logging prediction:", row) # 🔍 Add this line
45
+
46
+ with open(log_path, mode='a', newline='') as file:
47
+ writer = csv.writer(file)
48
+ writer.writerow(row)
49
+
50
+ # Prediction function
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def predict_retinopathy(image):
 
52
  img = image.convert("RGB").resize((224, 224))
53
  img_tensor = transform(img).unsqueeze(0).to(device)
54
 
 
65
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
66
  grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
67
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
 
 
 
 
 
 
68
 
69
+ # Logging
70
+ filename = getattr(image, "filename", "uploaded_image")
71
+ log_prediction(filename, label, confidence)
72
 
73
+ cam_pil = Image.fromarray(cam_image)
74
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
75
 
76
+ # Gradio interface
77
+ gr.Interface(
78
+ fn=predict_retinopathy,
79
+ inputs=gr.Image(type="pil"),
80
+ outputs=[
81
+ gr.Image(type="pil", label="Grad-CAM"),
82
+ gr.Text(label="Prediction")
83
+ ],
84
+ title="Diabetic Retinopathy Detection",
85
+ description="Upload a retinal image to classify DR and view Grad-CAM heatmap. All predictions are logged for analysis."
86
+ ).launch()
87
+ s=[cam_output, prediction_output]
 
 
 
88
  )
89
 
90
  demo.launch()