Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Image Part Segmentation and Labeling Tool | |
This script segments images into meaningful parts using the Segment Anything Model (SAM) | |
and optionally removes backgrounds using BriaRMBG. It identifies, visualizes, and merges | |
different parts of objects in images. | |
Key features: | |
- Background removal with alpha channel preservation | |
- Automatic part segmentation with SAM | |
- Intelligent part merging for logical grouping | |
- Detection of parts that SAM might miss | |
- Splitting of disconnected parts into separate components | |
- Edge cleaning and smoothing of segmentations | |
- Visualization of segmented parts with clear labeling | |
""" | |
import os | |
import argparse | |
import numpy as np | |
import cv2 | |
import torch | |
from PIL import Image | |
from torchvision.transforms import functional as F | |
from torchvision import transforms | |
import torch.nn.functional as F_nn | |
from segment_anything import SamAutomaticMaskGenerator, build_sam | |
from modules.label_2d_mask.visualizer import Visualizer | |
# Minimum size threshold for considering a segment (in pixels) | |
size_th = 2000 | |
def get_mask(group_ids, image, ids=None, img_name=None, save_dir=None): | |
""" | |
Creates and saves a colored visualization of mask segments. | |
Args: | |
group_ids: Array of segment IDs for each pixel | |
image: Input image | |
ids: Identifier to append to output filename | |
img_name: Base name of the image for saving | |
Returns: | |
Array of segment IDs (unchanged, just for convenience) | |
""" | |
colored_mask = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8) | |
colored_mask[group_ids == -1] = [255, 255, 255] | |
unique_ids = np.unique(group_ids) | |
unique_ids = unique_ids[unique_ids >= 0] | |
for i, unique_id in enumerate(unique_ids): | |
color_r = (i * 50 + 80) % 256 | |
color_g = (i * 120 + 40) % 256 | |
color_b = (i * 180 + 20) % 256 | |
mask = (group_ids == unique_id) | |
colored_mask[mask] = [color_r, color_g, color_b] | |
mask_path = os.path.join(save_dir, f"{img_name}_mask_segments_{ids}.png") | |
cv2.imwrite(mask_path, cv2.cvtColor(colored_mask, cv2.COLOR_RGB2BGR)) | |
print(f"Saved mask segments visualization to {mask_path}") | |
return group_ids | |
def clean_segment_edges(group_ids): | |
""" | |
Clean up segment edges by applying morphological operations to each segment. | |
Args: | |
group_ids: Array of segment IDs for each pixel | |
Returns: | |
Cleaned array of segment IDs with smoother boundaries | |
""" | |
# Get unique segment IDs (excluding background -1) | |
unique_ids = np.unique(group_ids) | |
unique_ids = unique_ids[unique_ids >= 0] | |
# Create a clean group_ids array | |
cleaned_group_ids = np.full_like(group_ids, -1) # Start with all background | |
# Define kernel for morphological operations | |
kernel = np.ones((3, 3), np.uint8) | |
# Process each segment individually | |
for segment_id in unique_ids: | |
# Extract the mask for this segment | |
segment_mask = (group_ids == segment_id).astype(np.uint8) | |
# Apply morphological closing to smooth edges | |
smoothed_mask = cv2.morphologyEx(segment_mask, cv2.MORPH_CLOSE, kernel, iterations=1) | |
# Apply morphological opening to remove small isolated pixels | |
smoothed_mask = cv2.morphologyEx(smoothed_mask, cv2.MORPH_OPEN, kernel, iterations=1) | |
# Add this segment back to the cleaned result | |
cleaned_group_ids[smoothed_mask > 0] = segment_id | |
print(f"Cleaned edges for {len(unique_ids)} segments") | |
return cleaned_group_ids | |
def prepare_image(image, bg_color=None, rmbg_net=None): | |
image_size = (1024, 1024) | |
transform_image = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
input_images = transform_image(image).unsqueeze(0).to('cuda') | |
# Prediction | |
with torch.no_grad(): | |
preds = rmbg_net(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image.size) | |
image.putalpha(mask) | |
return image | |
def resize_and_pad_to_square(image, target_size=518): | |
""" | |
Resize image to have longest side equal to target_size and pad shorter side | |
to create a square image. | |
Args: | |
image: PIL image or numpy array | |
target_size: Target square size, defaults to 518 | |
Returns: | |
PIL Image resized and padded to square (target_size x target_size) | |
""" | |
# Ensure image is a PIL Image object | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if len(image.shape) == 3 and image.shape[2] == 3 else image) | |
# Get original dimensions | |
width, height = image.size | |
# Determine which dimension is longer | |
if width > height: | |
# Width is longer | |
new_width = target_size | |
new_height = int(height * (target_size / width)) | |
else: | |
# Height is longer | |
new_height = target_size | |
new_width = int(width * (target_size / height)) | |
# Resize image while maintaining aspect ratio | |
resized_image = image.resize((new_width, new_height), Image.LANCZOS) | |
# Create new square image with proper mode (with or without alpha channel) | |
mode = "RGBA" if image.mode == "RGBA" else "RGB" | |
background_color = (255, 255, 255, 0) if mode == "RGBA" else (255, 255, 255) | |
square_image = Image.new(mode, (target_size, target_size), background_color) | |
# Calculate position to paste resized image (centered) | |
paste_x = (target_size - new_width) // 2 | |
paste_y = (target_size - new_height) // 2 | |
# Paste resized image onto square background | |
if mode == "RGBA": | |
square_image.paste(resized_image, (paste_x, paste_y), resized_image) | |
else: | |
square_image.paste(resized_image, (paste_x, paste_y)) | |
return square_image | |
def split_disconnected_parts(group_ids, size_threshold=None): | |
""" | |
Split each part into separate parts if they contain disconnected regions. | |
Args: | |
group_ids: Array of segment IDs for each pixel | |
size_threshold: Minimum size threshold for considering a segment (in pixels). | |
If None, uses the global size_th variable. | |
Returns: | |
Updated array with each connected component having a unique ID | |
""" | |
# Use provided threshold or fall back to global variable | |
if size_threshold is None: | |
size_threshold = size_th | |
# Create a copy to hold the result | |
new_group_ids = np.full_like(group_ids, -1) # Start with all background | |
# Get unique part IDs (excluding background -1) | |
unique_ids = np.unique(group_ids) | |
unique_ids = unique_ids[unique_ids >= 0] | |
# Track the next available ID | |
next_id = 0 | |
total_split_regions = 0 | |
# For each existing part ID | |
for part_id in unique_ids: | |
# Extract the mask for this part | |
part_mask = (group_ids == part_id).astype(np.uint8) | |
# Find connected components within this part | |
num_labels, labels = cv2.connectedComponents(part_mask, connectivity=8) | |
if num_labels == 1: # Just background (0), no regions found | |
continue | |
if num_labels == 2: # One connected component (background + 1 region) | |
# Assign the original part's area to the next available ID | |
new_group_ids[labels == 1] = next_id | |
next_id += 1 | |
else: # Multiple disconnected components | |
split_count = 0 | |
print(f"Part {part_id} has {num_labels-1} disconnected regions, splitting...") | |
# For each connected component (skipping background label 0) | |
for label in range(1, num_labels): | |
region_mask = labels == label | |
region_size = np.sum(region_mask) | |
# Only include regions that are large enough | |
if region_size >= size_threshold / 5: # Using size threshold to avoid tiny fragments | |
new_group_ids[region_mask] = next_id | |
split_count += 1 | |
next_id += 1 | |
else: | |
print(f" Skipping small disconnected region ({region_size} pixels)") | |
total_split_regions += split_count | |
if total_split_regions > 0: | |
print(f"Split disconnected parts: original {len(unique_ids)} parts -> {next_id} connected parts") | |
else: | |
print("No parts needed splitting - all parts are already connected") | |
return new_group_ids | |
# ------------------------------------------------------- | |
# MAIN SEGMENTATION FUNCTION | |
# ------------------------------------------------------- | |
def get_sam_mask(image, mask_generator, visual, merge_groups=None, existing_group_ids=None, | |
check_undetected=True, rgba_image=None, img_name=None, skip_split=False, save_dir=None, size_threshold=None): | |
""" | |
Generate and process SAM masks for the image, with optional merging and undetected region detection. | |
Args: | |
size_threshold: Minimum size threshold for considering a segment (in pixels). | |
If None, uses the global size_th variable. | |
""" | |
# Use provided threshold or fall back to global variable | |
if size_threshold is None: | |
size_threshold = size_th | |
label_mode = '1' | |
anno_mode = ['Mask', 'Mark'] | |
exist_group = False | |
# Use existing group IDs if provided, otherwise generate new ones with SAM | |
if existing_group_ids is not None: | |
group_ids = existing_group_ids.copy() | |
group_counter = np.max(group_ids) + 1 | |
exist_group = True | |
else: | |
# Generate masks using SAM | |
masks = mask_generator.generate(image) | |
group_ids = np.full((image.shape[0], image.shape[1]), -1, dtype=int) | |
num_masks = len(masks) | |
group_counter = 0 | |
# Sort masks by area (largest first) | |
area_sorted_masks = sorted(masks, key=lambda x: x["area"], reverse=True) | |
# Create background mask if we have RGBA image | |
background_mask = None | |
if rgba_image is not None: | |
rgba_array = np.array(rgba_image) | |
if rgba_array.shape[2] == 4: | |
# Use alpha channel to create foreground/background mask | |
background_mask = rgba_array[:, :, 3] <= 10 # Areas with very low alpha are background | |
# First pass: assign original group IDs | |
for i in range(0, num_masks): | |
if area_sorted_masks[i]["area"] < size_threshold: | |
print(f"Skipping mask {i}, area too small: {area_sorted_masks[i]['area']} < {size_threshold}") | |
continue | |
mask = area_sorted_masks[i]["segmentation"] | |
# Check proportion of background pixels in this mask | |
if background_mask is not None: | |
# Calculate how many pixels in this mask are background | |
background_pixels_in_mask = np.sum(mask & background_mask) | |
mask_area = np.sum(mask) | |
background_ratio = background_pixels_in_mask / mask_area | |
# Skip mask if background proportion is too high (>10%) | |
if background_ratio > 0.1: | |
print(f" Skipping mask {i}, background ratio: {background_ratio:.2f}") | |
continue | |
# Assign group ID to this mask's pixels | |
group_ids[mask] = group_counter | |
print(f"Assigned mask {i} with area {area_sorted_masks[i]['area']} to group {group_counter}") | |
group_counter += 1 | |
# Split disconnected parts immediately after SAM segmentation | |
print("Splitting disconnected parts in initial segmentation...") | |
group_ids = split_disconnected_parts(group_ids, size_threshold) | |
# Update group counter after splitting | |
if np.max(group_ids) >= 0: | |
group_counter = np.max(group_ids) + 1 | |
print(f"After early splitting, now have {len(np.unique(group_ids))-1} regions (excluding background)") | |
# Check for undetected parts using RGBA information | |
if check_undetected and rgba_image is not None: | |
print("Checking for undetected parts using RGBA image...") | |
# Create a foreground mask from the alpha channel | |
rgba_array = np.array(rgba_image) | |
# Check if the image has an alpha channel | |
if rgba_array.shape[2] == 4: | |
print(f"Image has alpha channel, checking for undetected parts...") | |
# Use alpha channel to identify non-transparent pixels (foreground) | |
alpha_mask = rgba_array[:, :, 3] > 0 | |
# Create existing parts mask and dilate it | |
existing_parts_mask = (group_ids != -1) | |
kernel = np.ones((4, 4), np.uint8) | |
# Use larger kernel for faster dilation | |
large_kernel = np.ones((4, 4), np.uint8) | |
dilated_parts = cv2.dilate(existing_parts_mask.astype(np.uint8), large_kernel) | |
# Find undetected areas (foreground but not detected by SAM) | |
undetected_mask = alpha_mask & (~dilated_parts.astype(bool)) | |
# Process only if there are enough undetected pixels | |
if np.sum(undetected_mask) > size_threshold: | |
print(f"Found undetected parts with {np.sum(undetected_mask)} pixels") | |
# Find connected components in undetected regions | |
num_labels, labels = cv2.connectedComponents( | |
undetected_mask.astype(np.uint8), | |
connectivity=8 | |
) | |
print(f" Found {num_labels-1} initial regions") | |
# Use Union-Find data structure for efficient region merging | |
parent = list(range(num_labels)) | |
# Find with path compression | |
def find(x): | |
"""Find with path compression for Union-Find""" | |
if parent[x] != x: | |
parent[x] = find(parent[x]) | |
return parent[x] | |
# Union by rank/size | |
def union(x, y): | |
"""Union operation for Union-Find""" | |
root_x = find(x) | |
root_y = find(y) | |
if root_x != root_y: | |
# Use smaller ID as parent | |
if root_x < root_y: | |
parent[root_y] = root_x | |
else: | |
parent[root_x] = root_y | |
# Calculate areas for all regions at once | |
areas = np.bincount(labels.flatten())[1:] if num_labels > 1 else [] | |
# Filter regions by minimum size | |
valid_regions = np.where(areas >= size_threshold/5)[0] + 1 | |
# Barrier mask for connectivity checks | |
barrier_mask = existing_parts_mask | |
# Pre-compute dilated regions for all valid regions | |
dilated_regions = {} | |
for i in valid_regions: | |
region_mask = (labels == i).astype(np.uint8) | |
dilated_regions[i] = cv2.dilate(region_mask, kernel, iterations=2) | |
# Check for region merges based on proximity and overlap | |
for idx, i in enumerate(valid_regions[:-1]): | |
for j in valid_regions[idx+1:]: | |
# Check overlap between dilated regions | |
overlap = dilated_regions[i] & dilated_regions[j] | |
overlap_size = np.sum(overlap) | |
# Merge if significant overlap and not separated by existing parts | |
if overlap_size > 40 and not np.any(overlap & barrier_mask): | |
# Calculate overlap ratios | |
overlap_ratio_i = overlap_size / areas[i-1] | |
overlap_ratio_j = overlap_size / areas[j-1] | |
if max(overlap_ratio_i, overlap_ratio_j) > 0.03: | |
union(i, j) | |
print(f" Merging regions {i} and {j} (overlap: {overlap_size} px)") | |
# Apply the merging results to create merged labels | |
merged_labels = np.zeros_like(labels) | |
for label in range(1, num_labels): | |
merged_labels[labels == label] = find(label) | |
# Get unique merged regions | |
unique_merged_regions = np.unique(merged_labels[merged_labels > 0]) | |
print(f" After merging: {len(unique_merged_regions)} connected regions") | |
# Add regions to group_ids if they're large enough | |
group_counter_start = group_counter | |
for label in unique_merged_regions: | |
region_mask = merged_labels == label | |
region_size = np.sum(region_mask) | |
if region_size > size_threshold: | |
print(f" Adding region with ID {label} ({region_size} pixels) as group {group_counter}") | |
group_ids[region_mask] = group_counter | |
group_counter += 1 | |
else: | |
print(f" Skipping small region with ID {label} ({region_size} pixels < {size_threshold})") | |
print(f" Added {group_counter - group_counter_start} regions that weren't detected by SAM") | |
# Process edges for all new parts at once | |
if group_counter > group_counter_start: | |
print("Processing edges for newly detected parts...") | |
# Create combined mask for all new parts | |
new_parts_mask = np.zeros_like(group_ids, dtype=bool) | |
for part_id in range(group_counter_start, group_counter): | |
new_parts_mask |= (group_ids == part_id) | |
# Compute edges for all new parts at once | |
all_new_dilated = cv2.dilate(new_parts_mask.astype(np.uint8), kernel, iterations=1) | |
all_new_eroded = cv2.erode(new_parts_mask.astype(np.uint8), kernel, iterations=1) | |
all_new_edges = all_new_dilated.astype(bool) & (~all_new_eroded.astype(bool)) | |
print(f"Edge processing completed for {group_counter - group_counter_start} new parts") | |
# Save debug visualization of initial segmentation | |
if not exist_group: | |
get_mask(group_ids, image, ids=2, img_name=img_name, save_dir=save_dir) | |
# Merge groups if specified | |
if merge_groups is not None: | |
# Start with current group_ids | |
merged_group_ids = group_ids | |
# Preserve background regions | |
merged_group_ids[group_ids == -1] = -1 | |
# For each merge group, assign all pixels to the first ID in that group | |
for new_id, group in enumerate(merge_groups): | |
# Create a mask to include all original IDs in this group | |
group_mask = np.zeros_like(group_ids, dtype=bool) | |
orig_ids_first = group[0] | |
# Process each original ID | |
for orig_id in group: | |
# Get mask for this original ID | |
mask = (group_ids == orig_id) | |
pixels = np.sum(mask) | |
if pixels > 0: | |
print(f" Including original ID {orig_id} ({pixels} pixels)") | |
group_mask = group_mask | mask | |
else: | |
print(f" Warning: Original ID {orig_id} does not exist") | |
# Set all pixels in this group to the first ID in the group | |
if np.any(group_mask): | |
print(f" Merging {np.sum(group_mask)} pixels to ID {orig_ids_first}") | |
merged_group_ids[group_mask] = orig_ids_first | |
# Reassign IDs to be continuous from 0 | |
unique_ids = np.unique(merged_group_ids) | |
unique_ids = unique_ids[unique_ids != -1] # Exclude background | |
id_reassignment = {old_id: new_id for new_id, old_id in enumerate(unique_ids)} | |
# Create new array with reassigned IDs | |
new_group_ids = np.full_like(merged_group_ids, -1) # Start with all background | |
for old_id, new_id in id_reassignment.items(): | |
new_group_ids[merged_group_ids == old_id] = new_id | |
# Update merged_group_ids with continuous IDs | |
merged_group_ids = new_group_ids | |
print(f"ID reassignment complete: {len(id_reassignment)} groups now have sequential IDs from 0 to {len(id_reassignment)-1}") | |
# Replace original group IDs with merged result | |
group_ids = merged_group_ids | |
print(f"Merging complete, now have {len(np.unique(group_ids))-1} regions (excluding background)") | |
# Skip splitting disconnected parts if requested | |
if not skip_split: | |
# Split disconnected parts into separate parts | |
group_ids = split_disconnected_parts(group_ids, size_threshold) | |
print(f"After splitting disconnected parts, now have {len(np.unique(group_ids))-1} regions (excluding background)") | |
else: | |
# Always split disconnected parts for initial segmentation | |
group_ids = split_disconnected_parts(group_ids, size_threshold) | |
print(f"After splitting disconnected parts, now have {len(np.unique(group_ids))-1} regions (excluding background)") | |
# Create visualization with clear labeling | |
vis_mask = visual | |
# First draw background areas (ID -1) | |
background_mask = (group_ids == -1) | |
if np.any(background_mask): | |
vis_mask = visual.draw_binary_mask(background_mask, color=[1.0, 1.0, 1.0], alpha=0.0) | |
# Then draw each segment with unique colors and labels | |
for unique_id in np.unique(group_ids): | |
if unique_id == -1: # Skip background | |
continue | |
mask = (group_ids == unique_id) | |
# Calculate center point and area of this region | |
y_indices, x_indices = np.where(mask) | |
if len(y_indices) > 0 and len(x_indices) > 0: | |
area = len(y_indices) # Calculate region area | |
print(f"Labeling region {unique_id}, area: {area} pixels") | |
if area < 30: # Skip very small regions | |
continue | |
# Use different colors for different IDs to enhance visual distinction | |
color_r = (unique_id * 50 + 80) % 200 / 255.0 + 0.2 | |
color_g = (unique_id * 120 + 40) % 200 / 255.0 + 0.2 | |
color_b = (unique_id * 180 + 20) % 200 / 255.0 + 0.2 | |
color = [color_r, color_g, color_b] | |
# Adjust transparency based on area size | |
adaptive_alpha = min(0.3, max(0.1, 0.1 + area / 100000)) | |
# Extract edges of this region | |
kernel = np.ones((3, 3), np.uint8) | |
dilated = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1) | |
eroded = cv2.erode(mask.astype(np.uint8), kernel, iterations=1) | |
edge = dilated.astype(bool) & (~eroded.astype(bool)) | |
# Build label text | |
label = f"{unique_id}" | |
# First draw the main body of the region | |
vis_mask = visual.draw_binary_mask_with_number( | |
mask, | |
text=label, | |
label_mode=label_mode, | |
alpha=adaptive_alpha, | |
anno_mode=anno_mode, | |
color=color, | |
font_size=20 | |
) | |
# Enhance edges (add border effect for all parts) | |
edge_color = [min(c*1.3, 1.0) for c in color] # Slightly brighter edge color | |
vis_mask = visual.draw_binary_mask( | |
edge, | |
alpha=0.8, # Lower transparency for edges to make them more visible | |
color=edge_color | |
) | |
im = vis_mask.get_image() | |
return group_ids, im |