File size: 2,689 Bytes
1fae4f7
737f1ab
 
 
0807e33
1fae4f7
737f1ab
 
1fae4f7
737f1ab
 
 
 
 
 
 
 
 
1fae4f7
737f1ab
 
 
 
 
 
 
 
 
1fae4f7
737f1ab
 
 
 
 
0807e33
 
737f1ab
0807e33
 
 
 
 
 
737f1ab
1fae4f7
737f1ab
 
 
1fae4f7
3f6deb5
 
 
 
737f1ab
 
3f6deb5
 
737f1ab
 
 
1fae4f7
3f6deb5
 
 
 
737f1ab
 
3f6deb5
 
737f1ab
 
 
1fae4f7
3f6deb5
 
 
737f1ab
 
3f6deb5
 
737f1ab
 
 
1fae4f7
3f6deb5
737f1ab
 
 
 
 
1fae4f7
 
737f1ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from gradio_client import Client, handle_file
from PIL import Image
import requests
import io

# Initialize the client
client = Client("not-lain/background-removal")

def process_image_via_api(image):
    result = client.predict(
        image=handle_file(image),
        api_name="/image"
    )
    # Convert the output tuple to PIL images and return
    if result:
        return (Image.open(result[0]), Image.open(result[1]))
    return None, None

def process_url_via_api(url):
    result = client.predict(
        image=url,
        api_name="/text"
    )
    # Convert the output tuple to PIL images and return
    if result:
        return (Image.open(result[0]), Image.open(result[1]))
    return None, None

def process_file_via_api(f):
    result = client.predict(
        f=handle_file(f),
        api_name="/png"
    )
    # Ensure the result is a valid file path
    if isinstance(result, str):
        return result
    elif isinstance(result, bytes):
        # If the result is bytes, convert to image and save to a file
        image = Image.open(io.BytesIO(result))
        output_path = "output.png"
        image.save(output_path)
        return output_path
    return None

# Example images
chameleon = "butterfly.jpg"
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"

# Tab 1: Image Upload
sliders_processed_tab1 = gr.Image(label="Processed Image")
sliders_origin_tab1 = gr.Image(label="Original Image")
image_upload_tab1 = gr.Image(label="Upload an image")
tab1 = gr.Interface(
    fn=process_image_via_api, 
    inputs=image_upload_tab1, 
    outputs=[sliders_processed_tab1, sliders_origin_tab1], 
    examples=[chameleon], 
    api_name="/image_api"
)

# Tab 2: URL Input
sliders_processed_tab2 = gr.Image(label="Processed Image")
sliders_origin_tab2 = gr.Image(label="Original Image")
url_input_tab2 = gr.Textbox(label="Paste an image URL")
tab2 = gr.Interface(
    fn=process_url_via_api, 
    inputs=url_input_tab2, 
    outputs=[sliders_processed_tab2, sliders_origin_tab2], 
    examples=[url_example], 
    api_name="/url_api"
)

# Tab 3: File Output
output_file_tab3 = gr.File(label="Output PNG File")
image_file_upload_tab3 = gr.Image(label="Upload an image", type="filepath")
tab3 = gr.Interface(
    fn=process_file_via_api, 
    inputs=image_file_upload_tab3, 
    outputs=output_file_tab3, 
    examples=[chameleon], 
    api_name="/png_api"
)

# Create the Tabbed Interface
demo = gr.TabbedInterface(
    [tab1, tab2, tab3], 
    ["Image Upload", "URL Input", "File Output"], 
    title="Background Removal Tool using API"
)

if __name__ == "__main__":
    demo.launch(show_error=True)