File size: 10,239 Bytes
6f168c5
 
 
 
 
 
 
 
 
 
5028300
6f168c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5028300
6f168c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1bbddf
 
6f168c5
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import os
import time
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import json
import gradio as gr
import tempfile
import spaces

# Set environment variable to reduce memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Check if CUDA is available, fallback to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32

# Load pipeline with error handling for HF Spaces
try:
    pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-refiner-1.0",
        torch_dtype=torch_dtype,
        variant="fp16" if device == "cuda" else None,
        use_safetensors=True
    )
    
    # Move to device
    pipe = pipe.to(device)
    
    # Enable optimizations based on available hardware
    if device == "cuda":
        # Use CPU offloading to reduce VRAM usage on GPU
        pipe.enable_model_cpu_offload()
        
        # Try to enable memory efficient attention
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except (ModuleNotFoundError, ImportError):
            print("xformers not available, using attention slicing")
            pipe.enable_attention_slicing()
    else:
        # For CPU inference, enable attention slicing
        pipe.enable_attention_slicing()
        
except Exception as e:
    print(f"Error loading pipeline: {e}")
    pipe = None


@spaces.GPU
def img2img(
    uploaded_image,
    image_url: str,
    prompt: str,
    negative_prompt: str = "",
    strength: float = 0.7,
    guidance_scale: float = 3.5,
    num_inference_steps: int = 50,
    seed: int = -1,
):
    if pipe is None:
        return None, "❌ Model failed to load. Please try again later.", None
        
    try:
        # Choose image source
        if uploaded_image is not None:
            init_image = Image.open(uploaded_image).convert("RGB")
        elif image_url.strip() != "":
            try:
                init_image = load_image(image_url).convert("RGB")
            except Exception as e:
                return None, f"❌ Failed to load image from URL: {str(e)}", None
        else:
            return None, "❌ Please upload an image or enter a valid URL", None

        # Resize image (keeping aspect ratio consideration for better results)
        init_image.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
        
        # Ensure dimensions are multiples of 8 for SDXL
        width, height = init_image.size
        width = (width // 8) * 8
        height = (height // 8) * 8
        init_image = init_image.resize((width, height))

        # Set seed and generator
        if seed == -1:
            generator = torch.Generator(device=device)
        else:
            generator = torch.Generator(device=device).manual_seed(seed)

        # Validate inputs
        if not prompt.strip():
            return None, "❌ Please enter a prompt", None

        # Run inference with progress tracking
        with torch.inference_mode():
            result = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt if negative_prompt.strip() else None,
                image=init_image,
                strength=max(0.1, min(1.0, strength)),  # Clamp strength
                guidance_scale=max(1.0, min(20.0, guidance_scale)),  # Clamp guidance
                num_inference_steps=max(10, min(100, num_inference_steps)),  # Clamp steps
                generator=generator
            ).images[0]

        used_seed = generator.initial_seed()

        # Create metadata dictionary
        metadata = {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "seed": used_seed,
            "model": "stabilityai/stable-diffusion-xl-refiner-1.0",
            "pipeline": "StableDiffusionXLImg2ImgPipeline",
            "guidance_scale": guidance_scale,
            "strength": strength,
            "steps": num_inference_steps,
            "width": result.width,
            "height": result.height,
            "device": device
        }

        # Save metadata into PNG
        png_info = PngInfo()
        png_info.add_text("parameters", json.dumps(metadata))

        # Use temporary file for HF Spaces
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
            output_path = tmp_file.name
            result.save(output_path, format="PNG", pnginfo=png_info)

        # Build markdown preview of metadata
        metadata_str = (
            f"**Prompt:** {metadata['prompt']}\n\n"
            f"**Negative Prompt:** {metadata['negative_prompt']}\n\n"
            f"**Seed:** {metadata['seed']}\n\n"
            f"**Model:** {metadata['model']}\n\n"
            f"**Guidance Scale:** {metadata['guidance_scale']}\n\n"
            f"**Strength:** {metadata['strength']}\n\n"
            f"**Steps:** {metadata['steps']}\n\n"
            f"**Dimensions:** {metadata['width']}x{metadata['height']}\n\n"
            f"**Device:** {metadata['device']}"
        )

        return output_path, f"βœ… **Generation Complete!**\n\n{metadata_str}", output_path

    except torch.cuda.OutOfMemoryError:
        return None, "❌ GPU out of memory. Try reducing image size or inference steps.", None
    except Exception as e:
        return None, f"❌ Error during generation: {str(e)}", None


# Define UI components with better styling
title = "🎨 SDXL Image-to-Image Editor"
description = """
Transform your images with AI! Upload an image and describe the changes you want to make.

**Tips:**
- Use detailed prompts for better results
- Lower strength values preserve more of the original image
- Higher guidance scale follows your prompt more closely
"""

# Custom CSS for better appearance
css = """
.gradio-container {
    font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
    color: white;
    background: linear-gradient(90deg, #4f46e5, #7c3aed);
    border: none;
}
.gr-button:hover {
    background: linear-gradient(90deg, #4338ca, #6d28d9);
}
"""

with gr.Blocks(title=title, css=css, theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"# {title}")
    gr.Markdown(description)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### πŸ“Έ Input Image")
            uploaded_image = gr.Image(
                label="Upload Image", 
                type="filepath",
                height=300
            )
            
            gr.Markdown("**Or**")
            image_url = gr.Textbox(
                label="Image URL", 
                placeholder="https://example.com/image.jpg",
                info="Paste a direct link to an image"
            )

            gr.Markdown("### ✍️ Prompts")
            prompt = gr.Textbox(
                label="Prompt", 
                placeholder="a beautiful sunset over mountains, photorealistic, detailed",
                lines=3,
                info="Describe what you want to see"
            )
            negative_prompt = gr.Textbox(
                label="Negative Prompt", 
                placeholder="blurry, low quality, distorted",
                lines=2,
                info="What to avoid in the image"
            )

            gr.Markdown("### βš™οΈ Settings")
            with gr.Row():
                strength = gr.Slider(
                    minimum=0.1, maximum=1.0, value=0.7, step=0.05,
                    label="Transformation Strength",
                    info="0.1 = subtle changes, 1.0 = major changes"
                )
                guidance_scale = gr.Slider(
                    minimum=1.0, maximum=20.0, value=7.5, step=0.5,
                    label="Guidance Scale",
                    info="How closely to follow the prompt"
                )
                
            with gr.Row():
                num_inference_steps = gr.Slider(
                    minimum=10, maximum=50, step=5, value=30,
                    label="Quality Steps",
                    info="More steps = higher quality but slower"
                )
                seed = gr.Slider(
                    minimum=-1, maximum=999999, step=1, value=-1,
                    label="Seed",
                    info="-1 for random"
                )

            submit_btn = gr.Button("πŸš€ Generate Image", variant="primary", size="lg")

        with gr.Column(scale=1):
            gr.Markdown("### πŸ–ΌοΈ Result")
            image_output = gr.Image(label="Generated Image", height=400)
            download_button = gr.File(label="πŸ“₯ Download Full Resolution", visible=False)
            
            gr.Markdown("### πŸ“Š Generation Details")
            metadata_output = gr.Markdown()

    # Event handlers
    submit_btn.click(
        fn=img2img,
        inputs=[
            uploaded_image,
            image_url,
            prompt,
            negative_prompt,
            strength,
            guidance_scale,
            num_inference_steps,
            seed
        ],
        outputs=[image_output, metadata_output, download_button]
    ).then(
        lambda x: gr.update(visible=x is not None),
        inputs=[image_output],
        outputs=[download_button]
    )

    # Examples
    gr.Markdown("### 🎯 Examples")
    gr.Examples(
        examples=[
            [None, "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png", "make it a van gogh painting", "blurry, low quality", 0.8, 7.5, 30, 42],
            [None, "https://picsum.photos/512/512?random=1", "turn into a cyberpunk cityscape", "blurry, distorted", 0.9, 8.0, 30, 123],
        ],
        inputs=[uploaded_image, image_url, prompt, negative_prompt, strength, guidance_scale, num_inference_steps, seed],
    )

# Launch configuration for HF Spaces
if __name__ == "__main__":
    demo.queue(max_size=20)  # Enable queuing for better performance
    demo.launch(
        show_error=True,
        share=False,  # Don't create gradio.live links in HF Spaces
        inbrowser=False,  # Don't try to open browser in cloud environment
        quiet=False
    )