|
import os |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from typing import Union, Tuple |
|
from PIL import Image, ImageFilter |
|
import cv2 |
|
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation |
|
from huggingface_hub import hf_hub_download |
|
import shutil |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
AVAILABLE_MODELS = { |
|
"segformer_b2_clothes": "1038lab/segformer_clothes" |
|
} |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
models_dir = os.path.join(current_dir, "models") |
|
|
|
|
|
def pil2tensor(image: Image.Image) -> torch.Tensor: |
|
"""Convert PIL Image to tensor.""" |
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,] |
|
|
|
|
|
def tensor2pil(image: torch.Tensor) -> Image.Image: |
|
"""Convert tensor to PIL Image.""" |
|
return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) |
|
|
|
|
|
def image2mask(image: Image.Image) -> torch.Tensor: |
|
"""Convert image to mask tensor.""" |
|
if isinstance(image, Image.Image): |
|
image = pil2tensor(image) |
|
return image.squeeze()[..., 0] |
|
|
|
|
|
def mask2image(mask: torch.Tensor) -> Image.Image: |
|
"""Convert mask tensor to PIL Image.""" |
|
if len(mask.shape) == 2: |
|
mask = mask.unsqueeze(0) |
|
return tensor2pil(mask) |
|
|
|
|
|
class ClothesSegmentation: |
|
""" |
|
Standalone clothing segmentation using Segformer model. |
|
""" |
|
|
|
def __init__(self): |
|
self.processor = None |
|
self.model = None |
|
self.cache_dir = os.path.join(models_dir, "RMBG", "segformer_clothes") |
|
|
|
|
|
self.class_map = { |
|
"Background": 0, "Hat": 1, "Hair": 2, "Sunglasses": 3, |
|
"Upper-clothes": 4, "Skirt": 5, "Pants": 6, "Dress": 7, |
|
"Belt": 8, "Left-shoe": 9, "Right-shoe": 10, "Face": 11, |
|
"Left-leg": 12, "Right-leg": 13, "Left-arm": 14, "Right-arm": 15, |
|
"Bag": 16, "Scarf": 17 |
|
} |
|
|
|
def check_model_cache(self): |
|
"""Check if model files exist in cache.""" |
|
if not os.path.exists(self.cache_dir): |
|
return False, "Model directory not found" |
|
|
|
required_files = [ |
|
'config.json', |
|
'model.safetensors', |
|
'preprocessor_config.json' |
|
] |
|
|
|
missing_files = [f for f in required_files if not os.path.exists(os.path.join(self.cache_dir, f))] |
|
if missing_files: |
|
return False, f"Required model files missing: {', '.join(missing_files)}" |
|
return True, "Model cache verified" |
|
|
|
def clear_model(self): |
|
"""Clear model from memory - improved version.""" |
|
if self.model is not None: |
|
self.model.cpu() |
|
del self.model |
|
self.model = None |
|
self.processor = None |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
def download_model_files(self): |
|
"""Download model files from Hugging Face - improved version.""" |
|
model_id = AVAILABLE_MODELS["segformer_b2_clothes"] |
|
model_files = { |
|
'config.json': 'config.json', |
|
'model.safetensors': 'model.safetensors', |
|
'preprocessor_config.json': 'preprocessor_config.json' |
|
} |
|
|
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
print(f"Downloading Clothes Segformer model files...") |
|
|
|
try: |
|
for save_name, repo_path in model_files.items(): |
|
print(f"Downloading {save_name}...") |
|
downloaded_path = hf_hub_download( |
|
repo_id=model_id, |
|
filename=repo_path, |
|
local_dir=self.cache_dir, |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
if os.path.dirname(downloaded_path) != self.cache_dir: |
|
target_path = os.path.join(self.cache_dir, save_name) |
|
shutil.move(downloaded_path, target_path) |
|
return True, "Model files downloaded successfully" |
|
except Exception as e: |
|
return False, f"Error downloading model files: {str(e)}" |
|
|
|
def load_model(self): |
|
"""Load the clothing segmentation model - improved version.""" |
|
try: |
|
|
|
cache_status, message = self.check_model_cache() |
|
if not cache_status: |
|
print(f"Cache check: {message}") |
|
download_status, download_message = self.download_model_files() |
|
if not download_status: |
|
print(f"β {download_message}") |
|
return False |
|
|
|
|
|
if self.processor is None: |
|
print("Loading clothes segmentation model...") |
|
self.processor = SegformerImageProcessor.from_pretrained(self.cache_dir) |
|
self.model = AutoModelForSemanticSegmentation.from_pretrained(self.cache_dir) |
|
self.model.eval() |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
self.model.to(device) |
|
print("β
Clothes segmentation model loaded successfully") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f"β Error loading clothes model: {e}") |
|
self.clear_model() |
|
return False |
|
|
|
def segment_clothes(self, image_path: str, target_classes: list = None, process_res: int = 512) -> np.ndarray: |
|
""" |
|
Segment clothing from an image - improved version with process_res parameter. |
|
|
|
Args: |
|
image_path: Path to the image |
|
target_classes: List of clothing classes to segment (default: ["Upper-clothes"]) |
|
process_res: Processing resolution (default: 512) |
|
|
|
Returns: |
|
Binary mask as numpy array |
|
""" |
|
if target_classes is None: |
|
target_classes = ["Upper-clothes"] |
|
|
|
if not self.load_model(): |
|
print("β Cannot load clothes segmentation model") |
|
return None |
|
|
|
try: |
|
|
|
image = cv2.imread(image_path) |
|
if image is None: |
|
print(f"β Could not load image: {image_path}") |
|
return None |
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
original_size = image_rgb.shape[:2] |
|
|
|
|
|
pil_image = Image.fromarray(image_rgb) |
|
|
|
|
|
if process_res != 512: |
|
pil_image = pil_image.resize((process_res, process_res), Image.Resampling.LANCZOS) |
|
|
|
inputs = self.processor(images=pil_image, return_tensors="pt") |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
logits = outputs.logits.cpu() |
|
|
|
|
|
upsampled_logits = nn.functional.interpolate( |
|
logits, |
|
size=original_size, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
pred_seg = upsampled_logits.argmax(dim=1)[0] |
|
|
|
|
|
combined_mask = None |
|
for class_name in target_classes: |
|
if class_name in self.class_map: |
|
mask = (pred_seg == self.class_map[class_name]).float() |
|
if combined_mask is None: |
|
combined_mask = mask |
|
else: |
|
combined_mask = torch.clamp(combined_mask + mask, 0, 1) |
|
else: |
|
print(f"β οΈ Unknown class: {class_name}") |
|
|
|
if combined_mask is None: |
|
print(f"β No valid classes found in: {target_classes}") |
|
return None |
|
|
|
|
|
mask_np = combined_mask.numpy().astype(np.float32) |
|
|
|
return mask_np |
|
|
|
except Exception as e: |
|
print(f"β Error in clothes segmentation: {e}") |
|
return None |
|
finally: |
|
|
|
if self.model is not None and not self.model.training: |
|
self.clear_model() |
|
|
|
def segment_clothes_with_filters(self, image_path: str, target_classes: list = None, |
|
mask_blur: int = 0, mask_offset: int = 0, |
|
process_res: int = 512) -> np.ndarray: |
|
""" |
|
Segment clothing with additional filtering options - new method from updated repo. |
|
|
|
Args: |
|
image_path: Path to the image |
|
target_classes: List of clothing classes to segment |
|
mask_blur: Blur amount for mask edges |
|
mask_offset: Expand/Shrink mask boundary |
|
process_res: Processing resolution |
|
|
|
Returns: |
|
Filtered binary mask as numpy array |
|
""" |
|
|
|
mask = self.segment_clothes(image_path, target_classes, process_res) |
|
if mask is None: |
|
return None |
|
|
|
try: |
|
|
|
mask_image = Image.fromarray((mask * 255).astype(np.uint8)) |
|
|
|
|
|
if mask_blur > 0: |
|
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur)) |
|
|
|
|
|
if mask_offset != 0: |
|
if mask_offset > 0: |
|
mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1)) |
|
else: |
|
mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1)) |
|
|
|
|
|
filtered_mask = np.array(mask_image).astype(np.float32) / 255.0 |
|
return filtered_mask |
|
|
|
except Exception as e: |
|
print(f"β Error applying filters: {e}") |
|
return mask |
|
|
|
|
|
|
|
def segment_upper_clothes(image_path: str) -> np.ndarray: |
|
""" |
|
Convenience function to segment upper clothes from an image. |
|
|
|
Args: |
|
image_path: Path to the image |
|
|
|
Returns: |
|
Binary mask as numpy array or None if failed |
|
""" |
|
segmenter = ClothesSegmentation() |
|
return segmenter.segment_clothes(image_path, ["Upper-clothes"]) |