File size: 4,431 Bytes
f1fb6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5661265
f1fb6bb
 
 
 
 
 
 
 
 
 
5661265
 
f1fb6bb
5661265
f1fb6bb
 
5661265
f1fb6bb
 
 
 
5661265
f1fb6bb
 
 
 
 
 
 
 
5661265
f1fb6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
5661265
f1fb6bb
 
 
 
 
 
 
 
 
 
 
5661265
 
f1fb6bb
5661265
f1fb6bb
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import gradio as gr
from PIL import Image
from rembg import remove
from diffusers import StableDiffusionPipeline

# -----------------------------------------------------------------------------
# Helper function to adjust image size to multiples of 8.
# -----------------------------------------------------------------------------
def adjust_size(w, h):
    """
    Adjust width and height to be multiples of 8, as required by the Stable Diffusion model.
    """
    new_w = (w // 8) * 8
    new_h = (h // 8) * 8
    return new_w, new_h

# -----------------------------------------------------------------------------
# Core processing function:
#   1. Remove background from the uploaded image.
#   2. Generate a new background image based on the text prompt.
#   3. Composite the foreground onto the generated background.
# -----------------------------------------------------------------------------
def process_image(input_image: Image.Image, bg_prompt: str) -> Image.Image:
    """
    Processes the uploaded image by removing its background and replacing it with a generated one.
    
    Parameters:
        input_image (PIL.Image.Image): The uploaded image.
        bg_prompt (str): Text prompt describing the new background.
    
    Returns:
        PIL.Image.Image: The final composited image.
    """
    if input_image is None:
        raise ValueError("No image provided.")

    # Step 1: Remove the background from the input image.
    print("Removing background from the uploaded image...")
    foreground = remove(input_image)
    foreground = foreground.convert("RGBA")

    # Step 2: Adjust dimensions for background generation.
    orig_w, orig_h = foreground.size
    gen_w, gen_h = adjust_size(orig_w, orig_h)
    print(f"Original size: {orig_w}x{orig_h} | Adjusted size: {gen_w}x{gen_h}")

    # Step 3: Generate a new background using the provided text prompt.
    print("Generating new background using Stable Diffusion...")
    bg_output = pipe(
        bg_prompt,
        height=gen_h,
        width=gen_w,
        num_inference_steps=50,  # Adjust if needed.
        guidance_scale=7.5       # Adjust for more/less prompt adherence.
    )
    # Convert the generated background to RGBA.
    background = bg_output.images[0].convert("RGBA")

    # Step 4: Ensure the foreground matches the background dimensions.
    if foreground.size != background.size:
        print("Resizing foreground to match background dimensions...")
        foreground = foreground.resize(background.size, Image.ANTIALIAS)

    # Step 5: Composite the images.
    print("Compositing images...")
    final_image = Image.alpha_composite(background, foreground)

    return final_image

# -----------------------------------------------------------------------------
# Load the Stable Diffusion pipeline from Hugging Face.
# -----------------------------------------------------------------------------
MODEL_ID = "stabilityai/stable-diffusion-2"  # Change the model if desired.

# Use half precision if GPU is available.
if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

print("Loading Stable Diffusion pipeline...")
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch_dtype)
if torch.cuda.is_available():
    pipe = pipe.to("cuda")
print("Stable Diffusion pipeline loaded.")

# -----------------------------------------------------------------------------
# Create the Gradio Interface using the updated API.
# -----------------------------------------------------------------------------
title = "Background Removal & Replacement"
description = (
    "Upload an image (e.g., a person or an animal) and provide a text prompt "
    "describing the new background. The app will remove the original background and "
    "composite the subject onto a generated background."
)

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Upload Your Image"),
        gr.Textbox(lines=2, placeholder="Describe the new background...", label="Background Prompt")
    ],
    outputs=gr.Image(label="Output Image"),
    title=title,
    description=description,
    allow_flagging="never"
)

# -----------------------------------------------------------------------------
# Launch the app.
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    iface.launch()