Saad0KH commited on
Commit
187d444
Β·
verified Β·
1 Parent(s): 0703834

Update SegCloth.py

Browse files
Files changed (1) hide show
  1. SegCloth.py +22 -15
SegCloth.py CHANGED
@@ -13,26 +13,33 @@ def encode_image_to_base64(image):
13
  image.save(buffered, format="PNG")
14
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
15
 
16
- def segment_clothing(img, clothes= ["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
17
  # Segment image
18
  segments = segmenter(img)
19
 
20
- # Create list of masks
21
- mask_list = []
22
- for s in segments:
23
- if(s['label'] in clothes):
24
- mask_list.append(s['mask'])
25
 
 
26
  result_images = []
27
-
28
- # Paste all masks on top of eachother
29
- final_mask = np.array(mask_list[0])
30
- for mask in mask_list:
31
- current_mask = np.array(mask)
32
- final_mask_bis = Image.fromarray(current_mask)
33
- img.putalpha(final_mask_bis)
34
- imageBase64 = encode_image_to_base64(img)
35
- result_images.append(('clothing_type', imageBase64))
 
 
 
 
 
 
 
 
 
36
 
37
 
38
 
 
13
  image.save(buffered, format="PNG")
14
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
15
 
16
+ def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Scarf"]):
17
  # Segment image
18
  segments = segmenter(img)
19
 
20
+ # Convert image to RGBA
21
+ img = img.convert("RGBA")
 
 
 
22
 
23
+ # Create list of masks
24
  result_images = []
25
+ for s in segments:
26
+ if s['label'] in clothes:
27
+ # Extract mask and resize image to mask size
28
+ current_mask = np.array(s['mask'])
29
+ mask_size = current_mask.shape[::-1] # Mask size is (width, height)
30
+
31
+ # Resize the original image to match the mask size
32
+ resized_img = img.resize(mask_size)
33
+
34
+ # Apply mask to resized image
35
+ final_mask = Image.fromarray(current_mask)
36
+ resized_img.putalpha(final_mask)
37
+
38
+ # Convert the final image to base64
39
+ imageBase64 = encode_image_to_base64(resized_img)
40
+ result_images.append((s['label'], imageBase64))
41
+
42
+ return result_images
43
 
44
 
45