File size: 4,527 Bytes
f1fb6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import gradio as gr
from PIL import Image
from rembg import remove
from diffusers import StableDiffusionPipeline

# -----------------------------------------------------------------------------
# Helper function to adjust image size to multiples of 8.
# -----------------------------------------------------------------------------
def adjust_size(w, h):
    """
    Adjust width and height to be multiples of 8, as required by the Stable Diffusion model.
    """
    new_w = (w // 8) * 8
    new_h = (h // 8) * 8
    return new_w, new_h

# -----------------------------------------------------------------------------
# Core processing function:
#   1. Remove background from the uploaded image.
#   2. Generate a new background image based on the text prompt.
#   3. Composite the foreground onto the generated background.
# -----------------------------------------------------------------------------
def process_image(input_image: Image.Image, bg_prompt: str) -> Image.Image:
    """
    Processes the uploaded image by removing its background and replacing it with a generated one.
    
    Parameters:
        input_image (PIL.Image.Image): The uploaded image.
        bg_prompt (str): Text prompt describing the new background.
    
    Returns:
        PIL.Image.Image: The final composited image.
    """
    if input_image is None:
        raise ValueError("No image provided.")

    # Step 1: Remove the background from the input image.
    print("Removing background from the uploaded image...")
    foreground = remove(input_image)
    foreground = foreground.convert("RGBA")

    # Step 2: Determine new dimensions (multiples of 8) based on the foreground.
    orig_w, orig_h = foreground.size
    gen_w, gen_h = adjust_size(orig_w, orig_h)
    print(f"Original size: {orig_w}x{orig_h} | Adjusted size: {gen_w}x{gen_h}")

    # Step 3: Generate a new background using the provided text prompt.
    print("Generating new background using Stable Diffusion...")
    bg_output = pipe(
        bg_prompt,
        height=gen_h,
        width=gen_w,
        num_inference_steps=50,  # Adjust as needed.
        guidance_scale=7.5       # Adjust for prompt adherence.
    )
    # The generated background is in RGB mode; convert to RGBA for compositing.
    background = bg_output.images[0].convert("RGBA")

    # Step 4: If necessary, resize the foreground to match the background.
    if foreground.size != background.size:
        print("Resizing foreground to match background dimensions...")
        foreground = foreground.resize(background.size, Image.ANTIALIAS)

    # Step 5: Composite the foreground over the new background.
    print("Compositing images...")
    final_image = Image.alpha_composite(background, foreground)

    return final_image

# -----------------------------------------------------------------------------
# Load the Stable Diffusion pipeline from Hugging Face.
# -----------------------------------------------------------------------------
MODEL_ID = "stabilityai/stable-diffusion-2"  # You may change the model if desired.

# Use half precision if GPU is available.
if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

print("Loading Stable Diffusion pipeline...")
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch_dtype)
if torch.cuda.is_available():
    pipe = pipe.to("cuda")
print("Stable Diffusion pipeline loaded.")

# -----------------------------------------------------------------------------
# Create the Gradio Interface.
# -----------------------------------------------------------------------------
title = "Background Removal & Replacement"
description = (
    "Upload an image (e.g., a person or an animal) and provide a text prompt "
    "describing the new background. The app will remove the original background and "
    "composite the subject onto a generated background."
)

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.inputs.Image(type="pil", label="Upload Your Image"),
        gr.inputs.Textbox(lines=2, placeholder="Describe the new background...", label="Background Prompt")
    ],
    outputs=gr.outputs.Image(type="pil", label="Output Image"),
    title=title,
    description=description,
    allow_flagging="never"
)

# -----------------------------------------------------------------------------
# Launch the app.
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    iface.launch()