alexnasa commited on
Commit
deed1cf
·
verified ·
1 Parent(s): 1df9ccd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -215,13 +215,13 @@ def reconstruct(input_img, caption):
215
 
216
  update_scale(12)
217
 
218
- real_gpu = real_image_initial_latents.detach()
219
- inversed_gpu = [x.detach() for x in inversed_latents]
220
 
221
  return image_np, caption, 12, [
222
  caption,
223
- real_gpu,
224
- inversed_gpu,
225
  weights
226
  ]
227
 
@@ -358,10 +358,21 @@ def apply_prompt(meta_data, new_prompt):
358
  pipe, _, scheduler = load_pipeline()
359
  pipe.to("cuda")
360
 
361
- caption, real_image_initial_latents, inversed_latents, saved_weights = meta_data
 
362
  # overwrite the global so your processor can see it
363
  global weights
364
  weights = saved_weights
 
 
 
 
 
 
 
 
 
 
365
 
366
  negative_prompt = ""
367
 
 
215
 
216
  update_scale(12)
217
 
218
+ real_cpu = real_image_initial_latents.detach().cpu()
219
+ inversed_cpu = [x.detach().cpu() for x in inversed_latents]
220
 
221
  return image_np, caption, 12, [
222
  caption,
223
+ real_cpu,
224
+ inversed_cpu,
225
  weights
226
  ]
227
 
 
358
  pipe, _, scheduler = load_pipeline()
359
  pipe.to("cuda")
360
 
361
+ caption, real_latents_cpu, inversed_latents_cpu, saved_weights = meta_data
362
+
363
  # overwrite the global so your processor can see it
364
  global weights
365
  weights = saved_weights
366
+
367
+ # move everything onto CUDA (and match dtype if needed)
368
+ device = next(pipe.unet.parameters()).device
369
+ dtype = next(pipe.unet.parameters()).dtype
370
+
371
+ real_latents = real_latents_cpu.to(device=device, dtype=dtype)
372
+ inversed_latents = [x.to(device=device, dtype=dtype) for x in inversed_latents_cpu]
373
+
374
+ # now all your latents live on CUDA, so the callback won't mix devices
375
+ initial_latents = torch.cat([real_latents] * 2)
376
 
377
  negative_prompt = ""
378