SuriRaja commited on
Commit
2233dcb
·
verified ·
1 Parent(s): 46cc277

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +12 -14
model.py CHANGED
@@ -4,29 +4,27 @@ import torch
4
  import numpy as np
5
  import cv2
6
 
7
- # Load model and processor
8
- processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
9
- model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
10
 
11
  def predict_defect(image: Image.Image):
12
- # Convert image to numpy array to keep original size
13
  original = np.array(image)
14
  inputs = processor(images=image, return_tensors="pt")
15
-
16
  with torch.no_grad():
17
  outputs = model(**inputs)
18
 
19
  logits = outputs.logits
20
  segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
21
-
22
- # Resize segmentation mask to match original image size
23
  resized_mask = cv2.resize(segmentation.astype(np.uint8), (original.shape[1], original.shape[0]), interpolation=cv2.INTER_NEAREST)
24
-
25
- # For example, highlight class index 12 (customize as needed)
26
- mask = (resized_mask == 12)
27
-
28
- # Overlay red color on defect areas
29
- overlay = original.copy()
30
- overlay[mask] = [255, 0, 0] # Red for defects
31
 
 
 
 
 
 
 
 
32
  return Image.fromarray(overlay)
 
4
  import numpy as np
5
  import cv2
6
 
7
+ # 🔁 Replace this with a model trained for road defects
8
+ processor = AutoImageProcessor.from_pretrained("segments/DeepLabV3")
9
+ model = AutoModelForSemanticSegmentation.from_pretrained("segments/DeepLabV3")
10
 
11
  def predict_defect(image: Image.Image):
 
12
  original = np.array(image)
13
  inputs = processor(images=image, return_tensors="pt")
 
14
  with torch.no_grad():
15
  outputs = model(**inputs)
16
 
17
  logits = outputs.logits
18
  segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
19
+
20
+ # Resize mask to original image size
21
  resized_mask = cv2.resize(segmentation.astype(np.uint8), (original.shape[1], original.shape[0]), interpolation=cv2.INTER_NEAREST)
 
 
 
 
 
 
 
22
 
23
+ # 📌 NOTE: Update the label index below based on your dataset
24
+ road_defect_label_index = 1 # Assume 1 represents cracks/potholes
25
+ mask = (resized_mask == road_defect_label_index)
26
+
27
+ # Overlay red where defect detected
28
+ overlay = original.copy()
29
+ overlay[mask] = [255, 0, 0]
30
  return Image.fromarray(overlay)