SegMatch / human_parts_segmentation.py
skallewag's picture
Upload 25 files
62cc23b verified
import os
import numpy as np
import torch
from PIL import Image, ImageFilter
import cv2
import requests
from typing import Dict, List, Tuple, Optional
import onnxruntime as ort
# Human parts labels based on CCIHP dataset - consistent with latest repo
HUMAN_PARTS_LABELS = {
0: ("background", "Background"),
1: ("hat", "Hat: Hat, helmet, cap, hood, veil, headscarf, part covering the skull and hair of a hood/balaclava, crown…"),
2: ("hair", "Hair"),
3: ("glove", "Glove"),
4: ("glasses", "Sunglasses/Glasses: Sunglasses, eyewear, protective glasses…"),
5: ("upper_clothes", "UpperClothes: T-shirt, shirt, tank top, sweater under a coat, top of a dress…"),
6: ("face_mask", "Face Mask: Protective mask, surgical mask, carnival mask, facial part of a balaclava, visor of a helmet…"),
7: ("coat", "Coat: Coat, jacket worn without anything on it, vest with nothing on it, a sweater with nothing on it…"),
8: ("socks", "Socks"),
9: ("pants", "Pants: Pants, shorts, tights, leggings, swimsuit bottoms… (clothing with 2 legs)"),
10: ("torso-skin", "Torso-skin"),
11: ("scarf", "Scarf: Scarf, bow tie, tie…"),
12: ("skirt", "Skirt: Skirt, kilt, bottom of a dress…"),
13: ("face", "Face"),
14: ("left-arm", "Left-arm (naked part)"),
15: ("right-arm", "Right-arm (naked part)"),
16: ("left-leg", "Left-leg (naked part)"),
17: ("right-leg", "Right-leg (naked part)"),
18: ("left-shoe", "Left-shoe"),
19: ("right-shoe", "Right-shoe"),
20: ("bag", "Bag: Backpack, shoulder bag, fanny pack… (bag carried on oneself"),
21: ("", "Others: Jewelry, tags, bibs, belts, ribbons, pins, head decorations, headphones…"),
}
# Model configuration - updated paths consistent with new repos
current_dir = os.path.dirname(os.path.abspath(__file__))
models_dir = os.path.join(current_dir, "models")
models_dir_path = os.path.join(models_dir, "onnx", "human-parts")
model_url = "https://huggingface.co/Metal3d/deeplabv3p-resnet50-human/resolve/main/deeplabv3p-resnet50-human.onnx"
model_name = "deeplabv3p-resnet50-human.onnx"
model_path = os.path.join(models_dir_path, model_name)
def get_class_index(class_name: str) -> int:
"""Return the index of the class name in the model."""
if class_name == "":
return -1
for key, value in HUMAN_PARTS_LABELS.items():
if value[0] == class_name:
return key
return -1
def download_model(model_url: str, model_path: str) -> bool:
"""Download the human parts segmentation model if not present - improved version."""
if os.path.exists(model_path):
return True
try:
os.makedirs(os.path.dirname(model_path), exist_ok=True)
print(f"Downloading human parts model to {model_path}...")
response = requests.get(model_url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
print(f"\rDownload progress: {percent:.1f}%", end='', flush=True)
print("\n✅ Model download completed")
return True
except Exception as e:
print(f"\n❌ Error downloading model: {e}")
return False
def get_human_parts_mask(image: torch.Tensor, model: ort.InferenceSession, rotation: float = 0, **kwargs) -> Tuple[torch.Tensor, int]:
"""
Generate human parts mask using the ONNX model - improved version.
Args:
image: Input image tensor
model: ONNX inference session
rotation: Rotation angle (not used currently)
**kwargs: Part-specific enable flags
Returns:
Tuple of (mask_tensor, score)
"""
image = image.squeeze(0)
image_np = image.numpy() * 255
pil_image = Image.fromarray(image_np.astype(np.uint8))
original_size = pil_image.size
# Resize to 512x512 as the model expects
pil_image = pil_image.resize((512, 512))
center = (256, 256)
if rotation != 0:
pil_image = pil_image.rotate(rotation, center=center)
# Normalize the image
image_np = np.array(pil_image).astype(np.float32) / 127.5 - 1
image_np = np.expand_dims(image_np, axis=0)
# Use the ONNX model to get the segmentation
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[0].name
result = model.run([output_name], {input_name: image_np})
result = np.array(result[0]).argmax(axis=3).squeeze(0)
# Debug: Check what classes the model actually detected
unique_classes = np.unique(result)
score = 0
mask = np.zeros_like(result)
# Combine masks for enabled classes
for class_name, enabled in kwargs.items():
class_index = get_class_index(class_name)
if enabled and class_index != -1:
detected = result == class_index
mask[detected] = 255
score += mask.sum()
# Resize back to original size
mask_image = Image.fromarray(mask.astype(np.uint8), mode="L")
if rotation != 0:
mask_image = mask_image.rotate(-rotation, center=center)
mask_image = mask_image.resize(original_size)
# Convert back to numpy - improved tensor handling
mask = np.array(mask_image).astype(np.float32) / 255.0 # Normalize to 0-1 range
# Add dimensions for torch tensor - consistent format
mask = np.expand_dims(mask, axis=0)
mask = np.expand_dims(mask, axis=0)
return torch.from_numpy(mask), score
def numpy_to_torch_tensor(image_np: np.ndarray) -> torch.Tensor:
"""Convert numpy array to torch tensor in the format expected by the models."""
if len(image_np.shape) == 3:
return torch.from_numpy(image_np.astype(np.float32) / 255.0).unsqueeze(0)
return torch.from_numpy(image_np.astype(np.float32) / 255.0)
def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""Convert torch tensor back to numpy array - improved version."""
if len(tensor.shape) == 4:
tensor = tensor.squeeze(0)
# Always handle as float32 tensor in 0-1 range then convert to binary
tensor_np = tensor.numpy()
if tensor_np.dtype == np.float32 and tensor_np.max() <= 1.0:
return (tensor_np > 0.5).astype(np.float32) # Binary threshold
else:
return tensor_np
class HumanPartsSegmentation:
"""
Standalone human parts segmentation for face and hair using DeepLabV3+ ResNet50.
"""
def __init__(self):
self.model = None
def check_model_cache(self):
"""Check if model file exists in cache - consistent with updated repos."""
if not os.path.exists(model_path):
return False, "Model file not found"
return True, "Model cache verified"
def clear_model(self):
"""Clear model from memory - improved version."""
if self.model is not None:
del self.model
self.model = None
def load_model(self):
"""Load the human parts segmentation model - improved version."""
try:
# Check and download model if needed
cache_status, message = self.check_model_cache()
if not cache_status:
print(f"Cache check: {message}")
if not download_model(model_url, model_path):
return False
# Load model if needed
if self.model is None:
print("Loading human parts segmentation model...")
self.model = ort.InferenceSession(model_path)
print("✅ Human parts segmentation model loaded successfully")
return True
except Exception as e:
print(f"❌ Error loading human parts model: {e}")
self.clear_model() # Cleanup on error
return False
def segment_parts(self, image_path: str, parts: List[str], mask_blur: int = 0, mask_offset: int = 0) -> Dict[str, np.ndarray]:
"""
Segment specific human parts from an image - improved version with filtering.
Args:
image_path: Path to the image file
parts: List of part names to segment (e.g., ['face', 'hair'])
mask_blur: Blur amount for mask edges
mask_offset: Expand/Shrink mask boundary
Returns:
Dictionary mapping part names to binary masks
"""
if not self.load_model():
print("❌ Cannot load human parts segmentation model")
return {}
try:
# Load image
image = cv2.imread(image_path)
if image is None:
print(f"❌ Could not load image: {image_path}")
return {}
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Convert to tensor format expected by the model
image_tensor = numpy_to_torch_tensor(image_rgb)
# Prepare kwargs for each part
part_kwargs = {part: True for part in parts}
# Get segmentation mask
mask_tensor, score = get_human_parts_mask(image_tensor, self.model, **part_kwargs)
# Convert back to numpy
if len(mask_tensor.shape) == 4:
mask_tensor = mask_tensor.squeeze(0).squeeze(0)
elif len(mask_tensor.shape) == 3:
mask_tensor = mask_tensor.squeeze(0)
# Get the combined mask for all requested parts
combined_mask = mask_tensor.numpy()
# Generate individual masks for each part if multiple parts requested
result_masks = {}
if len(parts) == 1:
# Single part - return the combined mask
part_name = parts[0]
final_mask = self._apply_filters(combined_mask, mask_blur, mask_offset)
if np.sum(final_mask > 0) > 0:
result_masks[part_name] = final_mask
else:
result_masks[part_name] = final_mask # Return empty mask instead of None
else:
# Multiple parts - need to segment each individually
for part in parts:
single_part_kwargs = {part: True}
single_mask_tensor, _ = get_human_parts_mask(image_tensor, self.model, **single_part_kwargs)
if len(single_mask_tensor.shape) == 4:
single_mask_tensor = single_mask_tensor.squeeze(0).squeeze(0)
elif len(single_mask_tensor.shape) == 3:
single_mask_tensor = single_mask_tensor.squeeze(0)
single_mask = single_mask_tensor.numpy()
final_mask = self._apply_filters(single_mask, mask_blur, mask_offset)
result_masks[part] = final_mask # Always add mask, even if empty
return result_masks
except Exception as e:
print(f"❌ Error in human parts segmentation: {e}")
return {}
finally:
# Clean up model if not needed
self.clear_model()
def _apply_filters(self, mask: np.ndarray, mask_blur: int = 0, mask_offset: int = 0) -> np.ndarray:
"""Apply filtering to mask - new method from updated repo."""
if mask_blur == 0 and mask_offset == 0:
return mask
try:
# Convert to PIL for filtering
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
# Apply blur if specified
if mask_blur > 0:
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
# Apply offset if specified
if mask_offset != 0:
if mask_offset > 0:
mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1))
else:
mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1))
# Convert back to numpy
filtered_mask = np.array(mask_image).astype(np.float32) / 255.0
return filtered_mask
except Exception as e:
print(f"❌ Error applying filters: {e}")
return mask