alexnasa commited on
Commit
e788852
·
verified ·
1 Parent(s): 9b0d447

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +387 -410
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import warnings
2
-
3
  warnings.filterwarnings("ignore")
4
- from diffusers import DiffusionPipeline, DDIMInverseScheduler, DDIMScheduler, AutoencoderKL
5
  import torch
6
  from typing import Optional
7
  from tqdm import tqdm
@@ -14,193 +13,205 @@ import gradio as gr
14
  import numpy as np
15
  import os
16
  import pickle
 
 
17
  import argparse
18
- from PIL import Image
19
- import requests
20
- import math
21
- import torch
22
- from safetensors.torch import load_file
23
- from huggingface_hub import hf_hub_download
24
- from diffusers import DiffusionPipeline
25
- import spaces
26
 
27
- @spaces.GPU()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def save_state_to_file(state):
29
  filename = "state.pkl"
30
- with open(filename, "wb") as f:
31
- pickle.dump(state, f)
32
  return filename
33
 
34
- @spaces.GPU()
35
  def load_state_from_file(filename):
36
- with open(filename, "rb") as f:
37
- state = pickle.load(f)
38
- return state
39
 
40
- guidance_scale_value = 7.5
41
- num_inference_steps = 10
42
- weights = {}
43
- res_list = []
44
- foreground_mask = None
45
- heighest_resolution = -1
46
- signal_value = 2.0
47
- blur_value = None
48
- allowed_res_max = 1.0
49
-
50
- # Device configuration
51
- device = "cuda"
52
- print(f"Using device: {device}")
53
-
54
- @spaces.GPU()
55
- def weight_population(layer_type, resolution, depth, value):
56
- # Check if layer_type exists, if not, create it
57
- if layer_type not in weights:
58
- weights[layer_type] = {}
59
-
60
- # Check if resolution exists under layer_type, if not, create it
61
- if resolution not in weights[layer_type]:
62
- weights[layer_type][resolution] = {}
63
 
64
- global heighest_resolution
65
- if resolution > heighest_resolution:
66
- heighest_resolution = resolution
67
-
68
- # Add/Modify the value at the specified depth (which can be a string)
69
- weights[layer_type][resolution][depth] = value
70
 
71
- @spaces.GPU()
72
- def resize_image_with_aspect(image, res_range_min=128, res_range_max=1024):
73
- # Get the original width and height of the image
74
- width, height = image.size
75
-
76
- # Determine the scaling factor to maintain the aspect ratio
77
- scaling_factor = 1
78
- if width < res_range_min or height < res_range_min:
79
- scaling_factor = max(res_range_min / width, res_range_min / height)
80
- elif width > res_range_max or height > res_range_max:
81
- scaling_factor = min(res_range_max / width, res_range_max / height)
82
-
83
- # Calculate the new dimensions
84
- new_width = int(width * scaling_factor)
85
- new_height = int(height * scaling_factor)
86
 
87
- print(f'{new_width}-{new_height}')
88
-
89
- # Resize the image with the new dimensions while maintaining the aspect ratio
90
- resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
91
-
92
- return resized_image
93
 
94
- @spaces.GPU()
95
- def reconstruct(input_img, caption):
96
- pipe, inverse_scheduler, scheduler = load_pipeline()
97
- global weights
98
- weights = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- prompt = caption
101
 
102
- img = input_img
103
 
104
- img = resize_image_with_aspect(img, res_range_min, res_range_max)
 
 
 
 
 
 
105
 
106
- transform = torchvision.transforms.Compose([
107
- torchvision.transforms.ToTensor()
108
- ])
109
 
110
- if torch_dtype == torch.float16:
111
- loaded_image = transform(img).half().to(device).unsqueeze(0)
112
- else:
113
- loaded_image = transform(img).to(device).unsqueeze(0)
114
 
115
- if loaded_image.shape[1] == 4:
116
- loaded_image = loaded_image[:,:3,:,:]
117
-
118
- with torch.no_grad():
119
- encoded_image = pipe.vae.encode(loaded_image*2 - 1)
120
- real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()
121
 
122
 
123
- # notice we disabled the CFG here by setting guidance scale as 1
124
- guidance_scale = 1.0
125
- inverse_scheduler.set_timesteps(num_inference_steps, device=device)
126
- timesteps = inverse_scheduler.timesteps
127
 
128
- latents = real_image_latents
 
129
 
130
- inversed_latents = [latents]
131
 
132
- def store_latent(pipe, step, timestep, callback_kwargs):
133
- latents = callback_kwargs["latents"]
 
134
 
135
- with torch.no_grad():
136
- if step != num_inference_steps - 1:
137
- inversed_latents.append(latents)
138
 
139
- return callback_kwargs
 
 
 
140
 
141
- with torch.no_grad():
142
 
143
- replace_attention_processor(pipe.unet, True)
 
 
144
 
145
- pipe.scheduler = inverse_scheduler
146
- latents = pipe(prompt=prompt,
147
- guidance_scale = guidance_scale,
148
- output_type="latent",
149
- return_dict=False,
150
- num_inference_steps=num_inference_steps,
151
- latents=latents,
152
- callback_on_step_end=store_latent,
153
- callback_on_step_end_tensor_inputs=["latents"],)[0]
154
 
155
- # initial state
156
- real_image_initial_latents = latents
157
 
158
- guidance_scale = guidance_scale_value
159
- scheduler.set_timesteps(num_inference_steps, device=device)
160
- timesteps = scheduler.timesteps
161
 
162
- def adjust_latent(pipe, step, timestep, callback_kwargs):
 
163
 
164
- with torch.no_grad():
165
- callback_kwargs["latents"] = inversed_latents[len(timesteps) - 1 - step].detach()
166
 
167
- return callback_kwargs
168
-
169
- with torch.no_grad():
170
 
171
- replace_attention_processor(pipe.unet, True)
172
 
173
- intermediate_values = real_image_initial_latents.clone()
 
 
 
 
 
 
174
 
175
- pipe.scheduler = scheduler
176
- intermediate_values = pipe(prompt=prompt,
177
- guidance_scale = guidance_scale,
178
- output_type="latent",
179
- return_dict=False,
180
- num_inference_steps=num_inference_steps,
181
- latents=intermediate_values,
182
- callback_on_step_end=adjust_latent,
183
- callback_on_step_end_tensor_inputs=["latents"],)[0]
184
 
185
- image = pipe.vae.decode(intermediate_values / pipe.vae.config.scaling_factor, return_dict=False)[0]
186
- image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu()
187
- image_np = (image_np / 2 + 0.5).clamp(0, 1).numpy()
188
- image_np = (image_np * 255).astype(np.uint8)
189
 
190
- update_scale(12)
191
 
192
- return image_np, caption, 12, [caption, real_image_initial_latents.detach(), inversed_latents, weights]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- @spaces.GPU()
195
  class AttnReplaceProcessor(AttnProcessor2_0):
196
 
197
- def __init__(self, replace_all, layer_type, layer_count, blur_sigma=None):
198
  super().__init__()
199
  self.replace_all = replace_all
200
- self.layer_type = layer_type
201
- self.layer_count = layer_count
202
- self.weight_populated = False
203
- self.blur_sigma = blur_sigma
204
 
205
  def __call__(
206
  self,
@@ -213,31 +224,20 @@ class AttnReplaceProcessor(AttnProcessor2_0):
213
  **kwargs,
214
  ) -> torch.FloatTensor:
215
 
216
-
217
- dimension_squared = hidden_states.shape[1]
218
 
219
  is_cross = not encoder_hidden_states is None
220
 
221
- residual = hidden_states
222
- if attn.spatial_norm is not None:
223
- hidden_states = attn.spatial_norm(hidden_states, temb)
224
-
225
  input_ndim = hidden_states.ndim
226
 
227
  if input_ndim == 4:
228
  batch_size, channel, height, width = hidden_states.shape
229
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
230
 
231
- batch_size, sequence_length, _ = (
232
  hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
233
  )
234
 
235
- if attention_mask is not None:
236
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
237
- # scaled_dot_product_attention expects attention_mask shape to be
238
- # (batch, heads, source_length, target_length)
239
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
240
-
241
  if attn.group_norm is not None:
242
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
243
 
@@ -251,43 +251,27 @@ class AttnReplaceProcessor(AttnProcessor2_0):
251
  key = attn.to_k(encoder_hidden_states)
252
  value = attn.to_v(encoder_hidden_states)
253
 
254
- inner_dim = key.shape[-1]
255
- head_dim = inner_dim // attn.heads
256
-
257
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
258
 
259
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
260
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
261
 
262
- height = width = math.isqrt(query.shape[2])
263
-
264
-
265
- if self.replace_all:
266
- weight_value = weights[self.layer_type][dimension_squared][self.layer_count]
267
-
268
- ucond_attn_scores, attn_scores = query.chunk(2)
269
- attn_scores[1].copy_(weight_value * attn_scores[0] + (1.0 - weight_value) * attn_scores[1])
270
- ucond_attn_scores[1].copy_(weight_value * ucond_attn_scores[0] + (1.0 - weight_value) * ucond_attn_scores[1])
271
-
272
-
273
- ucond_attn_scores, attn_scores = key.chunk(2)
274
- attn_scores[1].copy_(weight_value * attn_scores[0] + (1.0 - weight_value) * attn_scores[1])
275
- ucond_attn_scores[1].copy_(weight_value * ucond_attn_scores[0] + (1.0 - weight_value) * ucond_attn_scores[1])
276
- else:
277
- weight_population(self.layer_type, dimension_squared, self.layer_count, 1.0)
278
 
 
 
 
 
279
 
280
- hidden_states = F.scaled_dot_product_attention(
281
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
282
- )
283
 
284
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
285
- hidden_states = hidden_states.to(query.dtype)
 
286
 
287
- # linear proj
288
  hidden_states = attn.to_out[0](hidden_states)
289
- # dropout
290
- hidden_states = attn.to_out[1](hidden_states)
291
 
292
  if input_ndim == 4:
293
  hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
@@ -299,296 +283,289 @@ class AttnReplaceProcessor(AttnProcessor2_0):
299
 
300
  return hidden_states
301
 
302
- @spaces.GPU()
303
- def replace_attention_processor(unet, clear=False, blur_sigma=None):
304
- attention_count = 0
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
- for name, module in unet.named_modules():
308
- if "attn1" in name and "to" not in name:
309
- layer_type = name.split(".")[0].split("_")[0]
310
- attention_count += 1
311
-
312
- if not clear:
313
- if layer_type == "down":
314
- module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma)
315
- elif layer_type == "mid":
316
- module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma)
317
- elif layer_type == "up":
318
- module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma)
319
-
320
- else:
321
- module.processor = AttnReplaceProcessor(False, layer_type, attention_count, blur_sigma=blur_sigma)
322
-
323
- @spaces.GPU()
324
  def apply_prompt(meta_data, new_prompt):
325
 
326
- caption, real_image_initial_latents, inversed_latents, _ = meta_data
327
- negative_prompt = ""
 
 
 
 
 
 
 
 
 
328
 
329
- inference_steps = len(inversed_latents)
330
 
331
- guidance_scale = guidance_scale_value
332
- scheduler.set_timesteps(inference_steps, device=device)
333
- timesteps = scheduler.timesteps
334
 
335
- initial_latents = torch.cat([real_image_initial_latents] * 2)
336
 
337
- def adjust_latent(pipe, step, timestep, callback_kwargs):
338
- replace_attention_processor(pipe.unet)
339
 
340
- with torch.no_grad():
341
- callback_kwargs["latents"][1] = callback_kwargs["latents"][1] + (inversed_latents[len(timesteps) - 1 - step].detach() - callback_kwargs["latents"][0])
342
- callback_kwargs["latents"][0] = inversed_latents[len(timesteps) - 1 - step].detach()
 
 
 
 
343
 
344
- return callback_kwargs
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
- with torch.no_grad():
348
-
349
- replace_attention_processor(pipe.unet)
350
-
351
- pipe.scheduler = scheduler
352
- latents = pipe(prompt=[caption, new_prompt],
353
- negative_prompt=[negative_prompt, negative_prompt],
354
- guidance_scale = guidance_scale,
355
- output_type="latent",
356
- return_dict=False,
357
- num_inference_steps=num_inference_steps,
358
- latents=initial_latents,
359
- callback_on_step_end=adjust_latent,
360
- callback_on_step_end_tensor_inputs=["latents"],)[0]
361
-
362
- replace_attention_processor(pipe.unet, True)
363
-
364
- image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
365
- image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu()
366
- image_np = (image_np / 2 + 0.5).clamp(0, 1).numpy()
367
- image_np = (image_np * 255).astype(np.uint8)
368
-
369
- return image_np
370
-
371
- @spaces.GPU()
372
- def on_image_change(filepath):
373
- # Extract the filename without extension
374
- filename = os.path.splitext(os.path.basename(filepath))[0]
375
 
376
- if filename in ["example1", "example3", "example4"]:
377
 
378
- meta_data_raw = load_state_from_file(f"assets/{filename}-turbo.pkl")
379
 
380
- global weights
381
- _, _, _, weights = meta_data_raw
 
 
 
 
 
 
382
 
383
  global num_inference_steps
384
- num_inference_steps = 10
385
  scale_value = 7
 
386
 
387
  if filename == "example1":
388
- scale_value = 8
389
  new_prompt = "a photo of a tree, summer, colourful"
 
 
 
 
390
 
391
  elif filename == "example3":
392
- scale_value = 6
393
  new_prompt = "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"
394
-
395
  elif filename == "example4":
396
- scale_value = 13
397
  new_prompt = "a photo of plastic bottle on some sand, beach background, sky background"
398
 
399
  update_scale(scale_value)
400
  img = apply_prompt(meta_data_raw, new_prompt)
401
-
402
  return filepath, img, meta_data_raw, num_inference_steps, scale_value, scale_value
403
 
404
- @spaces.GPU()
405
- def update_value(value, layer_type, resolution, depth):
406
  global weights
407
- weights[layer_type][resolution][depth] = value
408
-
409
 
410
  def update_step(value):
411
  global num_inference_steps
412
  num_inference_steps = value
413
 
414
- def adjust_ends(values, adjustment):
415
- # Forward loop to adjust the first valid element from the left
416
- for i in range(len(values)):
417
- if (adjustment > 0 and values[i + 1] == 1.0) or (adjustment < 0 and values[i] > 0.0):
418
- values[i] = values[i] + adjustment
419
- break
420
 
421
- # Backward loop to adjust the first valid element from the right
422
- for i in range(len(values)-1, -1, -1):
423
- if (adjustment > 0 and values[i - 1] == 1.0) or (adjustment < 0 and values[i] > 0.0):
424
- values[i] = values[i] + adjustment
 
 
 
 
 
 
 
 
 
425
  break
426
 
427
- return values
428
-
429
- max_scale_value = 16
430
-
431
- @spaces.GPU()
432
- def update_scale(scale):
433
  global weights
434
-
435
- value_count = 0
436
 
437
  for outer_key, inner_dict in weights.items():
438
- for inner_key, values in inner_dict.items():
439
- for _, value in enumerate(values):
440
- value_count += 1
441
-
442
- list_values = [1.0] * value_count
443
-
444
- for _ in range(scale, max_scale_value):
445
- adjust_ends(list_values, -0.5)
446
-
447
- value_index = 0
448
-
449
- for outer_key, inner_dict in weights.items():
450
- for inner_key, values in inner_dict.items():
451
- for idx, value in enumerate(values):
452
-
453
- weights[outer_key][inner_key][value] = list_values[value_index]
454
- value_index += 1
455
-
456
-
457
- @spaces.GPU()
458
- def load_pipeline():
459
- model_id = "runwayml/stable-diffusion-v1-5"
460
- vae_model_id = "runwayml/stable-diffusion-v1-5"
461
- vae_folder = "vae"
462
- guidance_scale_value = 7.5
463
- resadapter_model_name = "resadapter_v2_sd1.5"
464
- res_range_min = 128
465
- res_range_max = 1024
466
-
467
-
468
- torch_dtype = torch.float16
469
-
470
- # torch_dtype = torch.float16
471
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
472
- pipe.vae = AutoencoderKL.from_pretrained(vae_model_id, subfolder=vae_folder, torch_dtype=torch_dtype).to(device)
473
- pipe.load_lora_weights(
474
- hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="pytorch_lora_weights.safetensors"),
475
- adapter_name="res_adapter",
476
- ) # load lora weights
477
- pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])
478
- pipe.unet.load_state_dict(
479
- load_file(hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="diffusion_pytorch_model.safetensors")),
480
- strict=False,
481
- ) # load norm weights
482
-
483
- inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
484
- scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
485
-
486
- return pipe, inverse_scheduler, scheduler
487
 
488
- if __name__ == "__main__":
489
-
490
- parser = argparse.ArgumentParser()
491
- parser.add_argument("--share", action="store_true", help="Enable sharing of the Gradio interface")
492
- args = parser.parse_args()
493
-
494
- num_inference_steps = 10
495
-
496
- # model_id = "stabilityai/stable-diffusion-xl-base-1.0"
497
- # vae_model_id = "madebyollin/sdxl-vae-fp16-fix"
498
- # vae_folder = ""
499
- # guidance_scale_value = 7.5
500
- # resadapter_model_name = "resadapter_v2_sdxl"
501
- # res_range_min = 256
502
- # res_range_max = 1536
503
- model_id = "runwayml/stable-diffusion-v1-5"
504
- vae_model_id = "runwayml/stable-diffusion-v1-5"
505
- vae_folder = "vae"
506
- guidance_scale_value = 7.5
507
- resadapter_model_name = "resadapter_v2_sd1.5"
508
- res_range_min = 128
509
- res_range_max = 1024
510
-
511
-
512
- torch_dtype = torch.float16
513
-
514
 
515
- with gr.Blocks(analytics_enabled=False) as demo:
516
  gr.Markdown(
517
- """
518
  <div style="text-align: center;">
519
  <div style="display: flex; justify-content: center;">
520
  <img src="https://github.com/user-attachments/assets/55a38e74-ab93-4d80-91c8-0fa6130af45a" alt="Logo">
521
  </div>
522
- <h1>Out of Focus v1.0 Turbo</h1>
523
  <p style="font-size:16px;">Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process</p>
524
  </div>
525
  <br>
526
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
527
  <a href="https://www.buymeacoffee.com/outofai" target="_blank"><img src="https://img.shields.io/badge/-buy_me_a%C2%A0coffee-red?logo=buy-me-a-coffee" alt="Buy Me A Coffee"></a> &ensp;
528
- <a href="https://twitter.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Out"></a>
 
529
  </div>
530
- """
531
- )
 
 
 
 
 
 
 
 
532
  with gr.Row():
533
- with gr.Column():
534
-
535
- with gr.Row():
536
- example_input = gr.Image(type="filepath", visible=False)
537
- image_input = gr.Image(type="pil", label="Upload Source Image")
538
- steps_slider = gr.Slider(minimum=5, maximum=50, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image")
539
- prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image")
540
- reconstruct_button = gr.Button("Reconstruct")
541
- with gr.Column():
542
-
543
- with gr.Row():
544
- reconstructed_image = gr.Image(type="pil", label="Reconstructed")
545
- invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False)
546
- interpolate_slider = gr.Slider(minimum=0, maximum=max_scale_value, step=1, value=max_scale_value, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image")
547
- new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or adding words at the end; swap words instead of adding or removing them for better results")
548
-
549
- with gr.Row():
550
- apply_button = gr.Button("Generate Vision", variant="primary", interactive=False)
551
-
552
- with gr.Row():
553
- show_case = gr.Examples(
554
- examples=[
555
- ["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background", 13],
556
- ["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful", 8],
557
- [
558
- "assets/example3.png",
559
- "a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds",
560
- "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds",
561
- 6 ,
562
- ],
563
- ],
564
- inputs=[example_input, prompt_input, new_prompt_input, interpolate_slider],
565
- label=None,
566
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
 
568
  meta_data = gr.State()
569
 
570
- example_input.change(fn=on_image_change, inputs=example_input, outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider]).then(lambda: gr.update(interactive=True), outputs=apply_button).then(
571
- lambda: gr.update(interactive=True), outputs=new_prompt_input
 
 
 
 
 
 
 
 
572
  )
573
  steps_slider.release(update_step, inputs=steps_slider)
574
- interpolate_slider.release(update_scale, inputs=interpolate_slider)
575
-
576
- value_trigger = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
- def triggered():
579
- global value_trigger
580
- value_trigger = not value_trigger
581
- return value_trigger
582
 
583
- reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, interpolate_slider, meta_data]).then(lambda: gr.update(interactive=True), outputs=reconstruct_button).then(lambda: gr.update(interactive=True), outputs=new_prompt_input).then(
584
- lambda: gr.update(interactive=True), outputs=apply_button
 
585
  )
586
 
587
- reconstruct_button.click(lambda: gr.update(interactive=False), outputs=reconstruct_button)
 
 
 
588
 
589
- reconstruct_button.click(lambda: gr.update(interactive=False), outputs=apply_button)
 
 
 
590
 
591
  apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image)
 
592
 
 
 
 
 
593
  demo.queue()
594
- demo.launch(share=args.share, inbrowser=True)
 
1
  import warnings
 
2
  warnings.filterwarnings("ignore")
3
+ from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler
4
  import torch
5
  from typing import Optional
6
  from tqdm import tqdm
 
13
  import numpy as np
14
  import os
15
  import pickle
16
+ from transformers import CLIPImageProcessor
17
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
18
  import argparse
 
 
 
 
 
 
 
 
19
 
20
+ weights = {
21
+ 'down': {
22
+ 4096: 0.0,
23
+ 1024: 1.0,
24
+ 256: 1.0,
25
+ },
26
+ 'mid': {
27
+ 64: 1.0,
28
+ },
29
+ 'up': {
30
+ 256: 1.0,
31
+ 1024: 1.0,
32
+ 4096: 0.0,
33
+ }
34
+ }
35
+ num_inference_steps = 10
36
+ model_id = "stabilityai/stable-diffusion-2-1-base"
37
+
38
+ pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda")
39
+ inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
40
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
41
+
42
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
43
+ feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
44
+
45
+ should_stop = False
46
+
47
  def save_state_to_file(state):
48
  filename = "state.pkl"
49
+ with open(filename, 'wb') as f:
50
+ pickle.dump(state, f)
51
  return filename
52
 
 
53
  def load_state_from_file(filename):
54
+ with open(filename, 'rb') as f:
55
+ state = pickle.load(f)
56
+ return state
57
 
58
+ def stop_reconstruct():
59
+ global should_stop
60
+ should_stop = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ def reconstruct(input_img, caption):
 
 
 
 
 
63
 
64
+ img = input_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
67
+ uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
 
 
 
 
68
 
69
+ prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds])
70
+
71
+
72
+ transform = torchvision.transforms.Compose([
73
+ torchvision.transforms.Resize((512, 512)),
74
+ torchvision.transforms.ToTensor()
75
+ ])
76
+
77
+ loaded_image = transform(img).to("cuda").unsqueeze(0)
78
+
79
+ if loaded_image.shape[1] == 4:
80
+ loaded_image = loaded_image[:,:3,:,:]
81
+
82
+ with torch.no_grad():
83
+ encoded_image = pipe.vae.encode(loaded_image*2 - 1)
84
+ real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()
85
+
86
+ guidance_scale = 1
87
+ inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
88
+ timesteps = inverse_scheduler.timesteps
89
+
90
+ latents = real_image_latents
91
+
92
+ inversed_latents = []
93
+
94
+ with torch.no_grad():
95
+
96
+ replace_attention_processor(pipe.unet, True)
97
+
98
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):
99
 
100
+ inversed_latents.append(latents)
101
 
102
+ latent_model_input = torch.cat([latents] * 2)
103
 
104
+ noise_pred = pipe.unet(
105
+ latent_model_input,
106
+ t,
107
+ encoder_hidden_states=prompt_embeds_combined,
108
+ cross_attention_kwargs=None,
109
+ return_dict=False,
110
+ )[0]
111
 
 
 
 
112
 
113
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
114
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
115
 
116
+ latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
 
 
 
 
117
 
118
 
119
+ # initial state
120
+ real_image_initial_latents = latents
 
 
121
 
122
+ W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1)
123
+ QT = nn.Parameter(W_values.clone())
124
 
 
125
 
126
+ guidance_scale = 7.5
127
+ scheduler.set_timesteps(num_inference_steps, device="cuda")
128
+ timesteps = scheduler.timesteps
129
 
130
+ optimizer = torch.optim.AdamW([QT], lr=0.008)
 
 
131
 
132
+ pipe.vae.eval()
133
+ pipe.vae.requires_grad_(False)
134
+ pipe.unet.eval()
135
+ pipe.unet.requires_grad_(False)
136
 
137
+ last_loss = 1
138
 
139
+ for epoch in range(50):
140
+ gc.collect()
141
+ torch.cuda.empty_cache()
142
 
143
+ if last_loss < 0.02:
144
+ break
145
+ elif last_loss < 0.03:
146
+ for param_group in optimizer.param_groups:
147
+ param_group['lr'] = 0.003
148
+ elif last_loss < 0.035:
149
+ for param_group in optimizer.param_groups:
150
+ param_group['lr'] = 0.006
 
151
 
152
+ intermediate_values = real_image_initial_latents.clone()
 
153
 
 
 
 
154
 
155
+ for i in range(num_inference_steps):
156
+ latents = intermediate_values.detach().clone()
157
 
158
+ t = timesteps[i]
 
159
 
160
+ prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])
 
 
161
 
162
+ latent_model_input = torch.cat([latents] * 2)
163
 
164
+ noise_pred_model = pipe.unet(
165
+ latent_model_input,
166
+ t,
167
+ encoder_hidden_states=prompt_embeds,
168
+ cross_attention_kwargs=None,
169
+ return_dict=False,
170
+ )[0]
171
 
172
+ noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)
173
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
 
 
 
 
 
174
 
175
+ intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
 
 
176
 
 
177
 
178
+ loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction="mean")
179
+ last_loss = loss
180
+
181
+ optimizer.zero_grad()
182
+ loss.backward()
183
+ optimizer.step()
184
+
185
+ global should_stop
186
+ if should_stop:
187
+ should_stop = False
188
+ break
189
+
190
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
191
+ image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
192
+ safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
193
+ image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
194
+ image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
195
+ image_np = (image_np * 255).astype(np.uint8)
196
+
197
+ yield image_np, caption, [caption, real_image_initial_latents, QT]
198
+
199
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
200
+ image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
201
+ safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
202
+ image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
203
+ image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
204
+ image_np = (image_np * 255).astype(np.uint8)
205
+
206
+ yield image_np, caption, [caption, real_image_initial_latents, QT]
207
+
208
 
 
209
  class AttnReplaceProcessor(AttnProcessor2_0):
210
 
211
+ def __init__(self, replace_all, weight):
212
  super().__init__()
213
  self.replace_all = replace_all
214
+ self.weight = weight
 
 
 
215
 
216
  def __call__(
217
  self,
 
224
  **kwargs,
225
  ) -> torch.FloatTensor:
226
 
227
+ residual = hidden_states
 
228
 
229
  is_cross = not encoder_hidden_states is None
230
 
 
 
 
 
231
  input_ndim = hidden_states.ndim
232
 
233
  if input_ndim == 4:
234
  batch_size, channel, height, width = hidden_states.shape
235
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
236
 
237
+ batch_size, _, _ = (
238
  hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
239
  )
240
 
 
 
 
 
 
 
241
  if attn.group_norm is not None:
242
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
243
 
 
251
  key = attn.to_k(encoder_hidden_states)
252
  value = attn.to_v(encoder_hidden_states)
253
 
254
+ query = attn.head_to_batch_dim(query)
255
+ key = attn.head_to_batch_dim(key)
256
+ value = attn.head_to_batch_dim(value)
 
257
 
258
+ attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2))
 
259
 
260
+ dimension_squared = hidden_states.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
+ if not is_cross and (self.replace_all):
263
+ ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4)
264
+ attn_scores_dst.copy_(self.weight[dimension_squared] * attn_scores_src + (1.0 - self.weight[dimension_squared]) * attn_scores_dst)
265
+ ucond_attn_scores_dst.copy_(self.weight[dimension_squared] * ucond_attn_scores_src + (1.0 - self.weight[dimension_squared]) * ucond_attn_scores_dst)
266
 
267
+ attention_probs = attention_scores.softmax(dim=-1)
268
+ del attention_scores
 
269
 
270
+ hidden_states = torch.bmm(attention_probs, value)
271
+ hidden_states = attn.batch_to_head_dim(hidden_states)
272
+ del attention_probs
273
 
 
274
  hidden_states = attn.to_out[0](hidden_states)
 
 
275
 
276
  if input_ndim == 4:
277
  hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
 
283
 
284
  return hidden_states
285
 
286
+ def replace_attention_processor(unet, clear = False):
 
 
287
 
288
+ for name, module in unet.named_modules():
289
+ if 'attn1' in name and 'to' not in name:
290
+ layer_type = name.split('.')[0].split('_')[0]
291
+
292
+ if not clear:
293
+ if layer_type == 'down':
294
+ module.processor = AttnReplaceProcessor(True, weights['down'])
295
+ elif layer_type == 'mid':
296
+ module.processor = AttnReplaceProcessor(True, weights['mid'])
297
+ elif layer_type == 'up':
298
+ module.processor = AttnReplaceProcessor(True, weights['up'])
299
+ else:
300
+ module.processor = AttnReplaceProcessor(False, 0.0)
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  def apply_prompt(meta_data, new_prompt):
303
 
304
+ caption, real_image_initial_latents, QT = meta_data
305
+
306
+ inference_steps = len(QT)
307
+
308
+ cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
309
+ # uncond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
310
+ new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
311
+
312
+ guidance_scale = 7.5
313
+ scheduler.set_timesteps(inference_steps, device="cuda")
314
+ timesteps = scheduler.timesteps
315
 
316
+ latents = torch.cat([real_image_initial_latents] * 2)
317
 
318
+ with torch.no_grad():
319
+ replace_attention_processor(pipe.unet)
 
320
 
321
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):
322
 
323
+ modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds])
324
+ latent_model_input = torch.cat([latents] * 2)
325
 
326
+ noise_pred = pipe.unet(
327
+ latent_model_input,
328
+ t,
329
+ encoder_hidden_states=modified_prompt_embeds,
330
+ cross_attention_kwargs=None,
331
+ return_dict=False,
332
+ )[0]
333
 
 
334
 
335
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
336
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
337
+
338
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
339
+
340
+ replace_attention_processor(pipe.unet, True)
341
+
342
+ image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
343
+ image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
344
+ safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
345
+ image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
346
+ image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
347
+ image_np = (image_np * 255).astype(np.uint8)
348
 
349
+ return image_np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
 
351
 
 
352
 
353
+ def on_image_change(filepath):
354
+ # Extract the filename without extension
355
+ filename = os.path.splitext(os.path.basename(filepath))[0]
356
+
357
+ # Check if the filename is "example1" or "example2"
358
+ if filename in ["example1", "example2", "example3", "example4"]:
359
+ meta_data_raw = load_state_from_file(f"assets/{filename}.pkl")
360
+ _, _, QT_raw = meta_data_raw
361
 
362
  global num_inference_steps
363
+ num_inference_steps = len(QT_raw)
364
  scale_value = 7
365
+ new_prompt = ""
366
 
367
  if filename == "example1":
368
+ scale_value = 7
369
  new_prompt = "a photo of a tree, summer, colourful"
370
+
371
+ elif filename == "example2":
372
+ scale_value = 8
373
+ new_prompt = "a photo of a panda, two ears, white background"
374
 
375
  elif filename == "example3":
376
+ scale_value = 7
377
  new_prompt = "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"
378
+
379
  elif filename == "example4":
380
+ scale_value = 7
381
  new_prompt = "a photo of plastic bottle on some sand, beach background, sky background"
382
 
383
  update_scale(scale_value)
384
  img = apply_prompt(meta_data_raw, new_prompt)
385
+
386
  return filepath, img, meta_data_raw, num_inference_steps, scale_value, scale_value
387
 
388
+ def update_value(value, key, res):
 
389
  global weights
390
+ weights[key][res] = value
 
391
 
392
  def update_step(value):
393
  global num_inference_steps
394
  num_inference_steps = value
395
 
396
+ def update_scale(scale):
397
+ values = [1.0] * 7
 
 
 
 
398
 
399
+ if scale == 9:
400
+ return values
401
+
402
+ reduction_steps = (9 - scale) * 0.5
403
+
404
+ for i in range(4): # There are 4 positions to reduce symmetrically
405
+ if reduction_steps >= 1:
406
+ values[i] = 0.0
407
+ values[-(i + 1)] = 0.0
408
+ reduction_steps -= 1
409
+ elif reduction_steps > 0:
410
+ values[i] = 0.5
411
+ values[-(i + 1)] = 0.5
412
  break
413
 
 
 
 
 
 
 
414
  global weights
415
+ index = 0
 
416
 
417
  for outer_key, inner_dict in weights.items():
418
+ for inner_key in inner_dict:
419
+ inner_dict[inner_key] = values[index]
420
+ index += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
+ return weights['down'][4096], weights['down'][1024], weights['down'][256], weights['mid'][64], weights['up'][256], weights['up'][1024], weights['up'][4096]
423
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
+ with gr.Blocks() as demo:
426
  gr.Markdown(
427
+ '''
428
  <div style="text-align: center;">
429
  <div style="display: flex; justify-content: center;">
430
  <img src="https://github.com/user-attachments/assets/55a38e74-ab93-4d80-91c8-0fa6130af45a" alt="Logo">
431
  </div>
432
+ <h1>Out of Focus 1.0</h1>
433
  <p style="font-size:16px;">Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process</p>
434
  </div>
435
  <br>
436
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
437
  <a href="https://www.buymeacoffee.com/outofai" target="_blank"><img src="https://img.shields.io/badge/-buy_me_a%C2%A0coffee-red?logo=buy-me-a-coffee" alt="Buy Me A Coffee"></a> &ensp;
438
+ <a href="https://twitter.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Ashleigh%20Watson"></a> &ensp;
439
+ <a href="https://twitter.com/banterless_ai" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Alex%20Nasa"></a>
440
  </div>
441
+ <br>
442
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
443
+ <p style="display: flex;gap: 6px;">
444
+ <a href="https://huggingface.co/spaces/fffiloni/OutofFocus?duplicate=true">
445
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate this Space">
446
+ </a> to skip the queue and enjoy faster inference on the GPU of your choice
447
+ </p>
448
+ </div>
449
+ '''
450
+ )
451
  with gr.Row():
452
+ with gr.Column():
453
+
454
+ with gr.Row():
455
+ example_input = gr.Image(height=512, width=512, type="filepath", visible=False)
456
+ image_input = gr.Image(height=512, width=512, type="pil", label="Upload Source Image")
457
+ steps_slider = gr.Slider(minimum=5, maximum=25, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image")
458
+ prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image")
459
+ reconstruct_button = gr.Button("Reconstruct")
460
+ stop_button = gr.Button("Stop", variant="stop", interactive=False)
461
+ with gr.Column():
462
+ reconstructed_image = gr.Image(type="pil", label="Reconstructed")
463
+
464
+ with gr.Row():
465
+ invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False)
466
+ interpolate_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image")
467
+ with gr.Row():
468
+ new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or word addition at the end, achieve the best results by swapping words instead of adding or removing in between")
469
+ with gr.Row():
470
+ apply_button = gr.Button("Generate Vision", variant="primary", interactive=False)
471
+ with gr.Row():
472
+ with gr.Accordion(label="Advanced Options", open=False):
473
+ gr.Markdown(
474
+ '''
475
+ <div style="text-align: center;">
476
+ <h1>Weight Adjustment</h1>
477
+ <p style="font-size:16px;">Specific Cross-Attention Influence weights can be manually modified for given resolutions (1.0 = Fully Source Attn 0.0 = Fully Target Attn)</p>
478
+ </div>
479
+ '''
480
+ )
481
+ down_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][4096], label="Self-Attn Down 64x64")
482
+ down_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][1024], label="Self-Attn Down 32x32")
483
+ down_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][256], label="Self-Attn Down 16x16")
484
+ mid_slider_64 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['mid'][64], label="Self-Attn Mid 8x8")
485
+ up_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][256], label="Self-Attn Up 16x16")
486
+ up_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][1024], label="Self-Attn Up 32x32")
487
+ up_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][4096], label="Self-Attn Up 64x64")
488
+
489
+ with gr.Row():
490
+ show_case = gr.Examples(
491
+ examples=[
492
+ ["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background"],
493
+ ["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful"],
494
+ ["assets/example2.png", "a photo of a cat, two ears, white background", "a photo of a panda, two ears, white background"],
495
+ ["assets/example3.png", "a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds", "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"],
496
+
497
+ ],
498
+ inputs=[example_input, prompt_input, new_prompt_input],
499
+ label=None
500
+ )
501
 
502
  meta_data = gr.State()
503
 
504
+ example_input.change(
505
+ fn=on_image_change,
506
+ inputs=example_input,
507
+ outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider]
508
+ ).then(
509
+ lambda: gr.update(interactive=True),
510
+ outputs=apply_button
511
+ ).then(
512
+ lambda: gr.update(interactive=True),
513
+ outputs=new_prompt_input
514
  )
515
  steps_slider.release(update_step, inputs=steps_slider)
516
+ interpolate_slider.release(update_scale, inputs=interpolate_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])
517
+ invisible_slider.change(update_scale, inputs=invisible_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])
518
+
519
+ up_slider_4096.change(update_value, inputs=[up_slider_4096, gr.State('up'), gr.State(4096)])
520
+ up_slider_1024.change(update_value, inputs=[up_slider_1024, gr.State('up'), gr.State(1024)])
521
+ up_slider_256.change(update_value, inputs=[up_slider_256, gr.State('up'), gr.State(256)])
522
+
523
+ down_slider_4096.change(update_value, inputs=[down_slider_4096, gr.State('down'), gr.State(4096)])
524
+ down_slider_1024.change(update_value, inputs=[down_slider_1024, gr.State('down'), gr.State(1024)])
525
+ down_slider_256.change(update_value, inputs=[down_slider_256, gr.State('down'), gr.State(256)])
526
+
527
+ mid_slider_64.change(update_value, inputs=[mid_slider_64, gr.State('mid'), gr.State(64)])
528
+
529
+ reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, meta_data]).then(
530
+ lambda: gr.update(interactive=True),
531
+ outputs=reconstruct_button
532
+ ).then(
533
+ lambda: gr.update(interactive=True),
534
+ outputs=new_prompt_input
535
+ ).then(
536
+ lambda: gr.update(interactive=True),
537
+ outputs=apply_button
538
+ ).then(
539
+ lambda: gr.update(interactive=False),
540
+ outputs=stop_button
541
+ )
542
 
543
+ reconstruct_button.click(
544
+ lambda: gr.update(interactive=False),
545
+ outputs=reconstruct_button
546
+ )
547
 
548
+ reconstruct_button.click(
549
+ lambda: gr.update(interactive=True),
550
+ outputs=stop_button
551
  )
552
 
553
+ reconstruct_button.click(
554
+ lambda: gr.update(interactive=False),
555
+ outputs=apply_button
556
+ )
557
 
558
+ stop_button.click(
559
+ lambda: gr.update(interactive=False),
560
+ outputs=stop_button
561
+ )
562
 
563
  apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image)
564
+ stop_button.click(stop_reconstruct)
565
 
566
+ if __name__ == "__main__":
567
+ parser = argparse.ArgumentParser()
568
+ parser.add_argument("--share", action="store_true")
569
+ args = parser.parse_args()
570
  demo.queue()
571
+ demo.launch(share=args.share)