tdurzynski's picture
Update app.py
5661265 verified
import torch
import gradio as gr
from PIL import Image
from rembg import remove
from diffusers import StableDiffusionPipeline
# -----------------------------------------------------------------------------
# Helper function to adjust image size to multiples of 8.
# -----------------------------------------------------------------------------
def adjust_size(w, h):
"""
Adjust width and height to be multiples of 8, as required by the Stable Diffusion model.
"""
new_w = (w // 8) * 8
new_h = (h // 8) * 8
return new_w, new_h
# -----------------------------------------------------------------------------
# Core processing function:
# 1. Remove background from the uploaded image.
# 2. Generate a new background image based on the text prompt.
# 3. Composite the foreground onto the generated background.
# -----------------------------------------------------------------------------
def process_image(input_image: Image.Image, bg_prompt: str) -> Image.Image:
"""
Processes the uploaded image by removing its background and replacing it with a generated one.
Parameters:
input_image (PIL.Image.Image): The uploaded image.
bg_prompt (str): Text prompt describing the new background.
Returns:
PIL.Image.Image: The final composited image.
"""
if input_image is None:
raise ValueError("No image provided.")
# Step 1: Remove the background from the input image.
print("Removing background from the uploaded image...")
foreground = remove(input_image)
foreground = foreground.convert("RGBA")
# Step 2: Adjust dimensions for background generation.
orig_w, orig_h = foreground.size
gen_w, gen_h = adjust_size(orig_w, orig_h)
print(f"Original size: {orig_w}x{orig_h} | Adjusted size: {gen_w}x{gen_h}")
# Step 3: Generate a new background using the provided text prompt.
print("Generating new background using Stable Diffusion...")
bg_output = pipe(
bg_prompt,
height=gen_h,
width=gen_w,
num_inference_steps=50, # Adjust if needed.
guidance_scale=7.5 # Adjust for more/less prompt adherence.
)
# Convert the generated background to RGBA.
background = bg_output.images[0].convert("RGBA")
# Step 4: Ensure the foreground matches the background dimensions.
if foreground.size != background.size:
print("Resizing foreground to match background dimensions...")
foreground = foreground.resize(background.size, Image.ANTIALIAS)
# Step 5: Composite the images.
print("Compositing images...")
final_image = Image.alpha_composite(background, foreground)
return final_image
# -----------------------------------------------------------------------------
# Load the Stable Diffusion pipeline from Hugging Face.
# -----------------------------------------------------------------------------
MODEL_ID = "stabilityai/stable-diffusion-2" # Change the model if desired.
# Use half precision if GPU is available.
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
print("Loading Stable Diffusion pipeline...")
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch_dtype)
if torch.cuda.is_available():
pipe = pipe.to("cuda")
print("Stable Diffusion pipeline loaded.")
# -----------------------------------------------------------------------------
# Create the Gradio Interface using the updated API.
# -----------------------------------------------------------------------------
title = "Background Removal & Replacement"
description = (
"Upload an image (e.g., a person or an animal) and provide a text prompt "
"describing the new background. The app will remove the original background and "
"composite the subject onto a generated background."
)
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload Your Image"),
gr.Textbox(lines=2, placeholder="Describe the new background...", label="Background Prompt")
],
outputs=gr.Image(label="Output Image"),
title=title,
description=description,
allow_flagging="never"
)
# -----------------------------------------------------------------------------
# Launch the app.
# -----------------------------------------------------------------------------
if __name__ == "__main__":
iface.launch()