tdurzynski commited on
Commit
f1fb6bb
·
verified ·
1 Parent(s): 8392bf4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from rembg import remove
5
+ from diffusers import StableDiffusionPipeline
6
+
7
+ # -----------------------------------------------------------------------------
8
+ # Helper function to adjust image size to multiples of 8.
9
+ # -----------------------------------------------------------------------------
10
+ def adjust_size(w, h):
11
+ """
12
+ Adjust width and height to be multiples of 8, as required by the Stable Diffusion model.
13
+ """
14
+ new_w = (w // 8) * 8
15
+ new_h = (h // 8) * 8
16
+ return new_w, new_h
17
+
18
+ # -----------------------------------------------------------------------------
19
+ # Core processing function:
20
+ # 1. Remove background from the uploaded image.
21
+ # 2. Generate a new background image based on the text prompt.
22
+ # 3. Composite the foreground onto the generated background.
23
+ # -----------------------------------------------------------------------------
24
+ def process_image(input_image: Image.Image, bg_prompt: str) -> Image.Image:
25
+ """
26
+ Processes the uploaded image by removing its background and replacing it with a generated one.
27
+
28
+ Parameters:
29
+ input_image (PIL.Image.Image): The uploaded image.
30
+ bg_prompt (str): Text prompt describing the new background.
31
+
32
+ Returns:
33
+ PIL.Image.Image: The final composited image.
34
+ """
35
+ if input_image is None:
36
+ raise ValueError("No image provided.")
37
+
38
+ # Step 1: Remove the background from the input image.
39
+ print("Removing background from the uploaded image...")
40
+ foreground = remove(input_image)
41
+ foreground = foreground.convert("RGBA")
42
+
43
+ # Step 2: Determine new dimensions (multiples of 8) based on the foreground.
44
+ orig_w, orig_h = foreground.size
45
+ gen_w, gen_h = adjust_size(orig_w, orig_h)
46
+ print(f"Original size: {orig_w}x{orig_h} | Adjusted size: {gen_w}x{gen_h}")
47
+
48
+ # Step 3: Generate a new background using the provided text prompt.
49
+ print("Generating new background using Stable Diffusion...")
50
+ bg_output = pipe(
51
+ bg_prompt,
52
+ height=gen_h,
53
+ width=gen_w,
54
+ num_inference_steps=50, # Adjust as needed.
55
+ guidance_scale=7.5 # Adjust for prompt adherence.
56
+ )
57
+ # The generated background is in RGB mode; convert to RGBA for compositing.
58
+ background = bg_output.images[0].convert("RGBA")
59
+
60
+ # Step 4: If necessary, resize the foreground to match the background.
61
+ if foreground.size != background.size:
62
+ print("Resizing foreground to match background dimensions...")
63
+ foreground = foreground.resize(background.size, Image.ANTIALIAS)
64
+
65
+ # Step 5: Composite the foreground over the new background.
66
+ print("Compositing images...")
67
+ final_image = Image.alpha_composite(background, foreground)
68
+
69
+ return final_image
70
+
71
+ # -----------------------------------------------------------------------------
72
+ # Load the Stable Diffusion pipeline from Hugging Face.
73
+ # -----------------------------------------------------------------------------
74
+ MODEL_ID = "stabilityai/stable-diffusion-2" # You may change the model if desired.
75
+
76
+ # Use half precision if GPU is available.
77
+ if torch.cuda.is_available():
78
+ torch_dtype = torch.float16
79
+ else:
80
+ torch_dtype = torch.float32
81
+
82
+ print("Loading Stable Diffusion pipeline...")
83
+ pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch_dtype)
84
+ if torch.cuda.is_available():
85
+ pipe = pipe.to("cuda")
86
+ print("Stable Diffusion pipeline loaded.")
87
+
88
+ # -----------------------------------------------------------------------------
89
+ # Create the Gradio Interface.
90
+ # -----------------------------------------------------------------------------
91
+ title = "Background Removal & Replacement"
92
+ description = (
93
+ "Upload an image (e.g., a person or an animal) and provide a text prompt "
94
+ "describing the new background. The app will remove the original background and "
95
+ "composite the subject onto a generated background."
96
+ )
97
+
98
+ iface = gr.Interface(
99
+ fn=process_image,
100
+ inputs=[
101
+ gr.inputs.Image(type="pil", label="Upload Your Image"),
102
+ gr.inputs.Textbox(lines=2, placeholder="Describe the new background...", label="Background Prompt")
103
+ ],
104
+ outputs=gr.outputs.Image(type="pil", label="Output Image"),
105
+ title=title,
106
+ description=description,
107
+ allow_flagging="never"
108
+ )
109
+
110
+ # -----------------------------------------------------------------------------
111
+ # Launch the app.
112
+ # -----------------------------------------------------------------------------
113
+ if __name__ == "__main__":
114
+ iface.launch()