Spaces:
Sleeping
Sleeping
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation | |
from PIL import Image | |
import torch | |
import numpy as np | |
import cv2 | |
processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
ROAD_LABELS = [0, 8] # class indices to consider as road | |
def predict_defect(image: Image.Image): | |
original = np.array(image) | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy() | |
resized_mask = cv2.resize(segmentation.astype(np.uint8), (original.shape[1], original.shape[0]), interpolation=cv2.INTER_NEAREST) | |
# Mark only suspicious areas (non-road) unless it's oversegmenting | |
defect_mask = ~np.isin(resized_mask, ROAD_LABELS) | |
if np.sum(defect_mask) / resized_mask.size > 0.4: | |
defect_mask[:] = False | |
overlay = original.copy() | |
overlay[defect_mask] = [255, 0, 0] | |
return Image.fromarray(overlay) |