|
import torch |
|
import gradio as gr |
|
from PIL import Image |
|
from rembg import remove |
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
|
|
|
|
def adjust_size(w, h): |
|
""" |
|
Adjust width and height to be multiples of 8, as required by the Stable Diffusion model. |
|
""" |
|
new_w = (w // 8) * 8 |
|
new_h = (h // 8) * 8 |
|
return new_w, new_h |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_image(input_image: Image.Image, bg_prompt: str) -> Image.Image: |
|
""" |
|
Processes the uploaded image by removing its background and replacing it with a generated one. |
|
|
|
Parameters: |
|
input_image (PIL.Image.Image): The uploaded image. |
|
bg_prompt (str): Text prompt describing the new background. |
|
|
|
Returns: |
|
PIL.Image.Image: The final composited image. |
|
""" |
|
if input_image is None: |
|
raise ValueError("No image provided.") |
|
|
|
|
|
print("Removing background from the uploaded image...") |
|
foreground = remove(input_image) |
|
foreground = foreground.convert("RGBA") |
|
|
|
|
|
orig_w, orig_h = foreground.size |
|
gen_w, gen_h = adjust_size(orig_w, orig_h) |
|
print(f"Original size: {orig_w}x{orig_h} | Adjusted size: {gen_w}x{gen_h}") |
|
|
|
|
|
print("Generating new background using Stable Diffusion...") |
|
bg_output = pipe( |
|
bg_prompt, |
|
height=gen_h, |
|
width=gen_w, |
|
num_inference_steps=50, |
|
guidance_scale=7.5 |
|
) |
|
|
|
background = bg_output.images[0].convert("RGBA") |
|
|
|
|
|
if foreground.size != background.size: |
|
print("Resizing foreground to match background dimensions...") |
|
foreground = foreground.resize(background.size, Image.ANTIALIAS) |
|
|
|
|
|
print("Compositing images...") |
|
final_image = Image.alpha_composite(background, foreground) |
|
|
|
return final_image |
|
|
|
|
|
|
|
|
|
MODEL_ID = "stabilityai/stable-diffusion-2" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = torch.float32 |
|
|
|
print("Loading Stable Diffusion pipeline...") |
|
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch_dtype) |
|
if torch.cuda.is_available(): |
|
pipe = pipe.to("cuda") |
|
print("Stable Diffusion pipeline loaded.") |
|
|
|
|
|
|
|
|
|
title = "Background Removal & Replacement" |
|
description = ( |
|
"Upload an image (e.g., a person or an animal) and provide a text prompt " |
|
"describing the new background. The app will remove the original background and " |
|
"composite the subject onto a generated background." |
|
) |
|
|
|
iface = gr.Interface( |
|
fn=process_image, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Your Image"), |
|
gr.Textbox(lines=2, placeholder="Describe the new background...", label="Background Prompt") |
|
], |
|
outputs=gr.Image(label="Output Image"), |
|
title=title, |
|
description=description, |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|