SuriRaja commited on
Commit
f07ca33
·
verified ·
1 Parent(s): 72d54a5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +8 -5
model.py CHANGED
@@ -13,8 +13,11 @@ def predict_defect(image: Image.Image):
13
  outputs = model(**inputs)
14
  logits = outputs.logits
15
  segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
16
-
17
- # Convert to RGB overlay
18
- overlay = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
19
- overlay[segmentation == 12] = [255, 0, 0] # example label index for defects (adjust accordingly)
20
- return Image.fromarray(overlay)
 
 
 
 
13
  outputs = model(**inputs)
14
  logits = outputs.logits
15
  segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
16
+
17
+ # Overlay on original image
18
+ original = np.array(image).copy()
19
+ mask = (segmentation == 12) # Replace 12 with correct defect label
20
+ original[mask] = [255, 0, 0] # Red highlight for defects
21
+
22
+ return Image.fromarray(original)
23
+