Saad0KH's picture
Update SegCloth.py
871269e verified
raw
history blame
1.47 kB
from transformers import pipeline
from PIL import Image
import numpy as np
from io import BytesIO
import io
import base64
# Initialize segmentation pipeline
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
def encode_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def segment_clothing(img, clothes= ["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
# Segment image
segments = segmenter(img)
# List to hold the results
results = []
# Process each segment
for s in segments:
if s['label'] in clothes:
# Create a blank image with the same size as the original
clothing_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
# Apply the mask to the new image
mask = np.array(s['mask'])
mask_image = Image.fromarray(mask * 255) # Convert mask to 255 range for alpha channel
# Paste mask onto the blank image
clothing_image.paste(img, mask=mask_image)
# Convert image to base64
image_base64 = encode_image_to_base64(clothing_image)
# Add to results list
results.append({
"type": s['label'],
"image_base64": image_base64
})
return results