alexnasa commited on
Commit
a0fd130
·
verified ·
1 Parent(s): 873eb28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +519 -372
app.py CHANGED
@@ -1,431 +1,578 @@
1
- #!/usr/bin/env python
2
- # Out-of-Focus v1.0 — Zero GPU-ready edition
3
- # -------------------------------------------------------------
4
- # 0. Imports (⚠️ keep `import spaces` FIRST)
5
- # -------------------------------------------------------------
6
- import warnings, os, gc, math, argparse, pickle
7
- warnings.filterwarnings("ignore")
8
 
9
- import spaces # ← mandatory for Zero GPU
10
- import torch, torchvision
 
 
 
 
 
 
11
  import torch.nn.functional as F
 
 
12
  import numpy as np
13
-
14
- from typing import Optional, Dict, Any
 
15
  from PIL import Image
16
- from diffusers import (DiffusionPipeline, DDIMInverseScheduler,
17
- DDIMScheduler, AutoencoderKL)
18
- from diffusers.models.attention_processor import (
19
- Attention, AttnProcessor2_0
20
- )
21
  from safetensors.torch import load_file
22
  from huggingface_hub import hf_hub_download
23
- import gradio as gr
24
-
25
- # -------------------------------------------------------------
26
- # 1. Globals (initialised lazily inside the GPU context)
27
- # -------------------------------------------------------------
28
- PIPE: Optional[DiffusionPipeline] = None
29
- INVERSE_SCHEDULER: Optional[DDIMInverseScheduler] = None
30
- SCHEDULER: Optional[DDIMScheduler] = None
31
- TORCH_DTYPE = torch.float16 # H100/A100 FP16 slice
32
-
33
- # your existing state dictionaries / sliders
34
- weights: Dict[str, Dict[int, Dict[int, float]]] = {}
35
- res_list, foreground_mask = [], None
36
- heighest_resolution, signal_value, blur_value = -1, 2.0, None
37
- allowed_res_max = 1.0
38
- guidance_scale_value, num_inference_steps = 7.5, 10
39
- max_scale_value = 16
40
- res_range_min, res_range_max = 128, 1024
41
-
42
- # -------------------------------------------------------------
43
- # 2. Lazy pipeline loader (runs inside GPU context)
44
- # -------------------------------------------------------------
45
- def _get_pipeline() -> tuple[DiffusionPipeline,
46
- DDIMInverseScheduler,
47
- DDIMScheduler]:
48
- """Initialise Stable Diffusion + schedulers on first call."""
49
- global PIPE, INVERSE_SCHEDULER, SCHEDULER
50
-
51
- if PIPE is None: # first GPU call ➜ download
52
- model_id = "runwayml/stable-diffusion-v1-5"
53
- vae_folder = "vae"
54
- resadapter_name = "resadapter_v2_sd1.5"
55
-
56
- PIPE = DiffusionPipeline.from_pretrained(
57
- model_id, torch_dtype=TORCH_DTYPE
58
- ).to("cuda")
59
-
60
- # external VAE
61
- PIPE.vae = AutoencoderKL.from_pretrained(
62
- model_id, subfolder=vae_folder, torch_dtype=TORCH_DTYPE
63
- ).to("cuda")
64
-
65
- # Res-Adapter LoRA + Norm weights
66
- lora_path = hf_hub_download(
67
- "jiaxiangc/res-adapter",
68
- subfolder=resadapter_name,
69
- filename="pytorch_lora_weights.safetensors"
70
- )
71
- norm_path = hf_hub_download(
72
- "jiaxiangc/res-adapter",
73
- subfolder=resadapter_name,
74
- filename="diffusion_pytorch_model.safetensors"
75
- )
76
- PIPE.load_lora_weights(lora_path, adapter_name="res_adapter")
77
- PIPE.set_adapters(["res_adapter"], adapter_weights=[1.0])
78
- PIPE.unet.load_state_dict(load_file(norm_path), strict=False)
79
 
80
- # schedulers
81
- INVERSE_SCHEDULER = DDIMInverseScheduler.from_pretrained(
82
- model_id, subfolder="scheduler"
83
- )
84
- SCHEDULER = DDIMScheduler.from_pretrained(
85
- model_id, subfolder="scheduler"
86
- )
87
- return PIPE, INVERSE_SCHEDULER, SCHEDULER
88
-
89
- # -------------------------------------------------------------
90
- # 3. Helper functions (unchanged from your original)
91
- # -------------------------------------------------------------
92
- def save_state_to_file(state): # … unchanged
93
  filename = "state.pkl"
94
  with open(filename, "wb") as f:
95
  pickle.dump(state, f)
96
  return filename
97
 
98
- def load_state_from_file(filename): # … unchanged
 
99
  with open(filename, "rb") as f:
100
- return pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def weight_population(layer_type, resolution, depth, value):
103
- global heighest_resolution
104
  if layer_type not in weights:
105
  weights[layer_type] = {}
 
 
106
  if resolution not in weights[layer_type]:
107
  weights[layer_type][resolution] = {}
 
 
108
  if resolution > heighest_resolution:
109
  heighest_resolution = resolution
 
 
110
  weights[layer_type][resolution][depth] = value
111
 
112
- def resize_image_with_aspect(img, res_min=128, res_max=1024):
113
- w, h = img.size
114
- if w < res_min or h < res_min:
115
- s = max(res_min / w, res_min / h)
116
- elif w > res_max or h > res_max:
117
- s = min(res_max / w, res_max / h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  else:
119
- s = 1
120
- return img.resize(
121
- (int(w * s), int(h * s)), Image.Resampling.LANCZOS
122
- )
123
 
124
- def adjust_ends(vals, delta):
125
- # helpers used by update_scale
126
- for i in range(len(vals)):
127
- if (delta > 0 and vals[i + 1] == 1.0) or (
128
- delta < 0 and vals[i] > 0.0
129
- ):
130
- vals[i] += delta
131
- break
132
- for i in range(len(vals) - 1, -1, -1):
133
- if (delta > 0 and vals[i - 1] == 1.0) or (
134
- delta < 0 and vals[i] > 0.0
135
- ):
136
- vals[i] += delta
137
- break
138
- return vals
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- def update_scale(scale):
141
- global weights
142
- values_flat = []
143
- for _, d in weights.items():
144
- for _, v in d.items():
145
- for _ in v:
146
- values_flat.append(1.0)
147
- for _ in range(scale, max_scale_value):
148
- adjust_ends(values_flat, -0.5)
149
- idx = 0
150
- for k1, d in weights.items():
151
- for k2 in d:
152
- for k3 in d[k2]:
153
- weights[k1][k2][k3] = values_flat[idx]
154
- idx += 1
155
-
156
- # -------------------------------------------------------------
157
- # 4. Custom attention processor (unchanged)
158
- # -------------------------------------------------------------
159
  class AttnReplaceProcessor(AttnProcessor2_0):
160
- def __init__(self, replace_all, layer_type,
161
- layer_count, blur_sigma=None):
162
  super().__init__()
163
  self.replace_all = replace_all
164
  self.layer_type = layer_type
165
  self.layer_count = layer_count
 
166
  self.blur_sigma = blur_sigma
167
 
168
  def __call__(
169
- self, attn: Attention, hidden_states: torch.Tensor,
170
- encoder_hidden_states: Optional[torch.Tensor] = None,
171
- attention_mask: Optional[torch.Tensor] = None,
172
- temb: Optional[torch.Tensor] = None, *args, **kwargs
173
- ) -> torch.Tensor:
174
-
175
- dim2 = hidden_states.shape[1]
176
- is_cross = encoder_hidden_states is not None
 
 
 
 
 
 
 
177
  residual = hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- # (norms & projections identical to original code)
180
- # --- omitted for brevity, copy your original implementation ---
181
- # replace attention values when self.replace_all is True
182
- # using global `weights`
183
- # --------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- return hidden_states # after residual & rescale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def replace_attention_processor(unet, clear=False, blur_sigma=None):
188
- attn_count = 0
 
 
189
  for name, module in unet.named_modules():
190
  if "attn1" in name and "to" not in name:
191
  layer_type = name.split(".")[0].split("_")[0]
192
- attn_count += 1
193
- module.processor = AttnReplaceProcessor(
194
- not clear, layer_type, attn_count, blur_sigma
195
- )
196
-
197
- # -------------------------------------------------------------
198
- # 5. GPU-bound functions
199
- # -------------------------------------------------------------
200
- @spaces.GPU(duration=120) # 2 min quota
201
- def reconstruct(input_img: Image.Image, caption: str):
202
- """
203
- Reconstruct the input image & latents.
204
- Returns: (np_image, caption, slider_val, meta_state)
205
- """
206
- pipe, inv_sched, sched = _get_pipeline()
207
-
208
- img = resize_image_with_aspect(input_img,
209
- res_range_min, res_range_max)
210
- transform = torchvision.transforms.ToTensor()
211
- loaded = transform(img).half().to("cuda").unsqueeze(0)
212
- if loaded.shape[1] == 4: # drop alpha
213
- loaded = loaded[:, :3, :, :]
214
 
215
- with torch.no_grad():
216
- enc = pipe.vae.encode(loaded * 2 - 1)
217
- real_latents = pipe.vae.config.scaling_factor * \
218
- enc.latent_dist.sample()
219
 
220
- # inversion pass
221
- inv_sched.set_timesteps(num_inference_steps, device="cuda")
222
- latents = real_latents.clone()
223
- inversed_latents = [latents]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- def store_latent(_, step, __, cb_kwargs):
226
- if step != num_inference_steps - 1:
227
- inversed_latents.append(cb_kwargs["latents"])
228
- return cb_kwargs
229
-
230
- replace_attention_processor(pipe.unet, True)
231
- pipe.scheduler = inv_sched
232
- pipe(prompt=caption,
233
- guidance_scale=1.0,
234
- output_type="latent",
235
- num_inference_steps=num_inference_steps,
236
- latents=latents,
237
- callback_on_step_end=store_latent,
238
- callback_on_step_end_tensor_inputs=["latents"])
239
-
240
- real_initial = inversed_latents[-1]
241
- # forward synthesis with CFG
242
- sched.set_timesteps(num_inference_steps, device="cuda")
243
- replace_attention_processor(pipe.unet, True)
244
-
245
- def adjust_latent(_, step, __, cb_kwargs):
246
- cb_kwargs["latents"] = inversed_latents[
247
- len(sched.timesteps) - 1 - step
248
- ].detach()
249
- return cb_kwargs
250
-
251
- latents = pipe(prompt=caption,
252
- guidance_scale=guidance_scale_value,
253
- output_type="latent",
254
- num_inference_steps=num_inference_steps,
255
- latents=real_initial,
256
- callback_on_step_end=adjust_latent,
257
- callback_on_step_end_tensor_inputs=["latents"])[0]
258
-
259
- image = pipe.vae.decode(
260
- latents / pipe.vae.config.scaling_factor, return_dict=False
261
- )[0]
262
- img_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu()
263
- img_np = ((img_np / 2 + 0.5).clamp(0, 1).numpy() * 255).astype(np.uint8)
264
-
265
- update_scale(12) # initial cross-attn value
266
-
267
- pipe.to("cpu"); torch.cuda.empty_cache()
268
- return img_np, caption, 12, [caption, real_initial.detach(),
269
- inversed_latents, weights]
270
-
271
- @spaces.GPU(duration=120) # 2 min quota
272
- def apply_prompt(meta_data: Any, new_prompt: str):
273
- """
274
- Re-generate the image using stored latents + new prompt.
275
- """
276
- pipe, _, sched = _get_pipeline()
277
- caption, real_latents, inversed, _ = meta_data
278
-
279
- steps = len(inversed)
280
- sched.set_timesteps(steps, device="cuda")
281
-
282
- initial = torch.cat([real_latents] * 2)
283
- def adjust_latent(_, step, __, cb_kwargs):
284
  replace_attention_processor(pipe.unet)
285
- delta = inversed[len(sched.timesteps) - 1 - step].detach()
286
- cb_kwargs["latents"][1] += delta - cb_kwargs["latents"][0]
287
- cb_kwargs["latents"][0] = delta
288
- return cb_kwargs
289
-
290
- latents = pipe(
291
- prompt=[caption, new_prompt],
292
- negative_prompt=["", ""],
293
- guidance_scale=guidance_scale_value,
294
- output_type="latent",
295
- num_inference_steps=steps,
296
- latents=initial,
297
- callback_on_step_end=adjust_latent,
298
- callback_on_step_end_tensor_inputs=["latents"]
299
- )[0][1]
300
-
301
- replace_attention_processor(pipe.unet, True)
302
- image = pipe.vae.decode(
303
- latents.unsqueeze(0) / pipe.vae.config.scaling_factor,
304
- return_dict=False
305
- )[0]
306
- img_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu()
307
- img_np = ((img_np / 2 + 0.5).clamp(0, 1).numpy() * 255).astype(np.uint8)
308
-
309
- pipe.to("cpu"); torch.cuda.empty_cache()
310
- return img_np
311
-
312
- # -------------------------------------------------------------
313
- # 6. Lightweight CPU callbacks
314
- # -------------------------------------------------------------
315
  def on_image_change(filepath):
316
- fname = os.path.splitext(os.path.basename(filepath))[0]
317
- if fname in ["example1", "example3", "example4"]:
318
- meta = load_state_from_file(f"assets/{fname}-turbo.pkl")
 
 
 
 
319
  global weights
320
- _, _, _, weights = meta
 
321
  global num_inference_steps
322
  num_inference_steps = 10
323
- scale_val = 8 if fname == "example1" else 6 if fname == "example3" else 13
324
- new_prompt = {
325
- "example1": "a photo of a tree, summer, colourful",
326
- "example3": ("a realistic photo of a female warrior, flowing "
327
- "dark purple or black hair, bronze shoulder armour, "
328
- "leather chest piece, sky background with clouds"),
329
- "example4": ("a photo of plastic bottle on some sand, beach "
330
- "background, sky background")
331
- }[fname]
332
- update_scale(scale_val)
333
- img = apply_prompt(meta, new_prompt)
334
- return filepath, img, meta, num_inference_steps, scale_val, scale_val
335
- return None
336
-
337
- def update_step(val):
 
 
 
 
 
 
 
 
 
 
 
338
  global num_inference_steps
339
- num_inference_steps = val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
- # -------------------------------------------------------------
342
- # 7. Gradio UI (unchanged layout)
343
- # -------------------------------------------------------------
344
  with gr.Blocks(analytics_enabled=False) as demo:
345
  gr.Markdown(
346
- """<div style="text-align:center">
347
- <img src="https://github.com/user-attachments/assets/55a38e74-ab93-4d80-91c8-0fa6130af45a">
348
- <h1>Out of Focus v1.0 Turbo (Zero GPU)</h1>
349
- <p>Prompt-based image reconstruction & manipulation.</p></div>"""
 
 
 
 
 
 
 
 
 
 
350
  )
351
-
352
  with gr.Row():
353
  with gr.Column():
354
- example_in = gr.Image(type="filepath", visible=False)
355
- img_in = gr.Image(type="pil",
356
- label="Upload Source Image")
357
- steps = gr.Slider(minimum=5, maximum=50, step=5,
358
- value=num_inference_steps,
359
- label="Steps")
360
- prompt_box = gr.Textbox(label="Prompt")
361
- recon_btn = gr.Button("Reconstruct")
362
  with gr.Column():
363
- recon_img = gr.Image(type="pil", label="Result")
364
- inv_slider = gr.Slider(minimum=0, maximum=9, step=1,
365
- value=7, visible=False)
366
- xattn = gr.Slider(minimum=0, maximum=max_scale_value,
367
- step=1, value=max_scale_value,
368
- label="Cross-Attention Influence")
369
- new_box = gr.Textbox(label="New Prompt", interactive=False)
370
- apply_btn = gr.Button("Generate Vision",
371
- variant="primary", interactive=False)
372
-
373
- gr.Examples(
374
- examples=[
375
- ["assets/example4.png",
376
- "a photo of plastic bottle on a rock, mountain background, sky background",
377
- "a photo of plastic bottle on some sand, beach background, sky background",
378
- 13],
379
- ["assets/example1.png",
380
- "a photo of a tree, spring, foggy",
381
- "a photo of a tree, summer, colourful",
382
- 8],
383
- ["assets/example3.png",
384
- ("a digital illustration of a female warrior, flowing "
385
- "dark purple or black hair, bronze shoulder armour, "
386
- "leather chest piece, sky background with clouds"),
387
- ("a realistic photo of a female warrior, flowing "
388
- "dark purple or black hair, bronze shoulder armour, "
389
- "leather chest piece, sky background with clouds"),
390
- 6],
391
- ],
392
- inputs=[example_in, prompt_box, new_box, xattn],
393
- )
394
-
395
- meta_state = gr.State()
396
-
397
- example_in.change(
398
- on_image_change,
399
- inputs=example_in,
400
- outputs=[img_in, recon_img, meta_state,
401
- steps, inv_slider, xattn]
402
- ).then(lambda: gr.update(interactive=True),
403
- outputs=[apply_btn, new_box])
404
-
405
- steps.release(update_step, inputs=steps)
406
- xattn.release(update_scale, inputs=xattn)
407
-
408
- recon_btn.click(
409
- reconstruct,
410
- inputs=[img_in, prompt_box],
411
- outputs=[recon_img, new_box, xattn, meta_state]
412
- ).then(lambda: gr.update(interactive=True),
413
- outputs=[recon_btn, new_box, apply_btn])
414
-
415
- recon_btn.click(lambda: gr.update(interactive=False),
416
- outputs=[recon_btn, apply_btn])
417
-
418
- apply_btn.click(apply_prompt,
419
- inputs=[meta_state, new_box],
420
- outputs=recon_img)
421
-
422
- # -------------------------------------------------------------
423
- # 8. Launch
424
- # -------------------------------------------------------------
425
- if __name__ == "__main__":
426
- parser = argparse.ArgumentParser()
427
- parser.add_argument("--share", action="store_true",
428
- help="Enable public Gradio sharing")
429
- args = parser.parse_args()
430
  demo.queue()
431
- demo.launch(share=args.share, inbrowser=True)
 
1
+ import warnings
 
 
 
 
 
 
2
 
3
+ warnings.filterwarnings("ignore")
4
+ from diffusers import DiffusionPipeline, DDIMInverseScheduler, DDIMScheduler, AutoencoderKL
5
+ import torch
6
+ from typing import Optional
7
+ from tqdm import tqdm
8
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
9
+ import torchvision
10
+ import torch.nn as nn
11
  import torch.nn.functional as F
12
+ import gc
13
+ import gradio as gr
14
  import numpy as np
15
+ import os
16
+ import pickle
17
+ import argparse
18
  from PIL import Image
19
+ import requests
20
+ import math
21
+ import torch
 
 
22
  from safetensors.torch import load_file
23
  from huggingface_hub import hf_hub_download
24
+ from diffusers import DiffusionPipeline
25
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def save_state_to_file(state):
 
 
 
 
 
 
 
 
 
 
 
 
28
  filename = "state.pkl"
29
  with open(filename, "wb") as f:
30
  pickle.dump(state, f)
31
  return filename
32
 
33
+
34
+ def load_state_from_file(filename):
35
  with open(filename, "rb") as f:
36
+ state = pickle.load(f)
37
+ return state
38
+
39
+ guidance_scale_value = 7.5
40
+ num_inference_steps = 10
41
+ weights = {}
42
+ res_list = []
43
+ foreground_mask = None
44
+ heighest_resolution = -1
45
+ signal_value = 2.0
46
+ blur_value = None
47
+ allowed_res_max = 1.0
48
 
49
+
50
+ def load_pipeline():
51
+
52
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
53
+ vae_model_id = "madebyollin/sdxl-vae-fp16-fix"
54
+ vae_folder = ""
55
+ guidance_scale_value = 7.5
56
+ resadapter_model_name = "resadapter_v2_sdxl"
57
+ res_range_min = 256
58
+ res_range_max = 1536
59
+
60
+ torch_dtype = torch.float16
61
+
62
+ # torch_dtype = torch.float16
63
+ pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to("cuda")
64
+ pipe.vae = AutoencoderKL.from_pretrained(vae_model_id, subfolder=vae_folder, torch_dtype=torch_dtype).to("cuda")
65
+ pipe.load_lora_weights(
66
+ hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="pytorch_lora_weights.safetensors"),
67
+ adapter_name="res_adapter",
68
+ ) # load lora weights
69
+ pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])
70
+ pipe.unet.load_state_dict(
71
+ load_file(hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="diffusion_pytorch_model.safetensors")),
72
+ strict=False,
73
+ ) # load norm weights
74
+
75
+ inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
76
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
77
+
78
+ return pipe, inverse_scheduler, scheduler
79
  def weight_population(layer_type, resolution, depth, value):
80
+ # Check if layer_type exists, if not, create it
81
  if layer_type not in weights:
82
  weights[layer_type] = {}
83
+
84
+ # Check if resolution exists under layer_type, if not, create it
85
  if resolution not in weights[layer_type]:
86
  weights[layer_type][resolution] = {}
87
+
88
+ global heighest_resolution
89
  if resolution > heighest_resolution:
90
  heighest_resolution = resolution
91
+
92
+ # Add/Modify the value at the specified depth (which can be a string)
93
  weights[layer_type][resolution][depth] = value
94
 
95
+ def resize_image_with_aspect(image, res_range_min=128, res_range_max=1024):
96
+ # Get the original width and height of the image
97
+ width, height = image.size
98
+
99
+ # Determine the scaling factor to maintain the aspect ratio
100
+ scaling_factor = 1
101
+ if width < res_range_min or height < res_range_min:
102
+ scaling_factor = max(res_range_min / width, res_range_min / height)
103
+ elif width > res_range_max or height > res_range_max:
104
+ scaling_factor = min(res_range_max / width, res_range_max / height)
105
+
106
+ # Calculate the new dimensions
107
+ new_width = int(width * scaling_factor)
108
+ new_height = int(height * scaling_factor)
109
+
110
+ print(f'{new_width}-{new_height}')
111
+
112
+ # Resize the image with the new dimensions while maintaining the aspect ratio
113
+ resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
114
+
115
+ return resized_image
116
+
117
+ @spaces.GPU()
118
+ def reconstruct(input_img, caption):
119
+
120
+ pipe, inverse_scheduler, scheduler = load_pipeline()
121
+
122
+ global weights
123
+ weights = {}
124
+
125
+ prompt = caption
126
+
127
+ img = input_img
128
+
129
+ img = resize_image_with_aspect(img, res_range_min, res_range_max)
130
+
131
+ transform = torchvision.transforms.Compose([
132
+ torchvision.transforms.ToTensor()
133
+ ])
134
+
135
+ if torch_dtype == torch.float16:
136
+ loaded_image = transform(img).half().to("cuda").unsqueeze(0)
137
  else:
138
+ loaded_image = transform(img).to("cuda").unsqueeze(0)
 
 
 
139
 
140
+ if loaded_image.shape[1] == 4:
141
+ loaded_image = loaded_image[:,:3,:,:]
142
+
143
+ with torch.no_grad():
144
+ encoded_image = pipe.vae.encode(loaded_image*2 - 1)
145
+ real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()
146
+
147
+
148
+ # notice we disabled the CFG here by setting guidance scale as 1
149
+ guidance_scale = 1.0
150
+ inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
151
+ timesteps = inverse_scheduler.timesteps
152
+
153
+ latents = real_image_latents
154
+
155
+ inversed_latents = [latents]
156
+
157
+ def store_latent(pipe, step, timestep, callback_kwargs):
158
+ latents = callback_kwargs["latents"]
159
+
160
+ with torch.no_grad():
161
+ if step != num_inference_steps - 1:
162
+ inversed_latents.append(latents)
163
+
164
+ return callback_kwargs
165
+
166
+ with torch.no_grad():
167
+
168
+ replace_attention_processor(pipe.unet, True)
169
+
170
+ pipe.scheduler = inverse_scheduler
171
+ latents = pipe(prompt=prompt,
172
+ guidance_scale = guidance_scale,
173
+ output_type="latent",
174
+ return_dict=False,
175
+ num_inference_steps=num_inference_steps,
176
+ latents=latents,
177
+ callback_on_step_end=store_latent,
178
+ callback_on_step_end_tensor_inputs=["latents"],)[0]
179
+
180
+ # initial state
181
+ real_image_initial_latents = latents
182
+
183
+ guidance_scale = guidance_scale_value
184
+ scheduler.set_timesteps(num_inference_steps, device="cuda")
185
+ timesteps = scheduler.timesteps
186
+
187
+ def adjust_latent(pipe, step, timestep, callback_kwargs):
188
+
189
+ with torch.no_grad():
190
+ callback_kwargs["latents"] = inversed_latents[len(timesteps) - 1 - step].detach()
191
+
192
+ return callback_kwargs
193
+
194
+ with torch.no_grad():
195
+
196
+ replace_attention_processor(pipe.unet, True)
197
+
198
+ intermediate_values = real_image_initial_latents.clone()
199
+
200
+ pipe.scheduler = scheduler
201
+ intermediate_values = pipe(prompt=prompt,
202
+ guidance_scale = guidance_scale,
203
+ output_type="latent",
204
+ return_dict=False,
205
+ num_inference_steps=num_inference_steps,
206
+ latents=intermediate_values,
207
+ callback_on_step_end=adjust_latent,
208
+ callback_on_step_end_tensor_inputs=["latents"],)[0]
209
+
210
+ image = pipe.vae.decode(intermediate_values / pipe.vae.config.scaling_factor, return_dict=False)[0]
211
+ image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu()
212
+ image_np = (image_np / 2 + 0.5).clamp(0, 1).numpy()
213
+ image_np = (image_np * 255).astype(np.uint8)
214
+
215
+ update_scale(12)
216
+
217
+ return image_np, caption, 12, [caption, real_image_initial_latents.detach(), inversed_latents, weights]
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  class AttnReplaceProcessor(AttnProcessor2_0):
220
+
221
+ def __init__(self, replace_all, layer_type, layer_count, blur_sigma=None):
222
  super().__init__()
223
  self.replace_all = replace_all
224
  self.layer_type = layer_type
225
  self.layer_count = layer_count
226
+ self.weight_populated = False
227
  self.blur_sigma = blur_sigma
228
 
229
  def __call__(
230
+ self,
231
+ attn: Attention,
232
+ hidden_states: torch.FloatTensor,
233
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
234
+ attention_mask: Optional[torch.FloatTensor] = None,
235
+ temb: Optional[torch.FloatTensor] = None,
236
+ *args,
237
+ **kwargs,
238
+ ) -> torch.FloatTensor:
239
+
240
+
241
+ dimension_squared = hidden_states.shape[1]
242
+
243
+ is_cross = not encoder_hidden_states is None
244
+
245
  residual = hidden_states
246
+ if attn.spatial_norm is not None:
247
+ hidden_states = attn.spatial_norm(hidden_states, temb)
248
+
249
+ input_ndim = hidden_states.ndim
250
+
251
+ if input_ndim == 4:
252
+ batch_size, channel, height, width = hidden_states.shape
253
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
254
+
255
+ batch_size, sequence_length, _ = (
256
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
257
+ )
258
+
259
+ if attention_mask is not None:
260
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
261
+ # scaled_dot_product_attention expects attention_mask shape to be
262
+ # (batch, heads, source_length, target_length)
263
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
264
+
265
+ if attn.group_norm is not None:
266
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
267
+
268
+ query = attn.to_q(hidden_states)
269
+
270
+ if encoder_hidden_states is None:
271
+ encoder_hidden_states = hidden_states
272
+ elif attn.norm_cross:
273
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
274
 
275
+ key = attn.to_k(encoder_hidden_states)
276
+ value = attn.to_v(encoder_hidden_states)
277
+
278
+ inner_dim = key.shape[-1]
279
+ head_dim = inner_dim // attn.heads
280
+
281
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282
+
283
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
284
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
285
+
286
+ height = width = math.isqrt(query.shape[2])
287
+
288
+
289
+ if self.replace_all:
290
+ weight_value = weights[self.layer_type][dimension_squared][self.layer_count]
291
+
292
+ ucond_attn_scores, attn_scores = query.chunk(2)
293
+ attn_scores[1].copy_(weight_value * attn_scores[0] + (1.0 - weight_value) * attn_scores[1])
294
+ ucond_attn_scores[1].copy_(weight_value * ucond_attn_scores[0] + (1.0 - weight_value) * ucond_attn_scores[1])
295
+
296
+
297
+ ucond_attn_scores, attn_scores = key.chunk(2)
298
+ attn_scores[1].copy_(weight_value * attn_scores[0] + (1.0 - weight_value) * attn_scores[1])
299
+ ucond_attn_scores[1].copy_(weight_value * ucond_attn_scores[0] + (1.0 - weight_value) * ucond_attn_scores[1])
300
+ else:
301
+ weight_population(self.layer_type, dimension_squared, self.layer_count, 1.0)
302
+
303
+
304
+ hidden_states = F.scaled_dot_product_attention(
305
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
306
+ )
307
 
308
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
309
+ hidden_states = hidden_states.to(query.dtype)
310
+
311
+ # linear proj
312
+ hidden_states = attn.to_out[0](hidden_states)
313
+ # dropout
314
+ hidden_states = attn.to_out[1](hidden_states)
315
+
316
+ if input_ndim == 4:
317
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
318
+
319
+ if attn.residual_connection:
320
+ hidden_states = hidden_states + residual
321
+
322
+ hidden_states = hidden_states / attn.rescale_output_factor
323
+
324
+ return hidden_states
325
 
326
  def replace_attention_processor(unet, clear=False, blur_sigma=None):
327
+ attention_count = 0
328
+
329
+
330
  for name, module in unet.named_modules():
331
  if "attn1" in name and "to" not in name:
332
  layer_type = name.split(".")[0].split("_")[0]
333
+ attention_count += 1
334
+
335
+ if not clear:
336
+ if layer_type == "down":
337
+ module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma)
338
+ elif layer_type == "mid":
339
+ module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma)
340
+ elif layer_type == "up":
341
+ module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma)
342
+
343
+ else:
344
+ module.processor = AttnReplaceProcessor(False, layer_type, attention_count, blur_sigma=blur_sigma)
 
 
 
 
 
 
 
 
 
 
345
 
346
+ @spaces.GPU()
347
+ def apply_prompt(meta_data, new_prompt):
 
 
348
 
349
+ pipe, inverse_scheduler, scheduler = load_pipeline()
350
+
351
+ caption, real_image_initial_latents, inversed_latents, _ = meta_data
352
+ negative_prompt = ""
353
+
354
+ inference_steps = len(inversed_latents)
355
+
356
+ guidance_scale = guidance_scale_value
357
+ scheduler.set_timesteps(inference_steps, device="cuda")
358
+ timesteps = scheduler.timesteps
359
+
360
+ initial_latents = torch.cat([real_image_initial_latents] * 2)
361
+
362
+ def adjust_latent(pipe, step, timestep, callback_kwargs):
363
+ replace_attention_processor(pipe.unet)
364
+
365
+ with torch.no_grad():
366
+ callback_kwargs["latents"][1] = callback_kwargs["latents"][1] + (inversed_latents[len(timesteps) - 1 - step].detach() - callback_kwargs["latents"][0])
367
+ callback_kwargs["latents"][0] = inversed_latents[len(timesteps) - 1 - step].detach()
368
+
369
+ return callback_kwargs
370
+
371
+
372
+ with torch.no_grad():
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  replace_attention_processor(pipe.unet)
375
+
376
+ pipe.scheduler = scheduler
377
+ latents = pipe(prompt=[caption, new_prompt],
378
+ negative_prompt=[negative_prompt, negative_prompt],
379
+ guidance_scale = guidance_scale,
380
+ output_type="latent",
381
+ return_dict=False,
382
+ num_inference_steps=num_inference_steps,
383
+ latents=initial_latents,
384
+ callback_on_step_end=adjust_latent,
385
+ callback_on_step_end_tensor_inputs=["latents"],)[0]
386
+
387
+ replace_attention_processor(pipe.unet, True)
388
+
389
+ image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
390
+ image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu()
391
+ image_np = (image_np / 2 + 0.5).clamp(0, 1).numpy()
392
+ image_np = (image_np * 255).astype(np.uint8)
393
+
394
+ return image_np
395
+
396
+
 
 
 
 
 
 
 
 
397
  def on_image_change(filepath):
398
+ # Extract the filename without extension
399
+ filename = os.path.splitext(os.path.basename(filepath))[0]
400
+
401
+ if filename in ["example1", "example3", "example4"]:
402
+
403
+ meta_data_raw = load_state_from_file(f"assets/{filename}-turbo.pkl")
404
+
405
  global weights
406
+ _, _, _, weights = meta_data_raw
407
+
408
  global num_inference_steps
409
  num_inference_steps = 10
410
+ scale_value = 7
411
+
412
+ if filename == "example1":
413
+ scale_value = 8
414
+ new_prompt = "a photo of a tree, summer, colourful"
415
+
416
+ elif filename == "example3":
417
+ scale_value = 6
418
+ new_prompt = "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"
419
+
420
+ elif filename == "example4":
421
+ scale_value = 13
422
+ new_prompt = "a photo of plastic bottle on some sand, beach background, sky background"
423
+
424
+ update_scale(scale_value)
425
+ img = apply_prompt(meta_data_raw, new_prompt)
426
+
427
+ return filepath, img, meta_data_raw, num_inference_steps, scale_value, scale_value
428
+
429
+
430
+ def update_value(value, layer_type, resolution, depth):
431
+ global weights
432
+ weights[layer_type][resolution][depth] = value
433
+
434
+
435
+ def update_step(value):
436
  global num_inference_steps
437
+ num_inference_steps = value
438
+
439
+ def adjust_ends(values, adjustment):
440
+ # Forward loop to adjust the first valid element from the left
441
+ for i in range(len(values)):
442
+ if (adjustment > 0 and values[i + 1] == 1.0) or (adjustment < 0 and values[i] > 0.0):
443
+ values[i] = values[i] + adjustment
444
+ break
445
+
446
+ # Backward loop to adjust the first valid element from the right
447
+ for i in range(len(values)-1, -1, -1):
448
+ if (adjustment > 0 and values[i - 1] == 1.0) or (adjustment < 0 and values[i] > 0.0):
449
+ values[i] = values[i] + adjustment
450
+ break
451
+
452
+ return values
453
+
454
+ max_scale_value = 16
455
+
456
+ def update_scale(scale):
457
+ global weights
458
+
459
+ value_count = 0
460
+
461
+ for outer_key, inner_dict in weights.items():
462
+ for inner_key, values in inner_dict.items():
463
+ for _, value in enumerate(values):
464
+ value_count += 1
465
+
466
+ list_values = [1.0] * value_count
467
+
468
+ for _ in range(scale, max_scale_value):
469
+ adjust_ends(list_values, -0.5)
470
+
471
+ value_index = 0
472
+
473
+ for outer_key, inner_dict in weights.items():
474
+ for inner_key, values in inner_dict.items():
475
+ for idx, value in enumerate(values):
476
+
477
+ weights[outer_key][inner_key][value] = list_values[value_index]
478
+ value_index += 1
479
+
480
+
481
+ if __name__ == "__main__":
482
+
483
+ parser = argparse.ArgumentParser()
484
+ parser.add_argument("--share", action="store_true", help="Enable sharing of the Gradio interface")
485
+ args = parser.parse_args()
486
+
487
+ num_inference_steps = 10
488
+
489
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
490
+ vae_model_id = "madebyollin/sdxl-vae-fp16-fix"
491
+ vae_folder = ""
492
+ guidance_scale_value = 7.5
493
+ resadapter_model_name = "resadapter_v2_sdxl"
494
+ res_range_min = 256
495
+ res_range_max = 1536
496
+
497
+ torch_dtype = torch.float16
498
 
 
 
 
499
  with gr.Blocks(analytics_enabled=False) as demo:
500
  gr.Markdown(
501
+ """
502
+ <div style="text-align: center;">
503
+ <div style="display: flex; justify-content: center;">
504
+ <img src="https://github.com/user-attachments/assets/55a38e74-ab93-4d80-91c8-0fa6130af45a" alt="Logo">
505
+ </div>
506
+ <h1>Out of Focus v1.0 Turbo</h1>
507
+ <p style="font-size:16px;">Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process</p>
508
+ </div>
509
+ <br>
510
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
511
+ <a href="https://www.buymeacoffee.com/outofai" target="_blank"><img src="https://img.shields.io/badge/-buy_me_a%C2%A0coffee-red?logo=buy-me-a-coffee" alt="Buy Me A Coffee"></a> &ensp;
512
+ <a href="https://twitter.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Out"></a>
513
+ </div>
514
+ """
515
  )
 
516
  with gr.Row():
517
  with gr.Column():
518
+
519
+ with gr.Row():
520
+ example_input = gr.Image(type="filepath", visible=False)
521
+ image_input = gr.Image(type="pil", label="Upload Source Image")
522
+ steps_slider = gr.Slider(minimum=5, maximum=50, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image")
523
+ prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image")
524
+ reconstruct_button = gr.Button("Reconstruct")
 
525
  with gr.Column():
526
+
527
+ with gr.Row():
528
+ reconstructed_image = gr.Image(type="pil", label="Reconstructed")
529
+ invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False)
530
+ interpolate_slider = gr.Slider(minimum=0, maximum=max_scale_value, step=1, value=max_scale_value, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image")
531
+ new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or adding words at the end; swap words instead of adding or removing them for better results")
532
+
533
+ with gr.Row():
534
+ apply_button = gr.Button("Generate Vision", variant="primary", interactive=False)
535
+
536
+ with gr.Row():
537
+ show_case = gr.Examples(
538
+ examples=[
539
+ ["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background", 13],
540
+ ["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful", 8],
541
+ [
542
+ "assets/example3.png",
543
+ "a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds",
544
+ "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds",
545
+ 6 ,
546
+ ],
547
+ ],
548
+ inputs=[example_input, prompt_input, new_prompt_input, interpolate_slider],
549
+ label=None,
550
+ )
551
+
552
+ meta_data = gr.State()
553
+
554
+ example_input.change(fn=on_image_change, inputs=example_input, outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider]).then(lambda: gr.update(interactive=True), outputs=apply_button).then(
555
+ lambda: gr.update(interactive=True), outputs=new_prompt_input
556
+ )
557
+ steps_slider.release(update_step, inputs=steps_slider)
558
+ interpolate_slider.release(update_scale, inputs=interpolate_slider)
559
+
560
+ value_trigger = True
561
+
562
+ def triggered():
563
+ global value_trigger
564
+ value_trigger = not value_trigger
565
+ return value_trigger
566
+
567
+ reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, interpolate_slider, meta_data]).then(lambda: gr.update(interactive=True), outputs=reconstruct_button).then(lambda: gr.update(interactive=True), outputs=new_prompt_input).then(
568
+ lambda: gr.update(interactive=True), outputs=apply_button
569
+ )
570
+
571
+ reconstruct_button.click(lambda: gr.update(interactive=False), outputs=reconstruct_button)
572
+
573
+ reconstruct_button.click(lambda: gr.update(interactive=False), outputs=apply_button)
574
+
575
+ apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image)
576
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  demo.queue()
578
+ demo.launch(share=args.share, inbrowser=True)