File size: 2,783 Bytes
e546fea
 
 
 
 
 
 
 
f64024f
958511f
e546fea
f64024f
958511f
 
 
 
b218be6
f64024f
958511f
 
 
 
 
 
 
e546fea
f64024f
e546fea
958511f
 
3e75999
b218be6
 
2d64873
f64024f
2d64873
 
 
958511f
e546fea
 
 
 
 
3e75999
 
2d64873
b218be6
f64024f
2d64873
b218be6
2d64873
 
 
 
 
e546fea
f64024f
b218be6
 
 
 
 
 
20a2fe0
b218be6
b2b24c7
b218be6
958511f
f64024f
b218be6
 
 
e546fea
f64024f
958511f
b218be6
958511f
e546fea
fd00e11
e546fea
fd00e11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

# Set precision for better performance
torch.set_float32_matmul_precision(["high", "highest"][0])

# Load the BiRefNet model for image segmentation
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")

# Define image transformation pipeline
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

# Main function to handle image processing
def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    processed_image = process(im)
    return (processed_image, origin)

# Process function that runs on GPU
@spaces.GPU
def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image

# Process function for file output
def process_file(f):
    name_path = f.rsplit(".", 1)[0] + ".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path)
    return name_path

# Define UI components
slider1 = ImageSlider(label="Processed Image", type="pil")
slider2 = ImageSlider(label="Processed Image from URL", type="pil")
image_upload = gr.Image(label="Upload an image")
image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")

# Example images
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"

# Create interfaces for each tab
tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")

# Create tabbed interface
demo = gr.TabbedInterface(
    [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
)

# Launch the app with minimal parameters
if __name__ == "__main__":
    demo.launch(show_error=True)