Update model.py
Browse files
model.py
CHANGED
@@ -13,8 +13,11 @@ def predict_defect(image: Image.Image):
|
|
13 |
outputs = model(**inputs)
|
14 |
logits = outputs.logits
|
15 |
segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
|
16 |
-
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
13 |
outputs = model(**inputs)
|
14 |
logits = outputs.logits
|
15 |
segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
|
16 |
+
|
17 |
+
# Overlay on original image
|
18 |
+
original = np.array(image).copy()
|
19 |
+
mask = (segmentation == 12) # Replace 12 with correct defect label
|
20 |
+
original[mask] = [255, 0, 0] # Red highlight for defects
|
21 |
+
|
22 |
+
return Image.fromarray(original)
|
23 |
+
|