import os import numpy as np import torch from PIL import Image, ImageFilter import cv2 import requests from typing import Dict, List, Tuple, Optional import onnxruntime as ort # Human parts labels based on CCIHP dataset - consistent with latest repo HUMAN_PARTS_LABELS = { 0: ("background", "Background"), 1: ("hat", "Hat: Hat, helmet, cap, hood, veil, headscarf, part covering the skull and hair of a hood/balaclava, crown…"), 2: ("hair", "Hair"), 3: ("glove", "Glove"), 4: ("glasses", "Sunglasses/Glasses: Sunglasses, eyewear, protective glasses…"), 5: ("upper_clothes", "UpperClothes: T-shirt, shirt, tank top, sweater under a coat, top of a dress…"), 6: ("face_mask", "Face Mask: Protective mask, surgical mask, carnival mask, facial part of a balaclava, visor of a helmet…"), 7: ("coat", "Coat: Coat, jacket worn without anything on it, vest with nothing on it, a sweater with nothing on it…"), 8: ("socks", "Socks"), 9: ("pants", "Pants: Pants, shorts, tights, leggings, swimsuit bottoms… (clothing with 2 legs)"), 10: ("torso-skin", "Torso-skin"), 11: ("scarf", "Scarf: Scarf, bow tie, tie…"), 12: ("skirt", "Skirt: Skirt, kilt, bottom of a dress…"), 13: ("face", "Face"), 14: ("left-arm", "Left-arm (naked part)"), 15: ("right-arm", "Right-arm (naked part)"), 16: ("left-leg", "Left-leg (naked part)"), 17: ("right-leg", "Right-leg (naked part)"), 18: ("left-shoe", "Left-shoe"), 19: ("right-shoe", "Right-shoe"), 20: ("bag", "Bag: Backpack, shoulder bag, fanny pack… (bag carried on oneself"), 21: ("", "Others: Jewelry, tags, bibs, belts, ribbons, pins, head decorations, headphones…"), } # Model configuration - updated paths consistent with new repos current_dir = os.path.dirname(os.path.abspath(__file__)) models_dir = os.path.join(current_dir, "models") models_dir_path = os.path.join(models_dir, "onnx", "human-parts") model_url = "https://huggingface.co/Metal3d/deeplabv3p-resnet50-human/resolve/main/deeplabv3p-resnet50-human.onnx" model_name = "deeplabv3p-resnet50-human.onnx" model_path = os.path.join(models_dir_path, model_name) def get_class_index(class_name: str) -> int: """Return the index of the class name in the model.""" if class_name == "": return -1 for key, value in HUMAN_PARTS_LABELS.items(): if value[0] == class_name: return key return -1 def download_model(model_url: str, model_path: str) -> bool: """Download the human parts segmentation model if not present - improved version.""" if os.path.exists(model_path): return True try: os.makedirs(os.path.dirname(model_path), exist_ok=True) print(f"Downloading human parts model to {model_path}...") response = requests.get(model_url, stream=True) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) downloaded += len(chunk) if total_size > 0: percent = (downloaded / total_size) * 100 print(f"\rDownload progress: {percent:.1f}%", end='', flush=True) print("\n✅ Model download completed") return True except Exception as e: print(f"\n❌ Error downloading model: {e}") return False def get_human_parts_mask(image: torch.Tensor, model: ort.InferenceSession, rotation: float = 0, **kwargs) -> Tuple[torch.Tensor, int]: """ Generate human parts mask using the ONNX model - improved version. Args: image: Input image tensor model: ONNX inference session rotation: Rotation angle (not used currently) **kwargs: Part-specific enable flags Returns: Tuple of (mask_tensor, score) """ image = image.squeeze(0) image_np = image.numpy() * 255 pil_image = Image.fromarray(image_np.astype(np.uint8)) original_size = pil_image.size # Resize to 512x512 as the model expects pil_image = pil_image.resize((512, 512)) center = (256, 256) if rotation != 0: pil_image = pil_image.rotate(rotation, center=center) # Normalize the image image_np = np.array(pil_image).astype(np.float32) / 127.5 - 1 image_np = np.expand_dims(image_np, axis=0) # Use the ONNX model to get the segmentation input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name result = model.run([output_name], {input_name: image_np}) result = np.array(result[0]).argmax(axis=3).squeeze(0) # Debug: Check what classes the model actually detected unique_classes = np.unique(result) score = 0 mask = np.zeros_like(result) # Combine masks for enabled classes for class_name, enabled in kwargs.items(): class_index = get_class_index(class_name) if enabled and class_index != -1: detected = result == class_index mask[detected] = 255 score += mask.sum() # Resize back to original size mask_image = Image.fromarray(mask.astype(np.uint8), mode="L") if rotation != 0: mask_image = mask_image.rotate(-rotation, center=center) mask_image = mask_image.resize(original_size) # Convert back to numpy - improved tensor handling mask = np.array(mask_image).astype(np.float32) / 255.0 # Normalize to 0-1 range # Add dimensions for torch tensor - consistent format mask = np.expand_dims(mask, axis=0) mask = np.expand_dims(mask, axis=0) return torch.from_numpy(mask), score def numpy_to_torch_tensor(image_np: np.ndarray) -> torch.Tensor: """Convert numpy array to torch tensor in the format expected by the models.""" if len(image_np.shape) == 3: return torch.from_numpy(image_np.astype(np.float32) / 255.0).unsqueeze(0) return torch.from_numpy(image_np.astype(np.float32) / 255.0) def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: """Convert torch tensor back to numpy array - improved version.""" if len(tensor.shape) == 4: tensor = tensor.squeeze(0) # Always handle as float32 tensor in 0-1 range then convert to binary tensor_np = tensor.numpy() if tensor_np.dtype == np.float32 and tensor_np.max() <= 1.0: return (tensor_np > 0.5).astype(np.float32) # Binary threshold else: return tensor_np class HumanPartsSegmentation: """ Standalone human parts segmentation for face and hair using DeepLabV3+ ResNet50. """ def __init__(self): self.model = None def check_model_cache(self): """Check if model file exists in cache - consistent with updated repos.""" if not os.path.exists(model_path): return False, "Model file not found" return True, "Model cache verified" def clear_model(self): """Clear model from memory - improved version.""" if self.model is not None: del self.model self.model = None def load_model(self): """Load the human parts segmentation model - improved version.""" try: # Check and download model if needed cache_status, message = self.check_model_cache() if not cache_status: print(f"Cache check: {message}") if not download_model(model_url, model_path): return False # Load model if needed if self.model is None: print("Loading human parts segmentation model...") self.model = ort.InferenceSession(model_path) print("✅ Human parts segmentation model loaded successfully") return True except Exception as e: print(f"❌ Error loading human parts model: {e}") self.clear_model() # Cleanup on error return False def segment_parts(self, image_path: str, parts: List[str], mask_blur: int = 0, mask_offset: int = 0) -> Dict[str, np.ndarray]: """ Segment specific human parts from an image - improved version with filtering. Args: image_path: Path to the image file parts: List of part names to segment (e.g., ['face', 'hair']) mask_blur: Blur amount for mask edges mask_offset: Expand/Shrink mask boundary Returns: Dictionary mapping part names to binary masks """ if not self.load_model(): print("❌ Cannot load human parts segmentation model") return {} try: # Load image image = cv2.imread(image_path) if image is None: print(f"❌ Could not load image: {image_path}") return {} image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert to tensor format expected by the model image_tensor = numpy_to_torch_tensor(image_rgb) # Prepare kwargs for each part part_kwargs = {part: True for part in parts} # Get segmentation mask mask_tensor, score = get_human_parts_mask(image_tensor, self.model, **part_kwargs) # Convert back to numpy if len(mask_tensor.shape) == 4: mask_tensor = mask_tensor.squeeze(0).squeeze(0) elif len(mask_tensor.shape) == 3: mask_tensor = mask_tensor.squeeze(0) # Get the combined mask for all requested parts combined_mask = mask_tensor.numpy() # Generate individual masks for each part if multiple parts requested result_masks = {} if len(parts) == 1: # Single part - return the combined mask part_name = parts[0] final_mask = self._apply_filters(combined_mask, mask_blur, mask_offset) if np.sum(final_mask > 0) > 0: result_masks[part_name] = final_mask else: result_masks[part_name] = final_mask # Return empty mask instead of None else: # Multiple parts - need to segment each individually for part in parts: single_part_kwargs = {part: True} single_mask_tensor, _ = get_human_parts_mask(image_tensor, self.model, **single_part_kwargs) if len(single_mask_tensor.shape) == 4: single_mask_tensor = single_mask_tensor.squeeze(0).squeeze(0) elif len(single_mask_tensor.shape) == 3: single_mask_tensor = single_mask_tensor.squeeze(0) single_mask = single_mask_tensor.numpy() final_mask = self._apply_filters(single_mask, mask_blur, mask_offset) result_masks[part] = final_mask # Always add mask, even if empty return result_masks except Exception as e: print(f"❌ Error in human parts segmentation: {e}") return {} finally: # Clean up model if not needed self.clear_model() def _apply_filters(self, mask: np.ndarray, mask_blur: int = 0, mask_offset: int = 0) -> np.ndarray: """Apply filtering to mask - new method from updated repo.""" if mask_blur == 0 and mask_offset == 0: return mask try: # Convert to PIL for filtering mask_image = Image.fromarray((mask * 255).astype(np.uint8)) # Apply blur if specified if mask_blur > 0: mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur)) # Apply offset if specified if mask_offset != 0: if mask_offset > 0: mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1)) else: mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1)) # Convert back to numpy filtered_mask = np.array(mask_image).astype(np.float32) / 255.0 return filtered_mask except Exception as e: print(f"❌ Error applying filters: {e}") return mask