groueix commited on
Commit
ba29ad3
·
1 Parent(s): a7c5a97

added white background

Browse files
Files changed (1) hide show
  1. copaint/gradio_ui.py +18 -2
copaint/gradio_ui.py CHANGED
@@ -24,6 +24,18 @@ logger = logging.getLogger(__name__)
24
  fromPIltoTensor = torchvision.transforms.ToTensor()
25
  fromTensortoPIL = torchvision.transforms.ToPILImage()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def add_grid_to_image(image, h_cells, w_cells):
29
  # image is a torch tensor of shape (3, h, w)
@@ -177,8 +189,12 @@ def build_gradio_ui():
177
 
178
  # Upload Design Template
179
  with gr.Column(scale=2):
180
- input_image = gr.Image(type="pil", label="Upload Your Design")
181
-
 
 
 
 
182
  with gr.Column(scale=1):
183
  # Grid
184
  with gr.Tab("Grid Layout"):
 
24
  fromPIltoTensor = torchvision.transforms.ToTensor()
25
  fromTensortoPIL = torchvision.transforms.ToPILImage()
26
 
27
+ from PIL import Image
28
+
29
+ def remove_transparency(image: Image.Image):
30
+ # Convert transparency to white background
31
+ if image.mode in ("RGBA", "LA"):
32
+ background = Image.new("RGB", image.size, (255, 255, 255)) # white background
33
+ background.paste(image, mask=image.split()[-1]) # paste using alpha channel as mask
34
+ image = background
35
+ else:
36
+ image = image.convert("RGB") # just to be safe
37
+
38
+ return image # or continue processing
39
 
40
  def add_grid_to_image(image, h_cells, w_cells):
41
  # image is a torch tensor of shape (3, h, w)
 
189
 
190
  # Upload Design Template
191
  with gr.Column(scale=2):
192
+ input_image = gr.Image(type="pil", label="Upload Your Design", preprocess=handle_upload)
193
+ image_input.upload(
194
+ fn=remove_transparency,
195
+ inputs=image_input,
196
+ outputs=image_input
197
+ )
198
  with gr.Column(scale=1):
199
  # Grid
200
  with gr.Tab("Grid Layout"):