|
import os |
|
import numpy as np |
|
import torch |
|
from PIL import Image, ImageFilter |
|
import cv2 |
|
import requests |
|
from typing import Dict, List, Tuple, Optional |
|
import onnxruntime as ort |
|
|
|
|
|
HUMAN_PARTS_LABELS = { |
|
0: ("background", "Background"), |
|
1: ("hat", "Hat: Hat, helmet, cap, hood, veil, headscarf, part covering the skull and hair of a hood/balaclava, crown…"), |
|
2: ("hair", "Hair"), |
|
3: ("glove", "Glove"), |
|
4: ("glasses", "Sunglasses/Glasses: Sunglasses, eyewear, protective glasses…"), |
|
5: ("upper_clothes", "UpperClothes: T-shirt, shirt, tank top, sweater under a coat, top of a dress…"), |
|
6: ("face_mask", "Face Mask: Protective mask, surgical mask, carnival mask, facial part of a balaclava, visor of a helmet…"), |
|
7: ("coat", "Coat: Coat, jacket worn without anything on it, vest with nothing on it, a sweater with nothing on it…"), |
|
8: ("socks", "Socks"), |
|
9: ("pants", "Pants: Pants, shorts, tights, leggings, swimsuit bottoms… (clothing with 2 legs)"), |
|
10: ("torso-skin", "Torso-skin"), |
|
11: ("scarf", "Scarf: Scarf, bow tie, tie…"), |
|
12: ("skirt", "Skirt: Skirt, kilt, bottom of a dress…"), |
|
13: ("face", "Face"), |
|
14: ("left-arm", "Left-arm (naked part)"), |
|
15: ("right-arm", "Right-arm (naked part)"), |
|
16: ("left-leg", "Left-leg (naked part)"), |
|
17: ("right-leg", "Right-leg (naked part)"), |
|
18: ("left-shoe", "Left-shoe"), |
|
19: ("right-shoe", "Right-shoe"), |
|
20: ("bag", "Bag: Backpack, shoulder bag, fanny pack… (bag carried on oneself"), |
|
21: ("", "Others: Jewelry, tags, bibs, belts, ribbons, pins, head decorations, headphones…"), |
|
} |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
models_dir = os.path.join(current_dir, "models") |
|
models_dir_path = os.path.join(models_dir, "onnx", "human-parts") |
|
model_url = "https://huggingface.co/Metal3d/deeplabv3p-resnet50-human/resolve/main/deeplabv3p-resnet50-human.onnx" |
|
model_name = "deeplabv3p-resnet50-human.onnx" |
|
model_path = os.path.join(models_dir_path, model_name) |
|
|
|
|
|
def get_class_index(class_name: str) -> int: |
|
"""Return the index of the class name in the model.""" |
|
if class_name == "": |
|
return -1 |
|
|
|
for key, value in HUMAN_PARTS_LABELS.items(): |
|
if value[0] == class_name: |
|
return key |
|
return -1 |
|
|
|
|
|
def download_model(model_url: str, model_path: str) -> bool: |
|
"""Download the human parts segmentation model if not present - improved version.""" |
|
if os.path.exists(model_path): |
|
return True |
|
|
|
try: |
|
os.makedirs(os.path.dirname(model_path), exist_ok=True) |
|
print(f"Downloading human parts model to {model_path}...") |
|
|
|
response = requests.get(model_url, stream=True) |
|
response.raise_for_status() |
|
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
downloaded = 0 |
|
|
|
with open(model_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
downloaded += len(chunk) |
|
if total_size > 0: |
|
percent = (downloaded / total_size) * 100 |
|
print(f"\rDownload progress: {percent:.1f}%", end='', flush=True) |
|
|
|
print("\n✅ Model download completed") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"\n❌ Error downloading model: {e}") |
|
return False |
|
|
|
|
|
def get_human_parts_mask(image: torch.Tensor, model: ort.InferenceSession, rotation: float = 0, **kwargs) -> Tuple[torch.Tensor, int]: |
|
""" |
|
Generate human parts mask using the ONNX model - improved version. |
|
|
|
Args: |
|
image: Input image tensor |
|
model: ONNX inference session |
|
rotation: Rotation angle (not used currently) |
|
**kwargs: Part-specific enable flags |
|
|
|
Returns: |
|
Tuple of (mask_tensor, score) |
|
""" |
|
image = image.squeeze(0) |
|
image_np = image.numpy() * 255 |
|
|
|
pil_image = Image.fromarray(image_np.astype(np.uint8)) |
|
original_size = pil_image.size |
|
|
|
|
|
pil_image = pil_image.resize((512, 512)) |
|
center = (256, 256) |
|
|
|
if rotation != 0: |
|
pil_image = pil_image.rotate(rotation, center=center) |
|
|
|
|
|
image_np = np.array(pil_image).astype(np.float32) / 127.5 - 1 |
|
image_np = np.expand_dims(image_np, axis=0) |
|
|
|
|
|
input_name = model.get_inputs()[0].name |
|
output_name = model.get_outputs()[0].name |
|
result = model.run([output_name], {input_name: image_np}) |
|
result = np.array(result[0]).argmax(axis=3).squeeze(0) |
|
|
|
|
|
unique_classes = np.unique(result) |
|
|
|
score = 0 |
|
mask = np.zeros_like(result) |
|
|
|
|
|
for class_name, enabled in kwargs.items(): |
|
class_index = get_class_index(class_name) |
|
if enabled and class_index != -1: |
|
detected = result == class_index |
|
mask[detected] = 255 |
|
score += mask.sum() |
|
|
|
|
|
mask_image = Image.fromarray(mask.astype(np.uint8), mode="L") |
|
if rotation != 0: |
|
mask_image = mask_image.rotate(-rotation, center=center) |
|
|
|
mask_image = mask_image.resize(original_size) |
|
|
|
|
|
mask = np.array(mask_image).astype(np.float32) / 255.0 |
|
|
|
|
|
mask = np.expand_dims(mask, axis=0) |
|
mask = np.expand_dims(mask, axis=0) |
|
|
|
return torch.from_numpy(mask), score |
|
|
|
|
|
def numpy_to_torch_tensor(image_np: np.ndarray) -> torch.Tensor: |
|
"""Convert numpy array to torch tensor in the format expected by the models.""" |
|
if len(image_np.shape) == 3: |
|
return torch.from_numpy(image_np.astype(np.float32) / 255.0).unsqueeze(0) |
|
return torch.from_numpy(image_np.astype(np.float32) / 255.0) |
|
|
|
|
|
def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: |
|
"""Convert torch tensor back to numpy array - improved version.""" |
|
if len(tensor.shape) == 4: |
|
tensor = tensor.squeeze(0) |
|
|
|
|
|
tensor_np = tensor.numpy() |
|
if tensor_np.dtype == np.float32 and tensor_np.max() <= 1.0: |
|
return (tensor_np > 0.5).astype(np.float32) |
|
else: |
|
return tensor_np |
|
|
|
|
|
class HumanPartsSegmentation: |
|
""" |
|
Standalone human parts segmentation for face and hair using DeepLabV3+ ResNet50. |
|
""" |
|
|
|
def __init__(self): |
|
self.model = None |
|
|
|
def check_model_cache(self): |
|
"""Check if model file exists in cache - consistent with updated repos.""" |
|
if not os.path.exists(model_path): |
|
return False, "Model file not found" |
|
return True, "Model cache verified" |
|
|
|
def clear_model(self): |
|
"""Clear model from memory - improved version.""" |
|
if self.model is not None: |
|
del self.model |
|
self.model = None |
|
|
|
def load_model(self): |
|
"""Load the human parts segmentation model - improved version.""" |
|
try: |
|
|
|
cache_status, message = self.check_model_cache() |
|
if not cache_status: |
|
print(f"Cache check: {message}") |
|
if not download_model(model_url, model_path): |
|
return False |
|
|
|
|
|
if self.model is None: |
|
print("Loading human parts segmentation model...") |
|
self.model = ort.InferenceSession(model_path) |
|
print("✅ Human parts segmentation model loaded successfully") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f"❌ Error loading human parts model: {e}") |
|
self.clear_model() |
|
return False |
|
|
|
def segment_parts(self, image_path: str, parts: List[str], mask_blur: int = 0, mask_offset: int = 0) -> Dict[str, np.ndarray]: |
|
""" |
|
Segment specific human parts from an image - improved version with filtering. |
|
|
|
Args: |
|
image_path: Path to the image file |
|
parts: List of part names to segment (e.g., ['face', 'hair']) |
|
mask_blur: Blur amount for mask edges |
|
mask_offset: Expand/Shrink mask boundary |
|
|
|
Returns: |
|
Dictionary mapping part names to binary masks |
|
""" |
|
if not self.load_model(): |
|
print("❌ Cannot load human parts segmentation model") |
|
return {} |
|
|
|
try: |
|
|
|
image = cv2.imread(image_path) |
|
if image is None: |
|
print(f"❌ Could not load image: {image_path}") |
|
return {} |
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
image_tensor = numpy_to_torch_tensor(image_rgb) |
|
|
|
|
|
part_kwargs = {part: True for part in parts} |
|
|
|
|
|
mask_tensor, score = get_human_parts_mask(image_tensor, self.model, **part_kwargs) |
|
|
|
|
|
if len(mask_tensor.shape) == 4: |
|
mask_tensor = mask_tensor.squeeze(0).squeeze(0) |
|
elif len(mask_tensor.shape) == 3: |
|
mask_tensor = mask_tensor.squeeze(0) |
|
|
|
|
|
combined_mask = mask_tensor.numpy() |
|
|
|
|
|
result_masks = {} |
|
if len(parts) == 1: |
|
|
|
part_name = parts[0] |
|
final_mask = self._apply_filters(combined_mask, mask_blur, mask_offset) |
|
if np.sum(final_mask > 0) > 0: |
|
result_masks[part_name] = final_mask |
|
else: |
|
result_masks[part_name] = final_mask |
|
else: |
|
|
|
for part in parts: |
|
single_part_kwargs = {part: True} |
|
single_mask_tensor, _ = get_human_parts_mask(image_tensor, self.model, **single_part_kwargs) |
|
|
|
if len(single_mask_tensor.shape) == 4: |
|
single_mask_tensor = single_mask_tensor.squeeze(0).squeeze(0) |
|
elif len(single_mask_tensor.shape) == 3: |
|
single_mask_tensor = single_mask_tensor.squeeze(0) |
|
|
|
single_mask = single_mask_tensor.numpy() |
|
final_mask = self._apply_filters(single_mask, mask_blur, mask_offset) |
|
|
|
result_masks[part] = final_mask |
|
|
|
return result_masks |
|
|
|
except Exception as e: |
|
print(f"❌ Error in human parts segmentation: {e}") |
|
return {} |
|
finally: |
|
|
|
self.clear_model() |
|
|
|
def _apply_filters(self, mask: np.ndarray, mask_blur: int = 0, mask_offset: int = 0) -> np.ndarray: |
|
"""Apply filtering to mask - new method from updated repo.""" |
|
if mask_blur == 0 and mask_offset == 0: |
|
return mask |
|
|
|
try: |
|
|
|
mask_image = Image.fromarray((mask * 255).astype(np.uint8)) |
|
|
|
|
|
if mask_blur > 0: |
|
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur)) |
|
|
|
|
|
if mask_offset != 0: |
|
if mask_offset > 0: |
|
mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1)) |
|
else: |
|
mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1)) |
|
|
|
|
|
filtered_mask = np.array(mask_image).astype(np.float32) / 255.0 |
|
return filtered_mask |
|
|
|
except Exception as e: |
|
print(f"❌ Error applying filters: {e}") |
|
return mask |