adaface-neurips commited on
Commit
9bcd020
·
1 Parent(s): c75a857

update code and gradio version

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 😻
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.30.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
adaface/adaface_wrapper.py CHANGED
@@ -30,8 +30,8 @@ class AdaFaceWrapper(nn.Module):
30
  use_840k_vae=False, use_ds_text_encoder=False,
31
  main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
  enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
- attn_lora_layer_names=['q', 'k', 'v', 'out'], shrink_cross_attn=False, q_lora_updates_query=False,
34
- device='cuda', is_training=False):
35
  '''
36
  pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
37
  If None, it's used only as a face encoder, and the unet and vae are
@@ -52,7 +52,7 @@ class AdaFaceWrapper(nn.Module):
52
  self.q_lora_updates_query = q_lora_updates_query
53
  self.use_lcm = use_lcm
54
  self.subject_string = subject_string
55
- self.shrink_cross_attn = shrink_cross_attn
56
 
57
  self.default_scheduler_name = default_scheduler_name
58
  self.num_inference_steps = num_inference_steps if not use_lcm else 4
@@ -64,6 +64,7 @@ class AdaFaceWrapper(nn.Module):
64
  self.unet_weights_in_ensemble = unet_weights_in_ensemble
65
  self.device = device
66
  self.is_training = is_training
 
67
 
68
  if negative_prompt is None:
69
  self.negative_prompt = \
@@ -99,6 +100,7 @@ class AdaFaceWrapper(nn.Module):
99
  self.adaface_ckpt_paths,
100
  self.adaface_encoder_cfg_scales,
101
  self.enabled_encoders,
 
102
  num_static_img_suffix_embs=4)
103
 
104
  self.id2ada_prompt_encoder.to(self.device)
@@ -189,10 +191,10 @@ class AdaFaceWrapper(nn.Module):
189
  pipeline.unet = unet_ensemble
190
 
191
  print(f"Loaded pipeline from {self.base_model_path}.")
192
- if not remove_unet and (self.unet_uses_attn_lora or self.shrink_cross_attn):
193
  unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
194
  attn_lora_layer_names=self.attn_lora_layer_names,
195
- shrink_cross_attn=self.shrink_cross_attn,
196
  q_lora_updates_query=self.q_lora_updates_query)
197
 
198
  pipeline.unet = unet2
@@ -294,12 +296,11 @@ class AdaFaceWrapper(nn.Module):
294
  def load_unet_loras(self, unet, unet_lora_modules_state_dict,
295
  use_attn_lora=True, use_ffn_lora=False,
296
  attn_lora_layer_names=['q', 'k', 'v', 'out'],
297
- shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
298
  q_lora_updates_query=False):
299
  attn_capture_procs, attn_opt_modules = \
300
  set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
301
  lora_rank=192, lora_scale_down=8,
302
- cross_attn_shrink_factor=cross_attn_shrink_factor,
303
  q_lora_updates_query=q_lora_updates_query)
304
  # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
305
  if use_ffn_lora:
@@ -343,16 +344,17 @@ class AdaFaceWrapper(nn.Module):
343
  print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
344
  self.outfeat_capture_blocks.append(unet.up_blocks[3])
345
 
346
- # If shrink_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
347
  # but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
348
  set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
349
  use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
350
- shrink_cross_attn=shrink_cross_attn)
 
351
 
352
  return unet
353
 
354
  def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
355
- shrink_cross_attn=False, q_lora_updates_query=False):
356
  unet_lora_weight_found = False
357
  if isinstance(self.adaface_ckpt_paths, str):
358
  adaface_ckpt_paths = [self.adaface_ckpt_paths]
@@ -360,7 +362,7 @@ class AdaFaceWrapper(nn.Module):
360
  adaface_ckpt_paths = self.adaface_ckpt_paths
361
 
362
  for adaface_ckpt_path in adaface_ckpt_paths:
363
- ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu')
364
  if 'unet_lora_modules' in ckpt_dict:
365
  unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
366
  print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
@@ -379,7 +381,7 @@ class AdaFaceWrapper(nn.Module):
379
  unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
380
  use_attn_lora=use_attn_lora,
381
  attn_lora_layer_names=attn_lora_layer_names,
382
- shrink_cross_attn=shrink_cross_attn,
383
  q_lora_updates_query=q_lora_updates_query)
384
  unet.unets[i] = unet_
385
  print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
@@ -387,7 +389,7 @@ class AdaFaceWrapper(nn.Module):
387
  unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
388
  use_attn_lora=use_attn_lora,
389
  attn_lora_layer_names=attn_lora_layer_names,
390
- shrink_cross_attn=shrink_cross_attn,
391
  q_lora_updates_query=q_lora_updates_query)
392
 
393
  return unet
@@ -612,8 +614,9 @@ class AdaFaceWrapper(nn.Module):
612
  # Scan prompt and replace tokens in self.placeholder_token_ids
613
  # with the corresponding image embeddings.
614
  prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
 
615
  prompt_embeds2 = prompt_embeds.clone()
616
- if alt_prompt_embed_type == 'img':
617
  if self.img_prompt_embs is None:
618
  print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
619
  return prompt_embeds
@@ -628,17 +631,18 @@ class AdaFaceWrapper(nn.Module):
628
  breakpoint()
629
 
630
  repl_tokens = {}
 
631
  for i in range(len(prompt_tokens)):
632
  if prompt_tokens[i] in self.all_placeholder_tokens:
633
  encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
634
  if prompt_tokens[i] in sublist), 0)
635
- alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx]
636
- prompt_embeds2[:, i] = prompt_embeds2[:, i] * (1 - alt_prompt_emb_weight) \
637
  + repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
638
  repl_tokens[prompt_tokens[i]] = 1
639
 
640
  repl_token_count = len(repl_tokens)
641
- if np.all(np.array(alt_prompt_emb_weights) == 1):
642
  print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
643
  else:
644
  print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
@@ -650,7 +654,7 @@ class AdaFaceWrapper(nn.Module):
650
  placeholder_tokens_pos='append',
651
  ablate_prompt_only_placeholders=False,
652
  ablate_prompt_no_placeholders=False,
653
- ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
654
  nonmix_prompt_emb_weight=0,
655
  repeat_prompt_for_each_encoder=True,
656
  device=None, verbose=False):
@@ -678,14 +682,25 @@ class AdaFaceWrapper(nn.Module):
678
  prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
679
  self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
680
 
681
- if ablate_prompt_embed_type != 'ada':
682
  alt_prompt_embed_type = ablate_prompt_embed_type
683
- alt_prompt_emb_weights = (1, 1)
 
 
 
 
 
 
 
 
684
  elif nonmix_prompt_emb_weight > 0:
685
  alt_prompt_embed_type = 'ada-nonmix'
686
- alt_prompt_emb_weights = (nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
 
 
687
  else:
688
- alt_prompt_emb_weights = (0, 0)
 
689
 
690
  if sum(alt_prompt_emb_weights) > 0:
691
  prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
 
30
  use_840k_vae=False, use_ds_text_encoder=False,
31
  main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
  enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
+ attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
34
+ device='cuda', is_training=False, is_on_hf_space=False):
35
  '''
36
  pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
37
  If None, it's used only as a face encoder, and the unet and vae are
 
52
  self.q_lora_updates_query = q_lora_updates_query
53
  self.use_lcm = use_lcm
54
  self.subject_string = subject_string
55
+ self.normalize_cross_attn = normalize_cross_attn
56
 
57
  self.default_scheduler_name = default_scheduler_name
58
  self.num_inference_steps = num_inference_steps if not use_lcm else 4
 
64
  self.unet_weights_in_ensemble = unet_weights_in_ensemble
65
  self.device = device
66
  self.is_training = is_training
67
+ self.is_on_hf_space = is_on_hf_space
68
 
69
  if negative_prompt is None:
70
  self.negative_prompt = \
 
100
  self.adaface_ckpt_paths,
101
  self.adaface_encoder_cfg_scales,
102
  self.enabled_encoders,
103
+ is_on_hf_space=self.is_on_hf_space,
104
  num_static_img_suffix_embs=4)
105
 
106
  self.id2ada_prompt_encoder.to(self.device)
 
191
  pipeline.unet = unet_ensemble
192
 
193
  print(f"Loaded pipeline from {self.base_model_path}.")
194
+ if not remove_unet and (self.unet_uses_attn_lora or self.normalize_cross_attn):
195
  unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
196
  attn_lora_layer_names=self.attn_lora_layer_names,
197
+ normalize_cross_attn=self.normalize_cross_attn,
198
  q_lora_updates_query=self.q_lora_updates_query)
199
 
200
  pipeline.unet = unet2
 
296
  def load_unet_loras(self, unet, unet_lora_modules_state_dict,
297
  use_attn_lora=True, use_ffn_lora=False,
298
  attn_lora_layer_names=['q', 'k', 'v', 'out'],
299
+ normalize_cross_attn=False,
300
  q_lora_updates_query=False):
301
  attn_capture_procs, attn_opt_modules = \
302
  set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
303
  lora_rank=192, lora_scale_down=8,
 
304
  q_lora_updates_query=q_lora_updates_query)
305
  # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
306
  if use_ffn_lora:
 
344
  print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
345
  self.outfeat_capture_blocks.append(unet.up_blocks[3])
346
 
347
+ # If normalize_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
348
  # but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
349
  set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
350
  use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
351
+ normalize_cross_attn=normalize_cross_attn, mix_attn_mats_in_batch=False,
352
+ res_hidden_states_gradscale=0)
353
 
354
  return unet
355
 
356
  def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
357
+ normalize_cross_attn=False, q_lora_updates_query=False):
358
  unet_lora_weight_found = False
359
  if isinstance(self.adaface_ckpt_paths, str):
360
  adaface_ckpt_paths = [self.adaface_ckpt_paths]
 
362
  adaface_ckpt_paths = self.adaface_ckpt_paths
363
 
364
  for adaface_ckpt_path in adaface_ckpt_paths:
365
+ ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu', weights_only=False)
366
  if 'unet_lora_modules' in ckpt_dict:
367
  unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
368
  print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
 
381
  unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
382
  use_attn_lora=use_attn_lora,
383
  attn_lora_layer_names=attn_lora_layer_names,
384
+ normalize_cross_attn=normalize_cross_attn,
385
  q_lora_updates_query=q_lora_updates_query)
386
  unet.unets[i] = unet_
387
  print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
 
389
  unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
390
  use_attn_lora=use_attn_lora,
391
  attn_lora_layer_names=attn_lora_layer_names,
392
+ normalize_cross_attn=normalize_cross_attn,
393
  q_lora_updates_query=q_lora_updates_query)
394
 
395
  return unet
 
614
  # Scan prompt and replace tokens in self.placeholder_token_ids
615
  # with the corresponding image embeddings.
616
  prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
617
+ # prompt_embeds are the ada embeddings.
618
  prompt_embeds2 = prompt_embeds.clone()
619
+ if alt_prompt_embed_type.startswith('img'):
620
  if self.img_prompt_embs is None:
621
  print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
622
  return prompt_embeds
 
631
  breakpoint()
632
 
633
  repl_tokens = {}
634
+ ada_emb_weight = alt_prompt_emb_weights[0]
635
  for i in range(len(prompt_tokens)):
636
  if prompt_tokens[i] in self.all_placeholder_tokens:
637
  encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
638
  if prompt_tokens[i] in sublist), 0)
639
+ alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx + 1]
640
+ prompt_embeds2[:, i] = prompt_embeds2[:, i] * ada_emb_weight \
641
  + repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
642
  repl_tokens[prompt_tokens[i]] = 1
643
 
644
  repl_token_count = len(repl_tokens)
645
+ if ada_emb_weight == 0:
646
  print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
647
  else:
648
  print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
 
654
  placeholder_tokens_pos='append',
655
  ablate_prompt_only_placeholders=False,
656
  ablate_prompt_no_placeholders=False,
657
+ ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img', 'img1', 'img2'.
658
  nonmix_prompt_emb_weight=0,
659
  repeat_prompt_for_each_encoder=True,
660
  device=None, verbose=False):
 
682
  prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
683
  self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
684
 
685
+ if ablate_prompt_embed_type.startswith('img'):
686
  alt_prompt_embed_type = ablate_prompt_embed_type
687
+ if alt_prompt_embed_type == 'img1':
688
+ # The mixing weights of ada, img1, and img2 are 0, 1, and 0.
689
+ alt_prompt_emb_weights = (0, 1, 0)
690
+ elif alt_prompt_embed_type == 'img2':
691
+ # The mixing weights of ada, img1, and img2 are 0, 0, and 1.
692
+ alt_prompt_emb_weights = (0, 0, 1)
693
+ else:
694
+ # The mixing weights of ada, img1, and img2 are 0, 1, and 1.
695
+ alt_prompt_emb_weights = (0, 1, 1)
696
  elif nonmix_prompt_emb_weight > 0:
697
  alt_prompt_embed_type = 'ada-nonmix'
698
+ # The mixing weight of ada is 1 - nonmix_prompt_emb_weight, instead of 1 - nonmix_prompt_emb_weight * 2.
699
+ # It means ada is mixed by this weight with both img1 and img2.
700
+ alt_prompt_emb_weights = (1 - nonmix_prompt_emb_weight, nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
701
  else:
702
+ # Don't change the prompt embeddings. So we set all the mixing weights to 0.
703
+ alt_prompt_emb_weights = (0, 0, 0)
704
 
705
  if sum(alt_prompt_emb_weights) > 0:
706
  prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
adaface/diffusers_attn_lora_capture.py CHANGED
@@ -4,7 +4,6 @@ import torch.nn.functional as F
4
  from typing import Optional, Tuple, Dict, Any
5
  from diffusers.models.attention_processor import Attention, AttnProcessor2_0
6
  from diffusers.utils import logging, is_torch_version, deprecate
7
- from diffusers.utils.torch_utils import fourier_filter
8
  # UNet is a diffusers PeftAdapterMixin instance.
9
  from diffusers.loaders.peft import PeftAdapterMixin
10
  from peft import LoraConfig, get_peft_model
@@ -12,7 +11,6 @@ import peft.tuners.lora as peft_lora
12
  from peft.tuners.lora.dora import DoraLinearLayer
13
  from einops import rearrange
14
  import math, re
15
- import numpy as np
16
  from peft.tuners.tuners_utils import BaseTunerLayer
17
 
18
 
@@ -28,7 +26,7 @@ class ScaleGrad(torch.autograd.Function):
28
  ctx.save_for_backward(alpha_, debug)
29
  output = input_
30
  if debug:
31
- print(f"input: {input_.abs().mean().item()}")
32
  return output
33
 
34
  @staticmethod
@@ -38,7 +36,7 @@ class ScaleGrad(torch.autograd.Function):
38
  if ctx.needs_input_grad[0]:
39
  grad_output2 = grad_output * alpha_
40
  if debug:
41
- print(f"grad_output2: {grad_output2.abs().mean().item()}")
42
  else:
43
  grad_output2 = None
44
  return grad_output2, None, None
@@ -77,36 +75,11 @@ def split_indices_by_instance(indices, as_dict=False):
77
  indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
78
  return indices_by_instance
79
 
80
- # If do_sum, returned emb_attns is 3D. Otherwise 4D.
81
- # indices are applied on the first 2 dims of attn_mat.
82
- def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False):
83
- indices_by_instance = split_indices_by_instance(indices)
84
-
85
- # emb_attns[0]: [1, 9, 8, 64]
86
- # 8: 8 attention heads. Last dim 64: number of image tokens.
87
- emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ]
88
- if all_token_weights is not None:
89
- # all_token_weights: [4, 77].
90
- # token_weights_by_instance[0]: [1, 9, 1, 1].
91
- token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ]
92
- else:
93
- token_weights = [ 1 ] * len(indices_by_instance)
94
-
95
- # Apply token weights.
96
- emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ]
97
-
98
- # sum among K_subj_i subj embeddings -> [1, 8, 64]
99
- if do_sum:
100
- emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ]
101
- elif do_mean:
102
- emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ]
103
-
104
- emb_attns = torch.cat(emb_attns, dim=0)
105
- return emb_attns
106
-
107
  # Slow implementation equivalent to F.scaled_dot_product_attention.
108
- def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
109
- shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
 
 
110
  is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
111
  B, L, S = query.size(0), query.size(-2), key.size(-2)
112
  scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
@@ -128,21 +101,39 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
128
  key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
129
  value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
130
 
131
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
132
 
133
- if shrink_cross_attn:
134
- cross_attn_scale = cross_attn_shrink_factor
135
- else:
136
- cross_attn_scale = 1
137
-
138
- # attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_weight.
139
- attn_weight += attn_bias
140
- attn_score = attn_weight
141
- attn_weight = torch.softmax(attn_weight, dim=-1)
142
- # NOTE: After scaling, the "probabilities" of the subject embeddings will sum to < 1.
143
- # But this is intended, as we want to scale down the impact of the subject embeddings
144
- # in the computed attention output tensors.
145
- attn_weight = attn_weight * cross_attn_scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
147
  output = attn_weight @ value
148
  return output, attn_score, attn_weight
@@ -156,23 +147,25 @@ class AttnProcessor_LoRA_Capture(nn.Module):
156
  def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
157
  lora_uses_dora=True, lora_proj_layers=None,
158
  lora_rank: int = 192, lora_alpha: float = 16,
159
- cross_attn_shrink_factor: float = 0.5,
160
  q_lora_updates_query=False, attn_proc_idx=-1):
161
  super().__init__()
162
 
163
  self.global_enable_lora = enable_lora
164
  self.attn_proc_idx = attn_proc_idx
165
  # reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
166
- # By default, shrink_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
167
- self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora)
168
  self.lora_rank = lora_rank
169
  self.lora_alpha = lora_alpha
170
  self.lora_scale = self.lora_alpha / self.lora_rank
171
- self.cross_attn_shrink_factor = cross_attn_shrink_factor
172
  self.q_lora_updates_query = q_lora_updates_query
173
 
174
  self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
175
  if self.global_enable_lora:
 
 
 
 
176
  for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
177
  if lora_layer_name == 'q':
178
  self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
@@ -188,9 +181,10 @@ class AttnProcessor_LoRA_Capture(nn.Module):
188
  use_dora=lora_uses_dora, lora_dropout=0.1)
189
 
190
  # LoRA layers can be enabled/disabled dynamically.
191
- def reset_attn_cache_and_flags(self, capture_ca_activations, shrink_cross_attn, enable_lora):
192
  self.capture_ca_activations = capture_ca_activations
193
- self.shrink_cross_attn = shrink_cross_attn
 
194
  self.cached_activations = {}
195
  # Only enable LoRA for the next call(s) if global_enable_lora is set to True.
196
  self.enable_lora = enable_lora and self.global_enable_lora
@@ -312,11 +306,14 @@ class AttnProcessor_LoRA_Capture(nn.Module):
312
  breakpoint()
313
 
314
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
315
- if is_cross_attn and (self.capture_ca_activations or self.shrink_cross_attn):
316
  hidden_states, attn_score, attn_prob = \
317
  scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
318
- dropout_p=0.0, shrink_cross_attn=self.shrink_cross_attn,
319
- cross_attn_shrink_factor=self.cross_attn_shrink_factor)
 
 
 
320
  else:
321
  # Use the faster implementation of scaled_dot_product_attention
322
  # when not capturing the activations or suppressing the subject attention.
@@ -452,7 +449,7 @@ def CrossAttnUpBlock2D_forward_capture(
452
  # Adapted from ConsistentIDPipeline:set_ip_adapter().
453
  # attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
454
  def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
455
- lora_rank=192, lora_scale_down=8, cross_attn_shrink_factor=0.5,
456
  q_lora_updates_query=False):
457
  attn_procs = {}
458
  attn_capture_procs = {}
@@ -502,7 +499,6 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
502
  lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
503
  # LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
504
  lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
505
- cross_attn_shrink_factor=cross_attn_shrink_factor,
506
  q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
507
 
508
  attn_proc_idx += 1
@@ -513,6 +509,11 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
513
  attn_capture_procs[name] = attn_capture_proc
514
 
515
  if use_attn_lora:
 
 
 
 
 
516
  for subname, module in attn_capture_proc.named_modules():
517
  if isinstance(module, peft_lora.LoraLayer):
518
  # ModuleDict doesn't allow "." in the key.
@@ -537,7 +538,7 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
537
  return attn_capture_procs, attn_opt_modules
538
 
539
  # NOTE: cross-attn layers are included in the returned lora_modules.
540
- def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=192, lora_alpha=16):
541
  # target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
542
  # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
543
  # Cannot set to conv.+ as it will match added adapter module names, including
@@ -592,15 +593,18 @@ def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=1
592
  def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
593
  outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
594
  use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
595
- shrink_cross_attn, res_hidden_states_gradscale):
596
  # For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
597
- for attn_capture_proc in attn_capture_procs:
598
- attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, shrink_cross_attn, enable_lora=use_attn_lora)
 
599
  # outfeat_capture_blocks only contains the last up block, up_blocks[3].
600
  # It contains 3 FFN layers. We want to capture their output features.
601
  for block in outfeat_capture_blocks:
602
  block.capture_outfeats = capture_ca_activations
603
 
 
 
604
  for block in res_hidden_states_gradscale_blocks:
605
  block.res_hidden_states_gradscale = res_hidden_states_gradscale
606
 
@@ -639,6 +643,7 @@ def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat
639
  block.cached_outfeats = {}
640
  block.capture_outfeats = False
641
 
 
642
  for layer_idx in captured_layer_indices:
643
  # Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
644
  # 23, 24 -> 1, 2 (!! not 0, 1 !!)
 
4
  from typing import Optional, Tuple, Dict, Any
5
  from diffusers.models.attention_processor import Attention, AttnProcessor2_0
6
  from diffusers.utils import logging, is_torch_version, deprecate
 
7
  # UNet is a diffusers PeftAdapterMixin instance.
8
  from diffusers.loaders.peft import PeftAdapterMixin
9
  from peft import LoraConfig, get_peft_model
 
11
  from peft.tuners.lora.dora import DoraLinearLayer
12
  from einops import rearrange
13
  import math, re
 
14
  from peft.tuners.tuners_utils import BaseTunerLayer
15
 
16
 
 
26
  ctx.save_for_backward(alpha_, debug)
27
  output = input_
28
  if debug:
29
+ print(f"input: {input_.abs().mean().detach().item()}")
30
  return output
31
 
32
  @staticmethod
 
36
  if ctx.needs_input_grad[0]:
37
  grad_output2 = grad_output * alpha_
38
  if debug:
39
+ print(f"grad_output2: {grad_output2.abs().mean().detach().item()}")
40
  else:
41
  grad_output2 = None
42
  return grad_output2, None, None
 
75
  indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
76
  return indices_by_instance
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Slow implementation equivalent to F.scaled_dot_product_attention.
79
+ def scaled_dot_product_attention(query, key, value, cross_attn_scale_factor,
80
+ attn_mask=None, dropout_p=0.0,
81
+ subj_indices=None, normalize_cross_attn=False,
82
+ mix_attn_mats_in_batch=False,
83
  is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
84
  B, L, S = query.size(0), query.size(-2), key.size(-2)
85
  scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
 
101
  key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
102
  value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
103
 
104
+ attn_score = query @ key.transpose(-2, -1) * scale_factor
105
 
106
+ # attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_score.
107
+ attn_score += attn_bias
108
+ if mix_attn_mats_in_batch:
109
+ # The instances in the batch are [sc, mc]. We average their attn scores,
110
+ # and apply to both instances.
111
+ # attn_score: [2, 8, 4096, 77] -> [1, 8, 4096, 77] -> [2, 8, 4096, 77].
112
+ # If BLOCK_SIZE > 1, attn_score.shape[0] = 2 * BLOCK_SIZE.
113
+ if attn_score.shape[0] %2 != 0:
114
+ breakpoint()
115
+ attn_score_sc, attn_score_mc = attn_score.chunk(2, dim=0)
116
+ # Cut off the grad flow from the SC instance to the MC instance.
117
+ attn_score = (attn_score_sc + attn_score_mc.detach()) / 2
118
+ attn_score = attn_score.repeat(2, 1, 1, 1)
119
+ elif normalize_cross_attn:
120
+ if subj_indices is None:
121
+ breakpoint()
122
+ subj_indices_B, subj_indices_N = subj_indices
123
+ subj_attn_score = attn_score[subj_indices_B, :, :, subj_indices_N]
124
+ # Normalize the attention score of the subject tokens to have mean 0 across tokens,
125
+ # so that positive and negative scores are balanced.
126
+ subj_attn_score = subj_attn_score - subj_attn_score.mean(dim=2, keepdim=True).detach()
127
+ # cross_attn_scale is a learnable parameter, so the score will be scaled appropriately.
128
+ # Scale up the BP'ed gradient to cross_attn_scale_factor by 10x.
129
+ ca_scale_grad_scaler = gen_gradient_scaler(10)
130
+ subj_attn_score = subj_attn_score * ca_scale_grad_scaler(cross_attn_scale_factor)
131
+ attn_score2 = attn_score.clone()
132
+ attn_score2[subj_indices_B, :, :, subj_indices_N] = subj_attn_score
133
+ attn_score = attn_score2
134
+ # Otherwise, do nothing to attn_score.
135
+
136
+ attn_weight = torch.softmax(attn_score, dim=-1)
137
  attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
138
  output = attn_weight @ value
139
  return output, attn_score, attn_weight
 
147
  def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
148
  lora_uses_dora=True, lora_proj_layers=None,
149
  lora_rank: int = 192, lora_alpha: float = 16,
 
150
  q_lora_updates_query=False, attn_proc_idx=-1):
151
  super().__init__()
152
 
153
  self.global_enable_lora = enable_lora
154
  self.attn_proc_idx = attn_proc_idx
155
  # reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
156
+ # By default, normalize_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
157
+ self.reset_attn_cache_and_flags(capture_ca_activations, False, False, enable_lora)
158
  self.lora_rank = lora_rank
159
  self.lora_alpha = lora_alpha
160
  self.lora_scale = self.lora_alpha / self.lora_rank
 
161
  self.q_lora_updates_query = q_lora_updates_query
162
 
163
  self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
164
  if self.global_enable_lora:
165
+ # enable_lora = True iff this is a cross-attn layer in the last 3 up blocks.
166
+ # Since we only use cross_attn_scale_factor on cross-attn layers,
167
+ # we only use cross_attn_scale_factor when enable_lora is True.
168
+ self.cross_attn_scale_factor = nn.Parameter(torch.tensor(0.8), requires_grad=True)
169
  for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
170
  if lora_layer_name == 'q':
171
  self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
 
181
  use_dora=lora_uses_dora, lora_dropout=0.1)
182
 
183
  # LoRA layers can be enabled/disabled dynamically.
184
+ def reset_attn_cache_and_flags(self, capture_ca_activations, normalize_cross_attn, mix_attn_mats_in_batch, enable_lora):
185
  self.capture_ca_activations = capture_ca_activations
186
+ self.normalize_cross_attn = normalize_cross_attn
187
+ self.mix_attn_mats_in_batch = mix_attn_mats_in_batch
188
  self.cached_activations = {}
189
  # Only enable LoRA for the next call(s) if global_enable_lora is set to True.
190
  self.enable_lora = enable_lora and self.global_enable_lora
 
306
  breakpoint()
307
 
308
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
309
+ if is_cross_attn and (self.capture_ca_activations or self.normalize_cross_attn):
310
  hidden_states, attn_score, attn_prob = \
311
  scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
312
+ dropout_p=0.0, subj_indices=subj_indices,
313
+ normalize_cross_attn=self.normalize_cross_attn,
314
+ cross_attn_scale_factor=self.cross_attn_scale_factor,
315
+ mix_attn_mats_in_batch=self.mix_attn_mats_in_batch)
316
+
317
  else:
318
  # Use the faster implementation of scaled_dot_product_attention
319
  # when not capturing the activations or suppressing the subject attention.
 
449
  # Adapted from ConsistentIDPipeline:set_ip_adapter().
450
  # attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
451
  def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
452
+ lora_rank=192, lora_scale_down=8,
453
  q_lora_updates_query=False):
454
  attn_procs = {}
455
  attn_capture_procs = {}
 
499
  lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
500
  # LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
501
  lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
 
502
  q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
503
 
504
  attn_proc_idx += 1
 
509
  attn_capture_procs[name] = attn_capture_proc
510
 
511
  if use_attn_lora:
512
+ cross_attn_scale_factor_name = name + "_cross_attn_scale_factor"
513
+ # Put cross_attn_scale_factor in attn_opt_modules, so that we can optimize and save/load it.
514
+ attn_opt_modules[cross_attn_scale_factor_name] = attn_capture_proc.cross_attn_scale_factor
515
+
516
+ # Put LoRA layers in attn_opt_modules, so that we can optimize and save/load them.
517
  for subname, module in attn_capture_proc.named_modules():
518
  if isinstance(module, peft_lora.LoraLayer):
519
  # ModuleDict doesn't allow "." in the key.
 
538
  return attn_capture_procs, attn_opt_modules
539
 
540
  # NOTE: cross-attn layers are included in the returned lora_modules.
541
+ def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=True, lora_rank=192, lora_alpha=16):
542
  # target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
543
  # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
544
  # Cannot set to conv.+ as it will match added adapter module names, including
 
593
  def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
594
  outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
595
  use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
596
+ normalize_cross_attn, mix_attn_mats_in_batch, res_hidden_states_gradscale):
597
  # For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
598
+ for i, attn_capture_proc in enumerate(attn_capture_procs):
599
+ attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, normalize_cross_attn, mix_attn_mats_in_batch,
600
+ enable_lora=use_attn_lora)
601
  # outfeat_capture_blocks only contains the last up block, up_blocks[3].
602
  # It contains 3 FFN layers. We want to capture their output features.
603
  for block in outfeat_capture_blocks:
604
  block.capture_outfeats = capture_ca_activations
605
 
606
+ # res_hidden_states_gradscale_blocks contain the second to the last up blocks, up_blocks[1:].
607
+ # It's only used to set res_hidden_states_gradscale, and doesn't capture anything.
608
  for block in res_hidden_states_gradscale_blocks:
609
  block.res_hidden_states_gradscale = res_hidden_states_gradscale
610
 
 
643
  block.cached_outfeats = {}
644
  block.capture_outfeats = False
645
 
646
+
647
  for layer_idx in captured_layer_indices:
648
  # Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
649
  # 23, 24 -> 1, 2 (!! not 0, 1 !!)
adaface/face_id_to_ada_prompt.py CHANGED
@@ -26,7 +26,7 @@ def create_id2ada_prompt_encoder(adaface_encoder_types, adaface_ckpt_paths=None,
26
  if adaface_encoder_type == 'arc2face':
27
  id2ada_prompt_encoder = \
28
  Arc2Face_ID2AdaPrompt(adaface_ckpt_path=adaface_ckpt_path,
29
- *args, **kwargs)
30
  elif adaface_encoder_type == 'consistentID':
31
  id2ada_prompt_encoder = \
32
  ConsistentID_ID2AdaPrompt(pipe=None,
@@ -64,6 +64,7 @@ class FaceID2AdaPrompt(nn.Module):
64
  # i.e., 6 for arc2face and 1 for consistentID.
65
  self.out_id_embs_cfg_scale = kwargs.get('out_id_embs_cfg_scale', -1)
66
  self.is_training = kwargs.get('is_training', False)
 
67
  # extend_prompt2token_proj_attention_multiplier is an integer >= 1.
68
  # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
69
  self.extend_prompt2token_proj_attention_multiplier = kwargs.get('extend_prompt2token_proj_attention_multiplier', 1)
@@ -603,9 +604,13 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
603
  '''
604
  # Use the same model as ID2AdaPrompt does.
605
  # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
606
- # Note there are two "models" in the path.
 
 
 
 
607
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
608
- providers=['CPUExecutionProvider'])
609
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
610
  print(f'Arc2Face Face encoder loaded on CPU.')
611
 
@@ -642,7 +647,6 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
642
 
643
  def _apply(self, fn):
644
  super()._apply(fn) # Call the parent _apply to handle parameters and buffers
645
- return
646
  # A dirty hack to get the device of the model, passed from
647
  # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
648
  test_tensor = torch.zeros(1) # Create a test tensor
@@ -651,22 +655,24 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
651
  # No need to reload face_app on the same device.
652
  if device == self.device:
653
  return
 
 
 
 
 
654
 
655
  if str(device) == 'cpu':
656
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
657
- providers=['CPUExecutionProvider'])
658
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
659
  else:
660
  device_id = device.index
661
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
662
  providers=['CUDAExecutionProvider'],
663
- provider_options=[{"device_id": device_id,
664
- "cudnn_conv_algo_search": "HEURISTIC",
665
- "gpu_mem_limit": 2 * 1024**3
666
- }])
667
  self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
668
 
669
- self.device = device
670
  print(f'Arc2Face Face encoder reloaded on {device}.')
671
  return
672
 
@@ -739,8 +745,8 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
739
  # but diffusers will call .to(dtype) in .from_single_file(),
740
  # and at that moment, the consistentID specific modules are not loaded yet.
741
  pipe = ConsistentIDPipeline.from_single_file(base_model_path)
742
- pipe.load_ConsistentID_model(consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin",
743
- bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth")
744
  pipe.to(dtype=self.dtype)
745
  # Since the passed-in pipe is None, this should be called during inference,
746
  # when the teacher ConsistentIDPipeline is not initialized.
@@ -791,7 +797,6 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
791
 
792
  def _apply(self, fn):
793
  super()._apply(fn) # Call the parent _apply to handle parameters and buffers
794
- return
795
  # A dirty hack to get the device of the model, passed from
796
  # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
797
  test_tensor = torch.zeros(1) # Create a test tensor
@@ -800,6 +805,11 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
800
  # No need to reload face_app on the same device.
801
  if device == self.device:
802
  return
 
 
 
 
 
803
 
804
  if str(device) == 'cpu':
805
  self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
@@ -809,13 +819,10 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
809
  device_id = device.index
810
  self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
811
  providers=['CUDAExecutionProvider'],
812
- provider_options=[{"device_id": device_id,
813
- "cudnn_conv_algo_search": "HEURISTIC",
814
- "gpu_mem_limit": 2 * 1024**3
815
- }])
816
  self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
817
 
818
- self.device = device
819
  self.pipe.face_app = self.face_app
820
  print(f'ConsistentID Face encoder reloaded on {device}.')
821
 
@@ -1277,7 +1284,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1277
  # No faces are found in the images, so return None embeddings.
1278
  # We don't want to return an all-zero embedding, which is useless.
1279
  if num_available_id_vecs == 0:
1280
- return None, [0]
1281
 
1282
  # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
1283
  # during inference, we average across the batch dim.
 
26
  if adaface_encoder_type == 'arc2face':
27
  id2ada_prompt_encoder = \
28
  Arc2Face_ID2AdaPrompt(adaface_ckpt_path=adaface_ckpt_path,
29
+ *args, **kwargs)
30
  elif adaface_encoder_type == 'consistentID':
31
  id2ada_prompt_encoder = \
32
  ConsistentID_ID2AdaPrompt(pipe=None,
 
64
  # i.e., 6 for arc2face and 1 for consistentID.
65
  self.out_id_embs_cfg_scale = kwargs.get('out_id_embs_cfg_scale', -1)
66
  self.is_training = kwargs.get('is_training', False)
67
+ self.is_on_hf_space = kwargs.get('is_on_hf_space', False)
68
  # extend_prompt2token_proj_attention_multiplier is an integer >= 1.
69
  # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
70
  self.extend_prompt2token_proj_attention_multiplier = kwargs.get('extend_prompt2token_proj_attention_multiplier', 1)
 
604
  '''
605
  # Use the same model as ID2AdaPrompt does.
606
  # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
607
+ # Note there's a second "model" in the path.
608
+ # Note DO use CUDAExecutionProvider during training and CPUExecutionProvider during inference.
609
+ # Otherwise, CPUExecutionProvider will hang DDP training,
610
+ # and CUDAExecutionProvider will cause OOM on huggingface spaces.
611
+ self.onnx_providers = ['CUDAExecutionProvider'] if self.is_training else ['CPUExecutionProvider']
612
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
613
+ providers=self.onnx_providers)
614
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
615
  print(f'Arc2Face Face encoder loaded on CPU.')
616
 
 
647
 
648
  def _apply(self, fn):
649
  super()._apply(fn) # Call the parent _apply to handle parameters and buffers
 
650
  # A dirty hack to get the device of the model, passed from
651
  # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
652
  test_tensor = torch.zeros(1) # Create a test tensor
 
655
  # No need to reload face_app on the same device.
656
  if device == self.device:
657
  return
658
+ self.device = device
659
+
660
+ if self.is_on_hf_space and self.face_app is not None:
661
+ print(f'On HF space. Arc2Face Face encoder already loaded on cpu.')
662
+ return
663
 
664
  if str(device) == 'cpu':
665
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
666
+ providers=['CPUExecutionProvider'])
667
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
668
  else:
669
  device_id = device.index
670
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
671
  providers=['CUDAExecutionProvider'],
672
+ provider_options=[{'device_id': device_id,
673
+ 'cudnn_conv_algo_search': 'HEURISTIC'}])
 
 
674
  self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
675
 
 
676
  print(f'Arc2Face Face encoder reloaded on {device}.')
677
  return
678
 
 
745
  # but diffusers will call .to(dtype) in .from_single_file(),
746
  # and at that moment, the consistentID specific modules are not loaded yet.
747
  pipe = ConsistentIDPipeline.from_single_file(base_model_path)
748
+ pipe.load_ConsistentID_model(consistentID_weight_path="models/ConsistentID/ConsistentID-v1.bin",
749
+ bise_net_weight_path="models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth")
750
  pipe.to(dtype=self.dtype)
751
  # Since the passed-in pipe is None, this should be called during inference,
752
  # when the teacher ConsistentIDPipeline is not initialized.
 
797
 
798
  def _apply(self, fn):
799
  super()._apply(fn) # Call the parent _apply to handle parameters and buffers
 
800
  # A dirty hack to get the device of the model, passed from
801
  # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
802
  test_tensor = torch.zeros(1) # Create a test tensor
 
805
  # No need to reload face_app on the same device.
806
  if device == self.device:
807
  return
808
+ self.device = device
809
+
810
+ if self.is_on_hf_space and self.face_app is not None:
811
+ print(f'On HF space. Arc2Face Face encoder already loaded on cpu.')
812
+ return
813
 
814
  if str(device) == 'cpu':
815
  self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
 
819
  device_id = device.index
820
  self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
821
  providers=['CUDAExecutionProvider'],
822
+ provider_options=[{'device_id': device_id,
823
+ 'cudnn_conv_algo_search': 'HEURISTIC'}])
 
 
824
  self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
825
 
 
826
  self.pipe.face_app = self.face_app
827
  print(f'ConsistentID Face encoder reloaded on {device}.')
828
 
 
1284
  # No faces are found in the images, so return None embeddings.
1285
  # We don't want to return an all-zero embedding, which is useless.
1286
  if num_available_id_vecs == 0:
1287
+ return None, None, [0]
1288
 
1289
  # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
1290
  # during inference, we average across the batch dim.
adaface/unet_teachers.py CHANGED
@@ -62,46 +62,41 @@ class UNetTeacher(nn.Module):
62
  # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
63
  # same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
64
  def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
65
- num_denoising_steps=1, same_t_noise_across_instances=False,
66
  global_t_lb=0, global_t_ub=1000):
67
  assert num_denoising_steps <= 10
68
 
69
- if self.p_uses_cfg > 0:
 
 
 
70
  self.uses_cfg = np.random.rand() < self.p_uses_cfg
71
- if self.uses_cfg:
72
- # Randomly sample a cfg_scale from cfg_scale_range.
73
- self.cfg_scale = np.random.uniform(*self.cfg_scale_range)
74
- if self.cfg_scale == 1:
75
- self.uses_cfg = False
76
-
77
- if self.uses_cfg:
78
- print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
79
- if negative_context is not None:
80
- negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
81
-
82
- # if negative_context is None, then teacher_context is a combination of
83
- # (one or multiple if unet_ensemble) pos_context and neg_context.
84
- # If negative_context is not None, then teacher_context is only pos_context.
85
- else:
86
- self.cfg_scale = 1
87
- print("Teacher does not use CFG.")
88
-
89
- # If negative_context is None, then teacher_context is a combination of
90
- # (one or multiple if unet_ensemble) pos_context and neg_context.
91
- # Since not uses_cfg, we only need pos_context.
92
- # If negative_context is not None, then teacher_context is only pos_context.
93
- if negative_context is None:
94
- teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
95
  else:
96
  # p_uses_cfg = 0. Never use CFG.
97
  self.uses_cfg = False
98
- # In this case, the student only passes pos_context to the teacher,
99
- # so no need to split teacher_context into pos_context and neg_context.
100
- # self.cfg_scale will be accessed by the student,
101
- # so we need to make sure it is always set correctly,
102
- # in case someday we want to switch from CFG to non-CFG during runtime.
103
  self.cfg_scale = 1
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
106
  if self.name == 'unet_ensemble':
107
  # teacher_context is a list of teacher contexts.
@@ -199,14 +194,20 @@ class UNetTeacher(nn.Module):
199
  teacher_pos_contexts = []
200
  # teacher_context is a list of teacher contexts.
201
  for teacher_context_i in teacher_context:
202
- pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
203
- if pos_context.shape[0] != BS:
204
- breakpoint()
 
 
 
205
  teacher_pos_contexts.append(pos_context)
206
  teacher_context = teacher_pos_contexts
207
  else:
208
- pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
209
- if pos_context.shape[0] != BS:
 
 
 
210
  breakpoint()
211
  teacher_context = pos_context
212
 
 
62
  # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
63
  # same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
64
  def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
65
+ num_denoising_steps=1, force_uses_cfg=False, same_t_noise_across_instances=False,
66
  global_t_lb=0, global_t_ub=1000):
67
  assert num_denoising_steps <= 10
68
 
69
+ # force_uses_cfg overrides p_uses_cfg.
70
+ if force_uses_cfg > 0:
71
+ self.uses_cfg = True
72
+ elif self.p_uses_cfg > 0:
73
  self.uses_cfg = np.random.rand() < self.p_uses_cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  else:
75
  # p_uses_cfg = 0. Never use CFG.
76
  self.uses_cfg = False
 
 
 
 
 
77
  self.cfg_scale = 1
78
 
79
+ if self.uses_cfg:
80
+ # Randomly sample a cfg_scale from cfg_scale_range.
81
+ self.cfg_scale = np.random.uniform(*self.cfg_scale_range)
82
+ print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
83
+ if negative_context is not None:
84
+ negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
85
+
86
+ # if negative_context is None, then teacher_context is a combination of
87
+ # (one or multiple if unet_ensemble) pos_context and neg_context.
88
+ # If negative_context is not None, then teacher_context is only pos_context.
89
+ else:
90
+ self.cfg_scale = 1
91
+ print("Teacher does not use CFG.")
92
+
93
+ # If negative_context is None, then teacher_context is either a combination of
94
+ # (one or multiple if unet_ensemble) pos_context and neg_context, or only pos_context.
95
+ # Since not uses_cfg, we only need pos_context.
96
+ # If negative_context is not None, then teacher_context is only pos_context.
97
+ if negative_context is None:
98
+ teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
99
+
100
  is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
101
  if self.name == 'unet_ensemble':
102
  # teacher_context is a list of teacher contexts.
 
194
  teacher_pos_contexts = []
195
  # teacher_context is a list of teacher contexts.
196
  for teacher_context_i in teacher_context:
197
+ if teacher_context_i.shape[0] == BS * 2:
198
+ pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
199
+ elif teacher_context_i.shape[0] == BS:
200
+ pos_context = teacher_context_i
201
+ else:
202
+ breakpoint()
203
  teacher_pos_contexts.append(pos_context)
204
  teacher_context = teacher_pos_contexts
205
  else:
206
+ if teacher_context.shape[0] == BS * 2:
207
+ pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
208
+ elif teacher_context.shape[0] == BS:
209
+ pos_context = teacher_context
210
+ else:
211
  breakpoint()
212
  teacher_context = pos_context
213
 
adaface/util.py CHANGED
@@ -48,7 +48,7 @@ def perturb_tensor(ts, perturb_std, perturb_std_is_relative=True, keep_norm=Fals
48
  ts = ts + noise
49
 
50
  if verbose:
51
- print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).item():.03f}")
52
 
53
  return ts
54
 
@@ -69,7 +69,7 @@ def calc_stats(emb_name, embeddings, mean_dim=-1):
69
  # Compute it manually.
70
  l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt()
71
  norms = torch.norm(embeddings, dim=1).detach().cpu().numpy()
72
- print("L1: %.4f, L2: %.4f" %(l1_loss.item(), l2_loss.item()))
73
  print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
74
 
75
 
@@ -80,7 +80,7 @@ class ScaleGrad(torch.autograd.Function):
80
  ctx.save_for_backward(alpha_, debug)
81
  output = input_
82
  if debug:
83
- print(f"input: {input_.abs().mean().item()}")
84
  return output
85
 
86
  @staticmethod
@@ -90,7 +90,7 @@ class ScaleGrad(torch.autograd.Function):
90
  if ctx.needs_input_grad[0]:
91
  grad_output2 = grad_output * alpha_
92
  if debug:
93
- print(f"grad_output2: {grad_output2.abs().mean().item()}")
94
  else:
95
  grad_output2 = None
96
  return grad_output2, None, None
@@ -232,8 +232,8 @@ def create_consistentid_pipeline(base_model_path="models/sd15-dste8-vae.safetens
232
  # consistentID specific modules are still in fp32. Will be converted to fp16
233
  # later with .to(device, torch_dtype) by the caller.
234
  pipe.load_ConsistentID_model(
235
- consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin",
236
- bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
237
  )
238
  # Avoid passing dtype to ConsistentIDPipeline.from_single_file(),
239
  # because we've overloaded .to() to convert consistentID specific modules as well,
 
48
  ts = ts + noise
49
 
50
  if verbose:
51
+ print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).detach().item():.03f}")
52
 
53
  return ts
54
 
 
69
  # Compute it manually.
70
  l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt()
71
  norms = torch.norm(embeddings, dim=1).detach().cpu().numpy()
72
+ print("L1: %.4f, L2: %.4f" %(l1_loss.detach().item(), l2_loss.detach().item()))
73
  print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
74
 
75
 
 
80
  ctx.save_for_backward(alpha_, debug)
81
  output = input_
82
  if debug:
83
+ print(f"input: {input_.abs().mean().detach().item()}")
84
  return output
85
 
86
  @staticmethod
 
90
  if ctx.needs_input_grad[0]:
91
  grad_output2 = grad_output * alpha_
92
  if debug:
93
+ print(f"grad_output2: {grad_output2.abs().mean().detach().item()}")
94
  else:
95
  grad_output2 = None
96
  return grad_output2, None, None
 
232
  # consistentID specific modules are still in fp32. Will be converted to fp16
233
  # later with .to(device, torch_dtype) by the caller.
234
  pipe.load_ConsistentID_model(
235
+ consistentID_weight_path="models/ConsistentID/ConsistentID-v1.bin",
236
+ bise_net_weight_path="models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
237
  )
238
  # Avoid passing dtype to ConsistentIDPipeline.from_single_file(),
239
  # because we've overloaded .to() to convert consistentID specific modules as well,
app.py CHANGED
@@ -20,14 +20,14 @@ def str2bool(v):
20
  else:
21
  raise argparse.ArgumentTypeError("Boolean value expected.")
22
 
23
- def is_running_on_spaces():
24
  return os.getenv("SPACE_ID") is not None
25
 
26
  import argparse
27
  parser = argparse.ArgumentParser()
28
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
29
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
30
- parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2025-03-06T03-31-21_zero3-ada-1000.pt',
31
  help="Path to the checkpoint of the ID2Ada prompt encoders")
32
  # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
33
  parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=[6.0, 1.0],
@@ -75,6 +75,16 @@ MAX_SEED = np.iinfo(np.int32).max
75
  global adaface
76
  adaface = None
77
 
 
 
 
 
 
 
 
 
 
 
78
  if not args.test_ui_only:
79
  adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
80
  adaface_encoder_types=args.adaface_encoder_types,
@@ -84,9 +94,10 @@ if not args.test_ui_only:
84
  unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
85
  unet_uses_attn_lora=args.unet_uses_attn_lora,
86
  attn_lora_layer_names=args.attn_lora_layer_names,
87
- shrink_cross_attn=False,
88
  q_lora_updates_query=args.q_lora_updates_query,
89
- device='cpu')
 
90
 
91
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
92
  if randomize_seed:
@@ -114,18 +125,7 @@ def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
114
 
115
  global adaface, args
116
 
117
- if is_running_on_spaces():
118
- device = 'cuda:0'
119
- else:
120
- if args.gpu is None:
121
- device = "cuda"
122
- else:
123
- device = f"cuda:{args.gpu}"
124
-
125
- print(f"Device: {device}")
126
-
127
- adaface.to(device)
128
- args.device = device
129
 
130
  if image_paths is None or len(image_paths) == 0:
131
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
@@ -255,16 +255,17 @@ def check_prompt_and_model_type(prompt, model_style_type, adaface_encoder_cfg_sc
255
  print(f"Switching to the base model type: {model_style_type}.")
256
 
257
  adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=model_style_type2base_model_path[model_style_type],
258
- adaface_encoder_types=args.adaface_encoder_types,
259
- adaface_ckpt_paths=args.adaface_ckpt_path,
260
- adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
261
- enabled_encoders=args.enabled_encoders,
262
- unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
263
- unet_uses_attn_lora=args.unet_uses_attn_lora,
264
- attn_lora_layer_names=args.attn_lora_layer_names,
265
- shrink_cross_attn=False,
266
- q_lora_updates_query=args.q_lora_updates_query,
267
- device='cpu')
 
268
 
269
  if adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]:
270
  args.adaface_encoder_cfg_scales[0] = adaface_encoder_cfg_scale1
@@ -370,12 +371,13 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
370
  "portrait, night view of tokyo street, neon light",
371
  "portrait, playing guitar on a boat, ocean waves",
372
  "portrait, with a passion for reading, curled up with a book in a cozy nook near a window, front view",
373
- "portrait, celebrating new year, fireworks",
374
  "portrait, running pose in a park",
375
  "portrait, in space suit, space helmet, walking on mars",
376
  "portrait, in superman costume, the sky ablaze with hues of orange and purple",
377
  "in a wheelchair",
378
- "on a horse"
 
379
  ])
380
 
381
  highlight_face = gr.Checkbox(label="Highlight face", value=False,
 
20
  else:
21
  raise argparse.ArgumentTypeError("Boolean value expected.")
22
 
23
+ def is_running_on_hf_space():
24
  return os.getenv("SPACE_ID") is not None
25
 
26
  import argparse
27
  parser = argparse.ArgumentParser()
28
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
29
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
30
+ parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2025-05-22T17-51-19_zero3-ada-1000.pt',
31
  help="Path to the checkpoint of the ID2Ada prompt encoders")
32
  # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
33
  parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=[6.0, 1.0],
 
75
  global adaface
76
  adaface = None
77
 
78
+ if is_running_on_hf_space():
79
+ args.device = 'cuda:0'
80
+ is_on_hf_space = True
81
+ else:
82
+ if args.gpu is None:
83
+ args.device = "cuda"
84
+ else:
85
+ args.device = f"cuda:{args.gpu}"
86
+ is_on_hf_space = False
87
+
88
  if not args.test_ui_only:
89
  adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
90
  adaface_encoder_types=args.adaface_encoder_types,
 
94
  unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
95
  unet_uses_attn_lora=args.unet_uses_attn_lora,
96
  attn_lora_layer_names=args.attn_lora_layer_names,
97
+ normalize_cross_attn=False,
98
  q_lora_updates_query=args.q_lora_updates_query,
99
+ device='cpu',
100
+ is_on_hf_space=is_on_hf_space)
101
 
102
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
103
  if randomize_seed:
 
125
 
126
  global adaface, args
127
 
128
+ adaface.to(args.device)
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  if image_paths is None or len(image_paths) == 0:
131
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
 
255
  print(f"Switching to the base model type: {model_style_type}.")
256
 
257
  adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=model_style_type2base_model_path[model_style_type],
258
+ adaface_encoder_types=args.adaface_encoder_types,
259
+ adaface_ckpt_paths=args.adaface_ckpt_path,
260
+ adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
261
+ enabled_encoders=args.enabled_encoders,
262
+ unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
263
+ unet_uses_attn_lora=args.unet_uses_attn_lora,
264
+ attn_lora_layer_names=args.attn_lora_layer_names,
265
+ normalize_cross_attn=False,
266
+ q_lora_updates_query=args.q_lora_updates_query,
267
+ device='cpu',
268
+ is_on_hf_space=is_on_hf_space)
269
 
270
  if adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]:
271
  args.adaface_encoder_cfg_scales[0] = adaface_encoder_cfg_scale1
 
371
  "portrait, night view of tokyo street, neon light",
372
  "portrait, playing guitar on a boat, ocean waves",
373
  "portrait, with a passion for reading, curled up with a book in a cozy nook near a window, front view",
374
+ "portrait, celebrating new year alone, fireworks",
375
  "portrait, running pose in a park",
376
  "portrait, in space suit, space helmet, walking on mars",
377
  "portrait, in superman costume, the sky ablaze with hues of orange and purple",
378
  "in a wheelchair",
379
+ "on a horse",
380
+ "on a bike",
381
  ])
382
 
383
  highlight_face = gr.Checkbox(label="Highlight face", value=False,