alexnasa commited on
Commit
38d2535
·
verified ·
1 Parent(s): 1a8a8c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -46,6 +46,10 @@ signal_value = 2.0
46
  blur_value = None
47
  allowed_res_max = 1.0
48
 
 
 
 
 
49
 
50
  def weight_population(layer_type, resolution, depth, value):
51
  # Check if layer_type exists, if not, create it
@@ -100,9 +104,9 @@ def reconstruct(input_img, caption):
100
  ])
101
 
102
  if torch_dtype == torch.float16:
103
- loaded_image = transform(img).half().to("cuda").unsqueeze(0)
104
  else:
105
- loaded_image = transform(img).to("cuda").unsqueeze(0)
106
 
107
  if loaded_image.shape[1] == 4:
108
  loaded_image = loaded_image[:,:3,:,:]
@@ -114,7 +118,7 @@ def reconstruct(input_img, caption):
114
 
115
  # notice we disabled the CFG here by setting guidance scale as 1
116
  guidance_scale = 1.0
117
- inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
118
  timesteps = inverse_scheduler.timesteps
119
 
120
  latents = real_image_latents
@@ -148,7 +152,7 @@ def reconstruct(input_img, caption):
148
  real_image_initial_latents = latents
149
 
150
  guidance_scale = guidance_scale_value
151
- scheduler.set_timesteps(num_inference_steps, device="cuda")
152
  timesteps = scheduler.timesteps
153
 
154
  def adjust_latent(pipe, step, timestep, callback_kwargs):
@@ -319,7 +323,7 @@ def apply_prompt(meta_data, new_prompt):
319
  inference_steps = len(inversed_latents)
320
 
321
  guidance_scale = guidance_scale_value
322
- scheduler.set_timesteps(inference_steps, device="cuda")
323
  timesteps = scheduler.timesteps
324
 
325
  initial_latents = torch.cat([real_image_initial_latents] * 2)
@@ -470,8 +474,8 @@ if __name__ == "__main__":
470
  torch_dtype = torch.float16
471
 
472
  # torch_dtype = torch.float16
473
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to("cuda")
474
- pipe.vae = AutoencoderKL.from_pretrained(vae_model_id, subfolder=vae_folder, torch_dtype=torch_dtype).to("cuda")
475
  pipe.load_lora_weights(
476
  hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="pytorch_lora_weights.safetensors"),
477
  adapter_name="res_adapter",
 
46
  blur_value = None
47
  allowed_res_max = 1.0
48
 
49
+ # Device configuration
50
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
51
+ print(f"Using device: {device}")
52
+
53
 
54
  def weight_population(layer_type, resolution, depth, value):
55
  # Check if layer_type exists, if not, create it
 
104
  ])
105
 
106
  if torch_dtype == torch.float16:
107
+ loaded_image = transform(img).half().to(device).unsqueeze(0)
108
  else:
109
+ loaded_image = transform(img).to(device).unsqueeze(0)
110
 
111
  if loaded_image.shape[1] == 4:
112
  loaded_image = loaded_image[:,:3,:,:]
 
118
 
119
  # notice we disabled the CFG here by setting guidance scale as 1
120
  guidance_scale = 1.0
121
+ inverse_scheduler.set_timesteps(num_inference_steps, device=device)
122
  timesteps = inverse_scheduler.timesteps
123
 
124
  latents = real_image_latents
 
152
  real_image_initial_latents = latents
153
 
154
  guidance_scale = guidance_scale_value
155
+ scheduler.set_timesteps(num_inference_steps, device=device)
156
  timesteps = scheduler.timesteps
157
 
158
  def adjust_latent(pipe, step, timestep, callback_kwargs):
 
323
  inference_steps = len(inversed_latents)
324
 
325
  guidance_scale = guidance_scale_value
326
+ scheduler.set_timesteps(inference_steps, device=device)
327
  timesteps = scheduler.timesteps
328
 
329
  initial_latents = torch.cat([real_image_initial_latents] * 2)
 
474
  torch_dtype = torch.float16
475
 
476
  # torch_dtype = torch.float16
477
+ pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
478
+ pipe.vae = AutoencoderKL.from_pretrained(vae_model_id, subfolder=vae_folder, torch_dtype=torch_dtype).to(device)
479
  pipe.load_lora_weights(
480
  hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="pytorch_lora_weights.safetensors"),
481
  adapter_name="res_adapter",