Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
9bcd020
1
Parent(s):
c75a857
update code and gradio version
Browse files- README.md +1 -1
- adaface/adaface_wrapper.py +37 -22
- adaface/diffusers_attn_lora_capture.py +67 -62
- adaface/face_id_to_ada_prompt.py +26 -19
- adaface/unet_teachers.py +37 -36
- adaface/util.py +6 -6
- app.py +30 -28
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 😻
|
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.30.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
adaface/adaface_wrapper.py
CHANGED
@@ -30,8 +30,8 @@ class AdaFaceWrapper(nn.Module):
|
|
30 |
use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
-
attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
34 |
-
device='cuda', is_training=False):
|
35 |
'''
|
36 |
pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
|
37 |
If None, it's used only as a face encoder, and the unet and vae are
|
@@ -52,7 +52,7 @@ class AdaFaceWrapper(nn.Module):
|
|
52 |
self.q_lora_updates_query = q_lora_updates_query
|
53 |
self.use_lcm = use_lcm
|
54 |
self.subject_string = subject_string
|
55 |
-
self.
|
56 |
|
57 |
self.default_scheduler_name = default_scheduler_name
|
58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
@@ -64,6 +64,7 @@ class AdaFaceWrapper(nn.Module):
|
|
64 |
self.unet_weights_in_ensemble = unet_weights_in_ensemble
|
65 |
self.device = device
|
66 |
self.is_training = is_training
|
|
|
67 |
|
68 |
if negative_prompt is None:
|
69 |
self.negative_prompt = \
|
@@ -99,6 +100,7 @@ class AdaFaceWrapper(nn.Module):
|
|
99 |
self.adaface_ckpt_paths,
|
100 |
self.adaface_encoder_cfg_scales,
|
101 |
self.enabled_encoders,
|
|
|
102 |
num_static_img_suffix_embs=4)
|
103 |
|
104 |
self.id2ada_prompt_encoder.to(self.device)
|
@@ -189,10 +191,10 @@ class AdaFaceWrapper(nn.Module):
|
|
189 |
pipeline.unet = unet_ensemble
|
190 |
|
191 |
print(f"Loaded pipeline from {self.base_model_path}.")
|
192 |
-
if not remove_unet and (self.unet_uses_attn_lora or self.
|
193 |
unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
|
194 |
attn_lora_layer_names=self.attn_lora_layer_names,
|
195 |
-
|
196 |
q_lora_updates_query=self.q_lora_updates_query)
|
197 |
|
198 |
pipeline.unet = unet2
|
@@ -294,12 +296,11 @@ class AdaFaceWrapper(nn.Module):
|
|
294 |
def load_unet_loras(self, unet, unet_lora_modules_state_dict,
|
295 |
use_attn_lora=True, use_ffn_lora=False,
|
296 |
attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
297 |
-
|
298 |
q_lora_updates_query=False):
|
299 |
attn_capture_procs, attn_opt_modules = \
|
300 |
set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
|
301 |
lora_rank=192, lora_scale_down=8,
|
302 |
-
cross_attn_shrink_factor=cross_attn_shrink_factor,
|
303 |
q_lora_updates_query=q_lora_updates_query)
|
304 |
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
|
305 |
if use_ffn_lora:
|
@@ -343,16 +344,17 @@ class AdaFaceWrapper(nn.Module):
|
|
343 |
print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
|
344 |
self.outfeat_capture_blocks.append(unet.up_blocks[3])
|
345 |
|
346 |
-
# If
|
347 |
# but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
|
348 |
set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
|
349 |
use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
|
350 |
-
|
|
|
351 |
|
352 |
return unet
|
353 |
|
354 |
def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
355 |
-
|
356 |
unet_lora_weight_found = False
|
357 |
if isinstance(self.adaface_ckpt_paths, str):
|
358 |
adaface_ckpt_paths = [self.adaface_ckpt_paths]
|
@@ -360,7 +362,7 @@ class AdaFaceWrapper(nn.Module):
|
|
360 |
adaface_ckpt_paths = self.adaface_ckpt_paths
|
361 |
|
362 |
for adaface_ckpt_path in adaface_ckpt_paths:
|
363 |
-
ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu')
|
364 |
if 'unet_lora_modules' in ckpt_dict:
|
365 |
unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
|
366 |
print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
|
@@ -379,7 +381,7 @@ class AdaFaceWrapper(nn.Module):
|
|
379 |
unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
|
380 |
use_attn_lora=use_attn_lora,
|
381 |
attn_lora_layer_names=attn_lora_layer_names,
|
382 |
-
|
383 |
q_lora_updates_query=q_lora_updates_query)
|
384 |
unet.unets[i] = unet_
|
385 |
print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
|
@@ -387,7 +389,7 @@ class AdaFaceWrapper(nn.Module):
|
|
387 |
unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
|
388 |
use_attn_lora=use_attn_lora,
|
389 |
attn_lora_layer_names=attn_lora_layer_names,
|
390 |
-
|
391 |
q_lora_updates_query=q_lora_updates_query)
|
392 |
|
393 |
return unet
|
@@ -612,8 +614,9 @@ class AdaFaceWrapper(nn.Module):
|
|
612 |
# Scan prompt and replace tokens in self.placeholder_token_ids
|
613 |
# with the corresponding image embeddings.
|
614 |
prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
|
|
|
615 |
prompt_embeds2 = prompt_embeds.clone()
|
616 |
-
if alt_prompt_embed_type
|
617 |
if self.img_prompt_embs is None:
|
618 |
print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
|
619 |
return prompt_embeds
|
@@ -628,17 +631,18 @@ class AdaFaceWrapper(nn.Module):
|
|
628 |
breakpoint()
|
629 |
|
630 |
repl_tokens = {}
|
|
|
631 |
for i in range(len(prompt_tokens)):
|
632 |
if prompt_tokens[i] in self.all_placeholder_tokens:
|
633 |
encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
|
634 |
if prompt_tokens[i] in sublist), 0)
|
635 |
-
alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx]
|
636 |
-
prompt_embeds2[:, i] = prompt_embeds2[:, i] *
|
637 |
+ repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
|
638 |
repl_tokens[prompt_tokens[i]] = 1
|
639 |
|
640 |
repl_token_count = len(repl_tokens)
|
641 |
-
if
|
642 |
print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
|
643 |
else:
|
644 |
print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
|
@@ -650,7 +654,7 @@ class AdaFaceWrapper(nn.Module):
|
|
650 |
placeholder_tokens_pos='append',
|
651 |
ablate_prompt_only_placeholders=False,
|
652 |
ablate_prompt_no_placeholders=False,
|
653 |
-
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
|
654 |
nonmix_prompt_emb_weight=0,
|
655 |
repeat_prompt_for_each_encoder=True,
|
656 |
device=None, verbose=False):
|
@@ -678,14 +682,25 @@ class AdaFaceWrapper(nn.Module):
|
|
678 |
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
|
679 |
self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
|
680 |
|
681 |
-
if ablate_prompt_embed_type
|
682 |
alt_prompt_embed_type = ablate_prompt_embed_type
|
683 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
elif nonmix_prompt_emb_weight > 0:
|
685 |
alt_prompt_embed_type = 'ada-nonmix'
|
686 |
-
|
|
|
|
|
687 |
else:
|
688 |
-
|
|
|
689 |
|
690 |
if sum(alt_prompt_emb_weights) > 0:
|
691 |
prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
|
|
|
30 |
use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
+
attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
|
34 |
+
device='cuda', is_training=False, is_on_hf_space=False):
|
35 |
'''
|
36 |
pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
|
37 |
If None, it's used only as a face encoder, and the unet and vae are
|
|
|
52 |
self.q_lora_updates_query = q_lora_updates_query
|
53 |
self.use_lcm = use_lcm
|
54 |
self.subject_string = subject_string
|
55 |
+
self.normalize_cross_attn = normalize_cross_attn
|
56 |
|
57 |
self.default_scheduler_name = default_scheduler_name
|
58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
|
|
64 |
self.unet_weights_in_ensemble = unet_weights_in_ensemble
|
65 |
self.device = device
|
66 |
self.is_training = is_training
|
67 |
+
self.is_on_hf_space = is_on_hf_space
|
68 |
|
69 |
if negative_prompt is None:
|
70 |
self.negative_prompt = \
|
|
|
100 |
self.adaface_ckpt_paths,
|
101 |
self.adaface_encoder_cfg_scales,
|
102 |
self.enabled_encoders,
|
103 |
+
is_on_hf_space=self.is_on_hf_space,
|
104 |
num_static_img_suffix_embs=4)
|
105 |
|
106 |
self.id2ada_prompt_encoder.to(self.device)
|
|
|
191 |
pipeline.unet = unet_ensemble
|
192 |
|
193 |
print(f"Loaded pipeline from {self.base_model_path}.")
|
194 |
+
if not remove_unet and (self.unet_uses_attn_lora or self.normalize_cross_attn):
|
195 |
unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
|
196 |
attn_lora_layer_names=self.attn_lora_layer_names,
|
197 |
+
normalize_cross_attn=self.normalize_cross_attn,
|
198 |
q_lora_updates_query=self.q_lora_updates_query)
|
199 |
|
200 |
pipeline.unet = unet2
|
|
|
296 |
def load_unet_loras(self, unet, unet_lora_modules_state_dict,
|
297 |
use_attn_lora=True, use_ffn_lora=False,
|
298 |
attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
299 |
+
normalize_cross_attn=False,
|
300 |
q_lora_updates_query=False):
|
301 |
attn_capture_procs, attn_opt_modules = \
|
302 |
set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
|
303 |
lora_rank=192, lora_scale_down=8,
|
|
|
304 |
q_lora_updates_query=q_lora_updates_query)
|
305 |
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
|
306 |
if use_ffn_lora:
|
|
|
344 |
print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
|
345 |
self.outfeat_capture_blocks.append(unet.up_blocks[3])
|
346 |
|
347 |
+
# If normalize_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
|
348 |
# but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
|
349 |
set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
|
350 |
use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
|
351 |
+
normalize_cross_attn=normalize_cross_attn, mix_attn_mats_in_batch=False,
|
352 |
+
res_hidden_states_gradscale=0)
|
353 |
|
354 |
return unet
|
355 |
|
356 |
def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
357 |
+
normalize_cross_attn=False, q_lora_updates_query=False):
|
358 |
unet_lora_weight_found = False
|
359 |
if isinstance(self.adaface_ckpt_paths, str):
|
360 |
adaface_ckpt_paths = [self.adaface_ckpt_paths]
|
|
|
362 |
adaface_ckpt_paths = self.adaface_ckpt_paths
|
363 |
|
364 |
for adaface_ckpt_path in adaface_ckpt_paths:
|
365 |
+
ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu', weights_only=False)
|
366 |
if 'unet_lora_modules' in ckpt_dict:
|
367 |
unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
|
368 |
print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
|
|
|
381 |
unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
|
382 |
use_attn_lora=use_attn_lora,
|
383 |
attn_lora_layer_names=attn_lora_layer_names,
|
384 |
+
normalize_cross_attn=normalize_cross_attn,
|
385 |
q_lora_updates_query=q_lora_updates_query)
|
386 |
unet.unets[i] = unet_
|
387 |
print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
|
|
|
389 |
unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
|
390 |
use_attn_lora=use_attn_lora,
|
391 |
attn_lora_layer_names=attn_lora_layer_names,
|
392 |
+
normalize_cross_attn=normalize_cross_attn,
|
393 |
q_lora_updates_query=q_lora_updates_query)
|
394 |
|
395 |
return unet
|
|
|
614 |
# Scan prompt and replace tokens in self.placeholder_token_ids
|
615 |
# with the corresponding image embeddings.
|
616 |
prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
|
617 |
+
# prompt_embeds are the ada embeddings.
|
618 |
prompt_embeds2 = prompt_embeds.clone()
|
619 |
+
if alt_prompt_embed_type.startswith('img'):
|
620 |
if self.img_prompt_embs is None:
|
621 |
print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
|
622 |
return prompt_embeds
|
|
|
631 |
breakpoint()
|
632 |
|
633 |
repl_tokens = {}
|
634 |
+
ada_emb_weight = alt_prompt_emb_weights[0]
|
635 |
for i in range(len(prompt_tokens)):
|
636 |
if prompt_tokens[i] in self.all_placeholder_tokens:
|
637 |
encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
|
638 |
if prompt_tokens[i] in sublist), 0)
|
639 |
+
alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx + 1]
|
640 |
+
prompt_embeds2[:, i] = prompt_embeds2[:, i] * ada_emb_weight \
|
641 |
+ repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
|
642 |
repl_tokens[prompt_tokens[i]] = 1
|
643 |
|
644 |
repl_token_count = len(repl_tokens)
|
645 |
+
if ada_emb_weight == 0:
|
646 |
print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
|
647 |
else:
|
648 |
print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
|
|
|
654 |
placeholder_tokens_pos='append',
|
655 |
ablate_prompt_only_placeholders=False,
|
656 |
ablate_prompt_no_placeholders=False,
|
657 |
+
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img', 'img1', 'img2'.
|
658 |
nonmix_prompt_emb_weight=0,
|
659 |
repeat_prompt_for_each_encoder=True,
|
660 |
device=None, verbose=False):
|
|
|
682 |
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
|
683 |
self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
|
684 |
|
685 |
+
if ablate_prompt_embed_type.startswith('img'):
|
686 |
alt_prompt_embed_type = ablate_prompt_embed_type
|
687 |
+
if alt_prompt_embed_type == 'img1':
|
688 |
+
# The mixing weights of ada, img1, and img2 are 0, 1, and 0.
|
689 |
+
alt_prompt_emb_weights = (0, 1, 0)
|
690 |
+
elif alt_prompt_embed_type == 'img2':
|
691 |
+
# The mixing weights of ada, img1, and img2 are 0, 0, and 1.
|
692 |
+
alt_prompt_emb_weights = (0, 0, 1)
|
693 |
+
else:
|
694 |
+
# The mixing weights of ada, img1, and img2 are 0, 1, and 1.
|
695 |
+
alt_prompt_emb_weights = (0, 1, 1)
|
696 |
elif nonmix_prompt_emb_weight > 0:
|
697 |
alt_prompt_embed_type = 'ada-nonmix'
|
698 |
+
# The mixing weight of ada is 1 - nonmix_prompt_emb_weight, instead of 1 - nonmix_prompt_emb_weight * 2.
|
699 |
+
# It means ada is mixed by this weight with both img1 and img2.
|
700 |
+
alt_prompt_emb_weights = (1 - nonmix_prompt_emb_weight, nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
|
701 |
else:
|
702 |
+
# Don't change the prompt embeddings. So we set all the mixing weights to 0.
|
703 |
+
alt_prompt_emb_weights = (0, 0, 0)
|
704 |
|
705 |
if sum(alt_prompt_emb_weights) > 0:
|
706 |
prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
|
adaface/diffusers_attn_lora_capture.py
CHANGED
@@ -4,7 +4,6 @@ import torch.nn.functional as F
|
|
4 |
from typing import Optional, Tuple, Dict, Any
|
5 |
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
6 |
from diffusers.utils import logging, is_torch_version, deprecate
|
7 |
-
from diffusers.utils.torch_utils import fourier_filter
|
8 |
# UNet is a diffusers PeftAdapterMixin instance.
|
9 |
from diffusers.loaders.peft import PeftAdapterMixin
|
10 |
from peft import LoraConfig, get_peft_model
|
@@ -12,7 +11,6 @@ import peft.tuners.lora as peft_lora
|
|
12 |
from peft.tuners.lora.dora import DoraLinearLayer
|
13 |
from einops import rearrange
|
14 |
import math, re
|
15 |
-
import numpy as np
|
16 |
from peft.tuners.tuners_utils import BaseTunerLayer
|
17 |
|
18 |
|
@@ -28,7 +26,7 @@ class ScaleGrad(torch.autograd.Function):
|
|
28 |
ctx.save_for_backward(alpha_, debug)
|
29 |
output = input_
|
30 |
if debug:
|
31 |
-
print(f"input: {input_.abs().mean().item()}")
|
32 |
return output
|
33 |
|
34 |
@staticmethod
|
@@ -38,7 +36,7 @@ class ScaleGrad(torch.autograd.Function):
|
|
38 |
if ctx.needs_input_grad[0]:
|
39 |
grad_output2 = grad_output * alpha_
|
40 |
if debug:
|
41 |
-
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
42 |
else:
|
43 |
grad_output2 = None
|
44 |
return grad_output2, None, None
|
@@ -77,36 +75,11 @@ def split_indices_by_instance(indices, as_dict=False):
|
|
77 |
indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
|
78 |
return indices_by_instance
|
79 |
|
80 |
-
# If do_sum, returned emb_attns is 3D. Otherwise 4D.
|
81 |
-
# indices are applied on the first 2 dims of attn_mat.
|
82 |
-
def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False):
|
83 |
-
indices_by_instance = split_indices_by_instance(indices)
|
84 |
-
|
85 |
-
# emb_attns[0]: [1, 9, 8, 64]
|
86 |
-
# 8: 8 attention heads. Last dim 64: number of image tokens.
|
87 |
-
emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ]
|
88 |
-
if all_token_weights is not None:
|
89 |
-
# all_token_weights: [4, 77].
|
90 |
-
# token_weights_by_instance[0]: [1, 9, 1, 1].
|
91 |
-
token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ]
|
92 |
-
else:
|
93 |
-
token_weights = [ 1 ] * len(indices_by_instance)
|
94 |
-
|
95 |
-
# Apply token weights.
|
96 |
-
emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ]
|
97 |
-
|
98 |
-
# sum among K_subj_i subj embeddings -> [1, 8, 64]
|
99 |
-
if do_sum:
|
100 |
-
emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ]
|
101 |
-
elif do_mean:
|
102 |
-
emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ]
|
103 |
-
|
104 |
-
emb_attns = torch.cat(emb_attns, dim=0)
|
105 |
-
return emb_attns
|
106 |
-
|
107 |
# Slow implementation equivalent to F.scaled_dot_product_attention.
|
108 |
-
def scaled_dot_product_attention(query, key, value,
|
109 |
-
|
|
|
|
|
110 |
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
|
111 |
B, L, S = query.size(0), query.size(-2), key.size(-2)
|
112 |
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
@@ -128,21 +101,39 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|
128 |
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
129 |
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
130 |
|
131 |
-
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
147 |
output = attn_weight @ value
|
148 |
return output, attn_score, attn_weight
|
@@ -156,23 +147,25 @@ class AttnProcessor_LoRA_Capture(nn.Module):
|
|
156 |
def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
|
157 |
lora_uses_dora=True, lora_proj_layers=None,
|
158 |
lora_rank: int = 192, lora_alpha: float = 16,
|
159 |
-
cross_attn_shrink_factor: float = 0.5,
|
160 |
q_lora_updates_query=False, attn_proc_idx=-1):
|
161 |
super().__init__()
|
162 |
|
163 |
self.global_enable_lora = enable_lora
|
164 |
self.attn_proc_idx = attn_proc_idx
|
165 |
# reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
|
166 |
-
# By default,
|
167 |
-
self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora)
|
168 |
self.lora_rank = lora_rank
|
169 |
self.lora_alpha = lora_alpha
|
170 |
self.lora_scale = self.lora_alpha / self.lora_rank
|
171 |
-
self.cross_attn_shrink_factor = cross_attn_shrink_factor
|
172 |
self.q_lora_updates_query = q_lora_updates_query
|
173 |
|
174 |
self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
|
175 |
if self.global_enable_lora:
|
|
|
|
|
|
|
|
|
176 |
for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
|
177 |
if lora_layer_name == 'q':
|
178 |
self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
@@ -188,9 +181,10 @@ class AttnProcessor_LoRA_Capture(nn.Module):
|
|
188 |
use_dora=lora_uses_dora, lora_dropout=0.1)
|
189 |
|
190 |
# LoRA layers can be enabled/disabled dynamically.
|
191 |
-
def reset_attn_cache_and_flags(self, capture_ca_activations,
|
192 |
self.capture_ca_activations = capture_ca_activations
|
193 |
-
self.
|
|
|
194 |
self.cached_activations = {}
|
195 |
# Only enable LoRA for the next call(s) if global_enable_lora is set to True.
|
196 |
self.enable_lora = enable_lora and self.global_enable_lora
|
@@ -312,11 +306,14 @@ class AttnProcessor_LoRA_Capture(nn.Module):
|
|
312 |
breakpoint()
|
313 |
|
314 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
315 |
-
if is_cross_attn and (self.capture_ca_activations or self.
|
316 |
hidden_states, attn_score, attn_prob = \
|
317 |
scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
|
318 |
-
dropout_p=0.0,
|
319 |
-
|
|
|
|
|
|
|
320 |
else:
|
321 |
# Use the faster implementation of scaled_dot_product_attention
|
322 |
# when not capturing the activations or suppressing the subject attention.
|
@@ -452,7 +449,7 @@ def CrossAttnUpBlock2D_forward_capture(
|
|
452 |
# Adapted from ConsistentIDPipeline:set_ip_adapter().
|
453 |
# attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
|
454 |
def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
455 |
-
lora_rank=192, lora_scale_down=8,
|
456 |
q_lora_updates_query=False):
|
457 |
attn_procs = {}
|
458 |
attn_capture_procs = {}
|
@@ -502,7 +499,6 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
|
|
502 |
lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
|
503 |
# LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
|
504 |
lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
|
505 |
-
cross_attn_shrink_factor=cross_attn_shrink_factor,
|
506 |
q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
|
507 |
|
508 |
attn_proc_idx += 1
|
@@ -513,6 +509,11 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
|
|
513 |
attn_capture_procs[name] = attn_capture_proc
|
514 |
|
515 |
if use_attn_lora:
|
|
|
|
|
|
|
|
|
|
|
516 |
for subname, module in attn_capture_proc.named_modules():
|
517 |
if isinstance(module, peft_lora.LoraLayer):
|
518 |
# ModuleDict doesn't allow "." in the key.
|
@@ -537,7 +538,7 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
|
|
537 |
return attn_capture_procs, attn_opt_modules
|
538 |
|
539 |
# NOTE: cross-attn layers are included in the returned lora_modules.
|
540 |
-
def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=
|
541 |
# target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
|
542 |
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
|
543 |
# Cannot set to conv.+ as it will match added adapter module names, including
|
@@ -592,15 +593,18 @@ def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=1
|
|
592 |
def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
|
593 |
outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
|
594 |
use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
|
595 |
-
|
596 |
# For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
|
597 |
-
for attn_capture_proc in attn_capture_procs:
|
598 |
-
attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations,
|
|
|
599 |
# outfeat_capture_blocks only contains the last up block, up_blocks[3].
|
600 |
# It contains 3 FFN layers. We want to capture their output features.
|
601 |
for block in outfeat_capture_blocks:
|
602 |
block.capture_outfeats = capture_ca_activations
|
603 |
|
|
|
|
|
604 |
for block in res_hidden_states_gradscale_blocks:
|
605 |
block.res_hidden_states_gradscale = res_hidden_states_gradscale
|
606 |
|
@@ -639,6 +643,7 @@ def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat
|
|
639 |
block.cached_outfeats = {}
|
640 |
block.capture_outfeats = False
|
641 |
|
|
|
642 |
for layer_idx in captured_layer_indices:
|
643 |
# Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
|
644 |
# 23, 24 -> 1, 2 (!! not 0, 1 !!)
|
|
|
4 |
from typing import Optional, Tuple, Dict, Any
|
5 |
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
6 |
from diffusers.utils import logging, is_torch_version, deprecate
|
|
|
7 |
# UNet is a diffusers PeftAdapterMixin instance.
|
8 |
from diffusers.loaders.peft import PeftAdapterMixin
|
9 |
from peft import LoraConfig, get_peft_model
|
|
|
11 |
from peft.tuners.lora.dora import DoraLinearLayer
|
12 |
from einops import rearrange
|
13 |
import math, re
|
|
|
14 |
from peft.tuners.tuners_utils import BaseTunerLayer
|
15 |
|
16 |
|
|
|
26 |
ctx.save_for_backward(alpha_, debug)
|
27 |
output = input_
|
28 |
if debug:
|
29 |
+
print(f"input: {input_.abs().mean().detach().item()}")
|
30 |
return output
|
31 |
|
32 |
@staticmethod
|
|
|
36 |
if ctx.needs_input_grad[0]:
|
37 |
grad_output2 = grad_output * alpha_
|
38 |
if debug:
|
39 |
+
print(f"grad_output2: {grad_output2.abs().mean().detach().item()}")
|
40 |
else:
|
41 |
grad_output2 = None
|
42 |
return grad_output2, None, None
|
|
|
75 |
indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
|
76 |
return indices_by_instance
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
# Slow implementation equivalent to F.scaled_dot_product_attention.
|
79 |
+
def scaled_dot_product_attention(query, key, value, cross_attn_scale_factor,
|
80 |
+
attn_mask=None, dropout_p=0.0,
|
81 |
+
subj_indices=None, normalize_cross_attn=False,
|
82 |
+
mix_attn_mats_in_batch=False,
|
83 |
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
|
84 |
B, L, S = query.size(0), query.size(-2), key.size(-2)
|
85 |
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
|
|
101 |
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
102 |
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
103 |
|
104 |
+
attn_score = query @ key.transpose(-2, -1) * scale_factor
|
105 |
|
106 |
+
# attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_score.
|
107 |
+
attn_score += attn_bias
|
108 |
+
if mix_attn_mats_in_batch:
|
109 |
+
# The instances in the batch are [sc, mc]. We average their attn scores,
|
110 |
+
# and apply to both instances.
|
111 |
+
# attn_score: [2, 8, 4096, 77] -> [1, 8, 4096, 77] -> [2, 8, 4096, 77].
|
112 |
+
# If BLOCK_SIZE > 1, attn_score.shape[0] = 2 * BLOCK_SIZE.
|
113 |
+
if attn_score.shape[0] %2 != 0:
|
114 |
+
breakpoint()
|
115 |
+
attn_score_sc, attn_score_mc = attn_score.chunk(2, dim=0)
|
116 |
+
# Cut off the grad flow from the SC instance to the MC instance.
|
117 |
+
attn_score = (attn_score_sc + attn_score_mc.detach()) / 2
|
118 |
+
attn_score = attn_score.repeat(2, 1, 1, 1)
|
119 |
+
elif normalize_cross_attn:
|
120 |
+
if subj_indices is None:
|
121 |
+
breakpoint()
|
122 |
+
subj_indices_B, subj_indices_N = subj_indices
|
123 |
+
subj_attn_score = attn_score[subj_indices_B, :, :, subj_indices_N]
|
124 |
+
# Normalize the attention score of the subject tokens to have mean 0 across tokens,
|
125 |
+
# so that positive and negative scores are balanced.
|
126 |
+
subj_attn_score = subj_attn_score - subj_attn_score.mean(dim=2, keepdim=True).detach()
|
127 |
+
# cross_attn_scale is a learnable parameter, so the score will be scaled appropriately.
|
128 |
+
# Scale up the BP'ed gradient to cross_attn_scale_factor by 10x.
|
129 |
+
ca_scale_grad_scaler = gen_gradient_scaler(10)
|
130 |
+
subj_attn_score = subj_attn_score * ca_scale_grad_scaler(cross_attn_scale_factor)
|
131 |
+
attn_score2 = attn_score.clone()
|
132 |
+
attn_score2[subj_indices_B, :, :, subj_indices_N] = subj_attn_score
|
133 |
+
attn_score = attn_score2
|
134 |
+
# Otherwise, do nothing to attn_score.
|
135 |
+
|
136 |
+
attn_weight = torch.softmax(attn_score, dim=-1)
|
137 |
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
138 |
output = attn_weight @ value
|
139 |
return output, attn_score, attn_weight
|
|
|
147 |
def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
|
148 |
lora_uses_dora=True, lora_proj_layers=None,
|
149 |
lora_rank: int = 192, lora_alpha: float = 16,
|
|
|
150 |
q_lora_updates_query=False, attn_proc_idx=-1):
|
151 |
super().__init__()
|
152 |
|
153 |
self.global_enable_lora = enable_lora
|
154 |
self.attn_proc_idx = attn_proc_idx
|
155 |
# reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
|
156 |
+
# By default, normalize_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
|
157 |
+
self.reset_attn_cache_and_flags(capture_ca_activations, False, False, enable_lora)
|
158 |
self.lora_rank = lora_rank
|
159 |
self.lora_alpha = lora_alpha
|
160 |
self.lora_scale = self.lora_alpha / self.lora_rank
|
|
|
161 |
self.q_lora_updates_query = q_lora_updates_query
|
162 |
|
163 |
self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
|
164 |
if self.global_enable_lora:
|
165 |
+
# enable_lora = True iff this is a cross-attn layer in the last 3 up blocks.
|
166 |
+
# Since we only use cross_attn_scale_factor on cross-attn layers,
|
167 |
+
# we only use cross_attn_scale_factor when enable_lora is True.
|
168 |
+
self.cross_attn_scale_factor = nn.Parameter(torch.tensor(0.8), requires_grad=True)
|
169 |
for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
|
170 |
if lora_layer_name == 'q':
|
171 |
self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
|
|
181 |
use_dora=lora_uses_dora, lora_dropout=0.1)
|
182 |
|
183 |
# LoRA layers can be enabled/disabled dynamically.
|
184 |
+
def reset_attn_cache_and_flags(self, capture_ca_activations, normalize_cross_attn, mix_attn_mats_in_batch, enable_lora):
|
185 |
self.capture_ca_activations = capture_ca_activations
|
186 |
+
self.normalize_cross_attn = normalize_cross_attn
|
187 |
+
self.mix_attn_mats_in_batch = mix_attn_mats_in_batch
|
188 |
self.cached_activations = {}
|
189 |
# Only enable LoRA for the next call(s) if global_enable_lora is set to True.
|
190 |
self.enable_lora = enable_lora and self.global_enable_lora
|
|
|
306 |
breakpoint()
|
307 |
|
308 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
309 |
+
if is_cross_attn and (self.capture_ca_activations or self.normalize_cross_attn):
|
310 |
hidden_states, attn_score, attn_prob = \
|
311 |
scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
|
312 |
+
dropout_p=0.0, subj_indices=subj_indices,
|
313 |
+
normalize_cross_attn=self.normalize_cross_attn,
|
314 |
+
cross_attn_scale_factor=self.cross_attn_scale_factor,
|
315 |
+
mix_attn_mats_in_batch=self.mix_attn_mats_in_batch)
|
316 |
+
|
317 |
else:
|
318 |
# Use the faster implementation of scaled_dot_product_attention
|
319 |
# when not capturing the activations or suppressing the subject attention.
|
|
|
449 |
# Adapted from ConsistentIDPipeline:set_ip_adapter().
|
450 |
# attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
|
451 |
def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
452 |
+
lora_rank=192, lora_scale_down=8,
|
453 |
q_lora_updates_query=False):
|
454 |
attn_procs = {}
|
455 |
attn_capture_procs = {}
|
|
|
499 |
lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
|
500 |
# LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
|
501 |
lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
|
|
|
502 |
q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
|
503 |
|
504 |
attn_proc_idx += 1
|
|
|
509 |
attn_capture_procs[name] = attn_capture_proc
|
510 |
|
511 |
if use_attn_lora:
|
512 |
+
cross_attn_scale_factor_name = name + "_cross_attn_scale_factor"
|
513 |
+
# Put cross_attn_scale_factor in attn_opt_modules, so that we can optimize and save/load it.
|
514 |
+
attn_opt_modules[cross_attn_scale_factor_name] = attn_capture_proc.cross_attn_scale_factor
|
515 |
+
|
516 |
+
# Put LoRA layers in attn_opt_modules, so that we can optimize and save/load them.
|
517 |
for subname, module in attn_capture_proc.named_modules():
|
518 |
if isinstance(module, peft_lora.LoraLayer):
|
519 |
# ModuleDict doesn't allow "." in the key.
|
|
|
538 |
return attn_capture_procs, attn_opt_modules
|
539 |
|
540 |
# NOTE: cross-attn layers are included in the returned lora_modules.
|
541 |
+
def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=True, lora_rank=192, lora_alpha=16):
|
542 |
# target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
|
543 |
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
|
544 |
# Cannot set to conv.+ as it will match added adapter module names, including
|
|
|
593 |
def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
|
594 |
outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
|
595 |
use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
|
596 |
+
normalize_cross_attn, mix_attn_mats_in_batch, res_hidden_states_gradscale):
|
597 |
# For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
|
598 |
+
for i, attn_capture_proc in enumerate(attn_capture_procs):
|
599 |
+
attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, normalize_cross_attn, mix_attn_mats_in_batch,
|
600 |
+
enable_lora=use_attn_lora)
|
601 |
# outfeat_capture_blocks only contains the last up block, up_blocks[3].
|
602 |
# It contains 3 FFN layers. We want to capture their output features.
|
603 |
for block in outfeat_capture_blocks:
|
604 |
block.capture_outfeats = capture_ca_activations
|
605 |
|
606 |
+
# res_hidden_states_gradscale_blocks contain the second to the last up blocks, up_blocks[1:].
|
607 |
+
# It's only used to set res_hidden_states_gradscale, and doesn't capture anything.
|
608 |
for block in res_hidden_states_gradscale_blocks:
|
609 |
block.res_hidden_states_gradscale = res_hidden_states_gradscale
|
610 |
|
|
|
643 |
block.cached_outfeats = {}
|
644 |
block.capture_outfeats = False
|
645 |
|
646 |
+
|
647 |
for layer_idx in captured_layer_indices:
|
648 |
# Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
|
649 |
# 23, 24 -> 1, 2 (!! not 0, 1 !!)
|
adaface/face_id_to_ada_prompt.py
CHANGED
@@ -26,7 +26,7 @@ def create_id2ada_prompt_encoder(adaface_encoder_types, adaface_ckpt_paths=None,
|
|
26 |
if adaface_encoder_type == 'arc2face':
|
27 |
id2ada_prompt_encoder = \
|
28 |
Arc2Face_ID2AdaPrompt(adaface_ckpt_path=adaface_ckpt_path,
|
29 |
-
|
30 |
elif adaface_encoder_type == 'consistentID':
|
31 |
id2ada_prompt_encoder = \
|
32 |
ConsistentID_ID2AdaPrompt(pipe=None,
|
@@ -64,6 +64,7 @@ class FaceID2AdaPrompt(nn.Module):
|
|
64 |
# i.e., 6 for arc2face and 1 for consistentID.
|
65 |
self.out_id_embs_cfg_scale = kwargs.get('out_id_embs_cfg_scale', -1)
|
66 |
self.is_training = kwargs.get('is_training', False)
|
|
|
67 |
# extend_prompt2token_proj_attention_multiplier is an integer >= 1.
|
68 |
# TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
|
69 |
self.extend_prompt2token_proj_attention_multiplier = kwargs.get('extend_prompt2token_proj_attention_multiplier', 1)
|
@@ -603,9 +604,13 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
603 |
'''
|
604 |
# Use the same model as ID2AdaPrompt does.
|
605 |
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
|
606 |
-
# Note there
|
|
|
|
|
|
|
|
|
607 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
608 |
-
|
609 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
610 |
print(f'Arc2Face Face encoder loaded on CPU.')
|
611 |
|
@@ -642,7 +647,6 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
642 |
|
643 |
def _apply(self, fn):
|
644 |
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
645 |
-
return
|
646 |
# A dirty hack to get the device of the model, passed from
|
647 |
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
648 |
test_tensor = torch.zeros(1) # Create a test tensor
|
@@ -651,22 +655,24 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
651 |
# No need to reload face_app on the same device.
|
652 |
if device == self.device:
|
653 |
return
|
|
|
|
|
|
|
|
|
|
|
654 |
|
655 |
if str(device) == 'cpu':
|
656 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
657 |
-
|
658 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
659 |
else:
|
660 |
device_id = device.index
|
661 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
662 |
providers=['CUDAExecutionProvider'],
|
663 |
-
provider_options=[{
|
664 |
-
|
665 |
-
"gpu_mem_limit": 2 * 1024**3
|
666 |
-
}])
|
667 |
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
668 |
|
669 |
-
self.device = device
|
670 |
print(f'Arc2Face Face encoder reloaded on {device}.')
|
671 |
return
|
672 |
|
@@ -739,8 +745,8 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
739 |
# but diffusers will call .to(dtype) in .from_single_file(),
|
740 |
# and at that moment, the consistentID specific modules are not loaded yet.
|
741 |
pipe = ConsistentIDPipeline.from_single_file(base_model_path)
|
742 |
-
pipe.load_ConsistentID_model(consistentID_weight_path="
|
743 |
-
bise_net_weight_path="
|
744 |
pipe.to(dtype=self.dtype)
|
745 |
# Since the passed-in pipe is None, this should be called during inference,
|
746 |
# when the teacher ConsistentIDPipeline is not initialized.
|
@@ -791,7 +797,6 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
791 |
|
792 |
def _apply(self, fn):
|
793 |
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
794 |
-
return
|
795 |
# A dirty hack to get the device of the model, passed from
|
796 |
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
797 |
test_tensor = torch.zeros(1) # Create a test tensor
|
@@ -800,6 +805,11 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
800 |
# No need to reload face_app on the same device.
|
801 |
if device == self.device:
|
802 |
return
|
|
|
|
|
|
|
|
|
|
|
803 |
|
804 |
if str(device) == 'cpu':
|
805 |
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
@@ -809,13 +819,10 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
809 |
device_id = device.index
|
810 |
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
811 |
providers=['CUDAExecutionProvider'],
|
812 |
-
provider_options=[{
|
813 |
-
|
814 |
-
"gpu_mem_limit": 2 * 1024**3
|
815 |
-
}])
|
816 |
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
817 |
|
818 |
-
self.device = device
|
819 |
self.pipe.face_app = self.face_app
|
820 |
print(f'ConsistentID Face encoder reloaded on {device}.')
|
821 |
|
@@ -1277,7 +1284,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1277 |
# No faces are found in the images, so return None embeddings.
|
1278 |
# We don't want to return an all-zero embedding, which is useless.
|
1279 |
if num_available_id_vecs == 0:
|
1280 |
-
return None, [0]
|
1281 |
|
1282 |
# If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
|
1283 |
# during inference, we average across the batch dim.
|
|
|
26 |
if adaface_encoder_type == 'arc2face':
|
27 |
id2ada_prompt_encoder = \
|
28 |
Arc2Face_ID2AdaPrompt(adaface_ckpt_path=adaface_ckpt_path,
|
29 |
+
*args, **kwargs)
|
30 |
elif adaface_encoder_type == 'consistentID':
|
31 |
id2ada_prompt_encoder = \
|
32 |
ConsistentID_ID2AdaPrompt(pipe=None,
|
|
|
64 |
# i.e., 6 for arc2face and 1 for consistentID.
|
65 |
self.out_id_embs_cfg_scale = kwargs.get('out_id_embs_cfg_scale', -1)
|
66 |
self.is_training = kwargs.get('is_training', False)
|
67 |
+
self.is_on_hf_space = kwargs.get('is_on_hf_space', False)
|
68 |
# extend_prompt2token_proj_attention_multiplier is an integer >= 1.
|
69 |
# TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
|
70 |
self.extend_prompt2token_proj_attention_multiplier = kwargs.get('extend_prompt2token_proj_attention_multiplier', 1)
|
|
|
604 |
'''
|
605 |
# Use the same model as ID2AdaPrompt does.
|
606 |
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
|
607 |
+
# Note there's a second "model" in the path.
|
608 |
+
# Note DO use CUDAExecutionProvider during training and CPUExecutionProvider during inference.
|
609 |
+
# Otherwise, CPUExecutionProvider will hang DDP training,
|
610 |
+
# and CUDAExecutionProvider will cause OOM on huggingface spaces.
|
611 |
+
self.onnx_providers = ['CUDAExecutionProvider'] if self.is_training else ['CPUExecutionProvider']
|
612 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
613 |
+
providers=self.onnx_providers)
|
614 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
615 |
print(f'Arc2Face Face encoder loaded on CPU.')
|
616 |
|
|
|
647 |
|
648 |
def _apply(self, fn):
|
649 |
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
|
|
650 |
# A dirty hack to get the device of the model, passed from
|
651 |
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
652 |
test_tensor = torch.zeros(1) # Create a test tensor
|
|
|
655 |
# No need to reload face_app on the same device.
|
656 |
if device == self.device:
|
657 |
return
|
658 |
+
self.device = device
|
659 |
+
|
660 |
+
if self.is_on_hf_space and self.face_app is not None:
|
661 |
+
print(f'On HF space. Arc2Face Face encoder already loaded on cpu.')
|
662 |
+
return
|
663 |
|
664 |
if str(device) == 'cpu':
|
665 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
666 |
+
providers=['CPUExecutionProvider'])
|
667 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
668 |
else:
|
669 |
device_id = device.index
|
670 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
671 |
providers=['CUDAExecutionProvider'],
|
672 |
+
provider_options=[{'device_id': device_id,
|
673 |
+
'cudnn_conv_algo_search': 'HEURISTIC'}])
|
|
|
|
|
674 |
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
675 |
|
|
|
676 |
print(f'Arc2Face Face encoder reloaded on {device}.')
|
677 |
return
|
678 |
|
|
|
745 |
# but diffusers will call .to(dtype) in .from_single_file(),
|
746 |
# and at that moment, the consistentID specific modules are not loaded yet.
|
747 |
pipe = ConsistentIDPipeline.from_single_file(base_model_path)
|
748 |
+
pipe.load_ConsistentID_model(consistentID_weight_path="models/ConsistentID/ConsistentID-v1.bin",
|
749 |
+
bise_net_weight_path="models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth")
|
750 |
pipe.to(dtype=self.dtype)
|
751 |
# Since the passed-in pipe is None, this should be called during inference,
|
752 |
# when the teacher ConsistentIDPipeline is not initialized.
|
|
|
797 |
|
798 |
def _apply(self, fn):
|
799 |
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
|
|
800 |
# A dirty hack to get the device of the model, passed from
|
801 |
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
802 |
test_tensor = torch.zeros(1) # Create a test tensor
|
|
|
805 |
# No need to reload face_app on the same device.
|
806 |
if device == self.device:
|
807 |
return
|
808 |
+
self.device = device
|
809 |
+
|
810 |
+
if self.is_on_hf_space and self.face_app is not None:
|
811 |
+
print(f'On HF space. Arc2Face Face encoder already loaded on cpu.')
|
812 |
+
return
|
813 |
|
814 |
if str(device) == 'cpu':
|
815 |
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
|
|
819 |
device_id = device.index
|
820 |
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
821 |
providers=['CUDAExecutionProvider'],
|
822 |
+
provider_options=[{'device_id': device_id,
|
823 |
+
'cudnn_conv_algo_search': 'HEURISTIC'}])
|
|
|
|
|
824 |
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
825 |
|
|
|
826 |
self.pipe.face_app = self.face_app
|
827 |
print(f'ConsistentID Face encoder reloaded on {device}.')
|
828 |
|
|
|
1284 |
# No faces are found in the images, so return None embeddings.
|
1285 |
# We don't want to return an all-zero embedding, which is useless.
|
1286 |
if num_available_id_vecs == 0:
|
1287 |
+
return None, None, [0]
|
1288 |
|
1289 |
# If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
|
1290 |
# during inference, we average across the batch dim.
|
adaface/unet_teachers.py
CHANGED
@@ -62,46 +62,41 @@ class UNetTeacher(nn.Module):
|
|
62 |
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
|
63 |
# same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
|
64 |
def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
|
65 |
-
num_denoising_steps=1, same_t_noise_across_instances=False,
|
66 |
global_t_lb=0, global_t_ub=1000):
|
67 |
assert num_denoising_steps <= 10
|
68 |
|
69 |
-
|
|
|
|
|
|
|
70 |
self.uses_cfg = np.random.rand() < self.p_uses_cfg
|
71 |
-
if self.uses_cfg:
|
72 |
-
# Randomly sample a cfg_scale from cfg_scale_range.
|
73 |
-
self.cfg_scale = np.random.uniform(*self.cfg_scale_range)
|
74 |
-
if self.cfg_scale == 1:
|
75 |
-
self.uses_cfg = False
|
76 |
-
|
77 |
-
if self.uses_cfg:
|
78 |
-
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
|
79 |
-
if negative_context is not None:
|
80 |
-
negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
|
81 |
-
|
82 |
-
# if negative_context is None, then teacher_context is a combination of
|
83 |
-
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
84 |
-
# If negative_context is not None, then teacher_context is only pos_context.
|
85 |
-
else:
|
86 |
-
self.cfg_scale = 1
|
87 |
-
print("Teacher does not use CFG.")
|
88 |
-
|
89 |
-
# If negative_context is None, then teacher_context is a combination of
|
90 |
-
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
91 |
-
# Since not uses_cfg, we only need pos_context.
|
92 |
-
# If negative_context is not None, then teacher_context is only pos_context.
|
93 |
-
if negative_context is None:
|
94 |
-
teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
|
95 |
else:
|
96 |
# p_uses_cfg = 0. Never use CFG.
|
97 |
self.uses_cfg = False
|
98 |
-
# In this case, the student only passes pos_context to the teacher,
|
99 |
-
# so no need to split teacher_context into pos_context and neg_context.
|
100 |
-
# self.cfg_scale will be accessed by the student,
|
101 |
-
# so we need to make sure it is always set correctly,
|
102 |
-
# in case someday we want to switch from CFG to non-CFG during runtime.
|
103 |
self.cfg_scale = 1
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
|
106 |
if self.name == 'unet_ensemble':
|
107 |
# teacher_context is a list of teacher contexts.
|
@@ -199,14 +194,20 @@ class UNetTeacher(nn.Module):
|
|
199 |
teacher_pos_contexts = []
|
200 |
# teacher_context is a list of teacher contexts.
|
201 |
for teacher_context_i in teacher_context:
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
205 |
teacher_pos_contexts.append(pos_context)
|
206 |
teacher_context = teacher_pos_contexts
|
207 |
else:
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
210 |
breakpoint()
|
211 |
teacher_context = pos_context
|
212 |
|
|
|
62 |
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
|
63 |
# same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
|
64 |
def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
|
65 |
+
num_denoising_steps=1, force_uses_cfg=False, same_t_noise_across_instances=False,
|
66 |
global_t_lb=0, global_t_ub=1000):
|
67 |
assert num_denoising_steps <= 10
|
68 |
|
69 |
+
# force_uses_cfg overrides p_uses_cfg.
|
70 |
+
if force_uses_cfg > 0:
|
71 |
+
self.uses_cfg = True
|
72 |
+
elif self.p_uses_cfg > 0:
|
73 |
self.uses_cfg = np.random.rand() < self.p_uses_cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
else:
|
75 |
# p_uses_cfg = 0. Never use CFG.
|
76 |
self.uses_cfg = False
|
|
|
|
|
|
|
|
|
|
|
77 |
self.cfg_scale = 1
|
78 |
|
79 |
+
if self.uses_cfg:
|
80 |
+
# Randomly sample a cfg_scale from cfg_scale_range.
|
81 |
+
self.cfg_scale = np.random.uniform(*self.cfg_scale_range)
|
82 |
+
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
|
83 |
+
if negative_context is not None:
|
84 |
+
negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
|
85 |
+
|
86 |
+
# if negative_context is None, then teacher_context is a combination of
|
87 |
+
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
88 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
89 |
+
else:
|
90 |
+
self.cfg_scale = 1
|
91 |
+
print("Teacher does not use CFG.")
|
92 |
+
|
93 |
+
# If negative_context is None, then teacher_context is either a combination of
|
94 |
+
# (one or multiple if unet_ensemble) pos_context and neg_context, or only pos_context.
|
95 |
+
# Since not uses_cfg, we only need pos_context.
|
96 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
97 |
+
if negative_context is None:
|
98 |
+
teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
|
99 |
+
|
100 |
is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
|
101 |
if self.name == 'unet_ensemble':
|
102 |
# teacher_context is a list of teacher contexts.
|
|
|
194 |
teacher_pos_contexts = []
|
195 |
# teacher_context is a list of teacher contexts.
|
196 |
for teacher_context_i in teacher_context:
|
197 |
+
if teacher_context_i.shape[0] == BS * 2:
|
198 |
+
pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
|
199 |
+
elif teacher_context_i.shape[0] == BS:
|
200 |
+
pos_context = teacher_context_i
|
201 |
+
else:
|
202 |
+
breakpoint()
|
203 |
teacher_pos_contexts.append(pos_context)
|
204 |
teacher_context = teacher_pos_contexts
|
205 |
else:
|
206 |
+
if teacher_context.shape[0] == BS * 2:
|
207 |
+
pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
|
208 |
+
elif teacher_context.shape[0] == BS:
|
209 |
+
pos_context = teacher_context
|
210 |
+
else:
|
211 |
breakpoint()
|
212 |
teacher_context = pos_context
|
213 |
|
adaface/util.py
CHANGED
@@ -48,7 +48,7 @@ def perturb_tensor(ts, perturb_std, perturb_std_is_relative=True, keep_norm=Fals
|
|
48 |
ts = ts + noise
|
49 |
|
50 |
if verbose:
|
51 |
-
print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).item():.03f}")
|
52 |
|
53 |
return ts
|
54 |
|
@@ -69,7 +69,7 @@ def calc_stats(emb_name, embeddings, mean_dim=-1):
|
|
69 |
# Compute it manually.
|
70 |
l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt()
|
71 |
norms = torch.norm(embeddings, dim=1).detach().cpu().numpy()
|
72 |
-
print("L1: %.4f, L2: %.4f" %(l1_loss.item(), l2_loss.item()))
|
73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
74 |
|
75 |
|
@@ -80,7 +80,7 @@ class ScaleGrad(torch.autograd.Function):
|
|
80 |
ctx.save_for_backward(alpha_, debug)
|
81 |
output = input_
|
82 |
if debug:
|
83 |
-
print(f"input: {input_.abs().mean().item()}")
|
84 |
return output
|
85 |
|
86 |
@staticmethod
|
@@ -90,7 +90,7 @@ class ScaleGrad(torch.autograd.Function):
|
|
90 |
if ctx.needs_input_grad[0]:
|
91 |
grad_output2 = grad_output * alpha_
|
92 |
if debug:
|
93 |
-
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
94 |
else:
|
95 |
grad_output2 = None
|
96 |
return grad_output2, None, None
|
@@ -232,8 +232,8 @@ def create_consistentid_pipeline(base_model_path="models/sd15-dste8-vae.safetens
|
|
232 |
# consistentID specific modules are still in fp32. Will be converted to fp16
|
233 |
# later with .to(device, torch_dtype) by the caller.
|
234 |
pipe.load_ConsistentID_model(
|
235 |
-
consistentID_weight_path="
|
236 |
-
bise_net_weight_path="
|
237 |
)
|
238 |
# Avoid passing dtype to ConsistentIDPipeline.from_single_file(),
|
239 |
# because we've overloaded .to() to convert consistentID specific modules as well,
|
|
|
48 |
ts = ts + noise
|
49 |
|
50 |
if verbose:
|
51 |
+
print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).detach().item():.03f}")
|
52 |
|
53 |
return ts
|
54 |
|
|
|
69 |
# Compute it manually.
|
70 |
l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt()
|
71 |
norms = torch.norm(embeddings, dim=1).detach().cpu().numpy()
|
72 |
+
print("L1: %.4f, L2: %.4f" %(l1_loss.detach().item(), l2_loss.detach().item()))
|
73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
74 |
|
75 |
|
|
|
80 |
ctx.save_for_backward(alpha_, debug)
|
81 |
output = input_
|
82 |
if debug:
|
83 |
+
print(f"input: {input_.abs().mean().detach().item()}")
|
84 |
return output
|
85 |
|
86 |
@staticmethod
|
|
|
90 |
if ctx.needs_input_grad[0]:
|
91 |
grad_output2 = grad_output * alpha_
|
92 |
if debug:
|
93 |
+
print(f"grad_output2: {grad_output2.abs().mean().detach().item()}")
|
94 |
else:
|
95 |
grad_output2 = None
|
96 |
return grad_output2, None, None
|
|
|
232 |
# consistentID specific modules are still in fp32. Will be converted to fp16
|
233 |
# later with .to(device, torch_dtype) by the caller.
|
234 |
pipe.load_ConsistentID_model(
|
235 |
+
consistentID_weight_path="models/ConsistentID/ConsistentID-v1.bin",
|
236 |
+
bise_net_weight_path="models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
|
237 |
)
|
238 |
# Avoid passing dtype to ConsistentIDPipeline.from_single_file(),
|
239 |
# because we've overloaded .to() to convert consistentID specific modules as well,
|
app.py
CHANGED
@@ -20,14 +20,14 @@ def str2bool(v):
|
|
20 |
else:
|
21 |
raise argparse.ArgumentTypeError("Boolean value expected.")
|
22 |
|
23 |
-
def
|
24 |
return os.getenv("SPACE_ID") is not None
|
25 |
|
26 |
import argparse
|
27 |
parser = argparse.ArgumentParser()
|
28 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
29 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
30 |
-
parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2025-
|
31 |
help="Path to the checkpoint of the ID2Ada prompt encoders")
|
32 |
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
|
33 |
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=[6.0, 1.0],
|
@@ -75,6 +75,16 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
75 |
global adaface
|
76 |
adaface = None
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
if not args.test_ui_only:
|
79 |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
|
80 |
adaface_encoder_types=args.adaface_encoder_types,
|
@@ -84,9 +94,10 @@ if not args.test_ui_only:
|
|
84 |
unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
85 |
unet_uses_attn_lora=args.unet_uses_attn_lora,
|
86 |
attn_lora_layer_names=args.attn_lora_layer_names,
|
87 |
-
|
88 |
q_lora_updates_query=args.q_lora_updates_query,
|
89 |
-
device='cpu'
|
|
|
90 |
|
91 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
92 |
if randomize_seed:
|
@@ -114,18 +125,7 @@ def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
|
|
114 |
|
115 |
global adaface, args
|
116 |
|
117 |
-
|
118 |
-
device = 'cuda:0'
|
119 |
-
else:
|
120 |
-
if args.gpu is None:
|
121 |
-
device = "cuda"
|
122 |
-
else:
|
123 |
-
device = f"cuda:{args.gpu}"
|
124 |
-
|
125 |
-
print(f"Device: {device}")
|
126 |
-
|
127 |
-
adaface.to(device)
|
128 |
-
args.device = device
|
129 |
|
130 |
if image_paths is None or len(image_paths) == 0:
|
131 |
raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
|
@@ -255,16 +255,17 @@ def check_prompt_and_model_type(prompt, model_style_type, adaface_encoder_cfg_sc
|
|
255 |
print(f"Switching to the base model type: {model_style_type}.")
|
256 |
|
257 |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=model_style_type2base_model_path[model_style_type],
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
268 |
|
269 |
if adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]:
|
270 |
args.adaface_encoder_cfg_scales[0] = adaface_encoder_cfg_scale1
|
@@ -370,12 +371,13 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
370 |
"portrait, night view of tokyo street, neon light",
|
371 |
"portrait, playing guitar on a boat, ocean waves",
|
372 |
"portrait, with a passion for reading, curled up with a book in a cozy nook near a window, front view",
|
373 |
-
"portrait, celebrating new year, fireworks",
|
374 |
"portrait, running pose in a park",
|
375 |
"portrait, in space suit, space helmet, walking on mars",
|
376 |
"portrait, in superman costume, the sky ablaze with hues of orange and purple",
|
377 |
"in a wheelchair",
|
378 |
-
"on a horse"
|
|
|
379 |
])
|
380 |
|
381 |
highlight_face = gr.Checkbox(label="Highlight face", value=False,
|
|
|
20 |
else:
|
21 |
raise argparse.ArgumentTypeError("Boolean value expected.")
|
22 |
|
23 |
+
def is_running_on_hf_space():
|
24 |
return os.getenv("SPACE_ID") is not None
|
25 |
|
26 |
import argparse
|
27 |
parser = argparse.ArgumentParser()
|
28 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
29 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
30 |
+
parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2025-05-22T17-51-19_zero3-ada-1000.pt',
|
31 |
help="Path to the checkpoint of the ID2Ada prompt encoders")
|
32 |
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
|
33 |
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=[6.0, 1.0],
|
|
|
75 |
global adaface
|
76 |
adaface = None
|
77 |
|
78 |
+
if is_running_on_hf_space():
|
79 |
+
args.device = 'cuda:0'
|
80 |
+
is_on_hf_space = True
|
81 |
+
else:
|
82 |
+
if args.gpu is None:
|
83 |
+
args.device = "cuda"
|
84 |
+
else:
|
85 |
+
args.device = f"cuda:{args.gpu}"
|
86 |
+
is_on_hf_space = False
|
87 |
+
|
88 |
if not args.test_ui_only:
|
89 |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
|
90 |
adaface_encoder_types=args.adaface_encoder_types,
|
|
|
94 |
unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
95 |
unet_uses_attn_lora=args.unet_uses_attn_lora,
|
96 |
attn_lora_layer_names=args.attn_lora_layer_names,
|
97 |
+
normalize_cross_attn=False,
|
98 |
q_lora_updates_query=args.q_lora_updates_query,
|
99 |
+
device='cpu',
|
100 |
+
is_on_hf_space=is_on_hf_space)
|
101 |
|
102 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
103 |
if randomize_seed:
|
|
|
125 |
|
126 |
global adaface, args
|
127 |
|
128 |
+
adaface.to(args.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
if image_paths is None or len(image_paths) == 0:
|
131 |
raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
|
|
|
255 |
print(f"Switching to the base model type: {model_style_type}.")
|
256 |
|
257 |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=model_style_type2base_model_path[model_style_type],
|
258 |
+
adaface_encoder_types=args.adaface_encoder_types,
|
259 |
+
adaface_ckpt_paths=args.adaface_ckpt_path,
|
260 |
+
adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
|
261 |
+
enabled_encoders=args.enabled_encoders,
|
262 |
+
unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
263 |
+
unet_uses_attn_lora=args.unet_uses_attn_lora,
|
264 |
+
attn_lora_layer_names=args.attn_lora_layer_names,
|
265 |
+
normalize_cross_attn=False,
|
266 |
+
q_lora_updates_query=args.q_lora_updates_query,
|
267 |
+
device='cpu',
|
268 |
+
is_on_hf_space=is_on_hf_space)
|
269 |
|
270 |
if adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]:
|
271 |
args.adaface_encoder_cfg_scales[0] = adaface_encoder_cfg_scale1
|
|
|
371 |
"portrait, night view of tokyo street, neon light",
|
372 |
"portrait, playing guitar on a boat, ocean waves",
|
373 |
"portrait, with a passion for reading, curled up with a book in a cozy nook near a window, front view",
|
374 |
+
"portrait, celebrating new year alone, fireworks",
|
375 |
"portrait, running pose in a park",
|
376 |
"portrait, in space suit, space helmet, walking on mars",
|
377 |
"portrait, in superman costume, the sky ablaze with hues of orange and purple",
|
378 |
"in a wheelchair",
|
379 |
+
"on a horse",
|
380 |
+
"on a bike",
|
381 |
])
|
382 |
|
383 |
highlight_face = gr.Checkbox(label="Highlight face", value=False,
|