import torch import gradio as gr from PIL import Image from rembg import remove from diffusers import StableDiffusionPipeline # ----------------------------------------------------------------------------- # Helper function to adjust image size to multiples of 8. # ----------------------------------------------------------------------------- def adjust_size(w, h): """ Adjust width and height to be multiples of 8, as required by the Stable Diffusion model. """ new_w = (w // 8) * 8 new_h = (h // 8) * 8 return new_w, new_h # ----------------------------------------------------------------------------- # Core processing function: # 1. Remove background from the uploaded image. # 2. Generate a new background image based on the text prompt. # 3. Composite the foreground onto the generated background. # ----------------------------------------------------------------------------- def process_image(input_image: Image.Image, bg_prompt: str) -> Image.Image: """ Processes the uploaded image by removing its background and replacing it with a generated one. Parameters: input_image (PIL.Image.Image): The uploaded image. bg_prompt (str): Text prompt describing the new background. Returns: PIL.Image.Image: The final composited image. """ if input_image is None: raise ValueError("No image provided.") # Step 1: Remove the background from the input image. print("Removing background from the uploaded image...") foreground = remove(input_image) foreground = foreground.convert("RGBA") # Step 2: Adjust dimensions for background generation. orig_w, orig_h = foreground.size gen_w, gen_h = adjust_size(orig_w, orig_h) print(f"Original size: {orig_w}x{orig_h} | Adjusted size: {gen_w}x{gen_h}") # Step 3: Generate a new background using the provided text prompt. print("Generating new background using Stable Diffusion...") bg_output = pipe( bg_prompt, height=gen_h, width=gen_w, num_inference_steps=50, # Adjust if needed. guidance_scale=7.5 # Adjust for more/less prompt adherence. ) # Convert the generated background to RGBA. background = bg_output.images[0].convert("RGBA") # Step 4: Ensure the foreground matches the background dimensions. if foreground.size != background.size: print("Resizing foreground to match background dimensions...") foreground = foreground.resize(background.size, Image.ANTIALIAS) # Step 5: Composite the images. print("Compositing images...") final_image = Image.alpha_composite(background, foreground) return final_image # ----------------------------------------------------------------------------- # Load the Stable Diffusion pipeline from Hugging Face. # ----------------------------------------------------------------------------- MODEL_ID = "stabilityai/stable-diffusion-2" # Change the model if desired. # Use half precision if GPU is available. if torch.cuda.is_available(): torch_dtype = torch.float16 else: torch_dtype = torch.float32 print("Loading Stable Diffusion pipeline...") pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch_dtype) if torch.cuda.is_available(): pipe = pipe.to("cuda") print("Stable Diffusion pipeline loaded.") # ----------------------------------------------------------------------------- # Create the Gradio Interface using the updated API. # ----------------------------------------------------------------------------- title = "Background Removal & Replacement" description = ( "Upload an image (e.g., a person or an animal) and provide a text prompt " "describing the new background. The app will remove the original background and " "composite the subject onto a generated background." ) iface = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil", label="Upload Your Image"), gr.Textbox(lines=2, placeholder="Describe the new background...", label="Background Prompt") ], outputs=gr.Image(label="Output Image"), title=title, description=description, allow_flagging="never" ) # ----------------------------------------------------------------------------- # Launch the app. # ----------------------------------------------------------------------------- if __name__ == "__main__": iface.launch()