Spaces:
Runtime error
Runtime error
File size: 16,296 Bytes
c98544f d3479d5 c98544f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 |
import torch
import numpy as np
import logging
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from two_stream_shunt_adapter import ConditionModulationShuntAdapter, reshape_for_shunt
logger = logging.getLogger(__name__)
@dataclass
class ShiftConfig:
"""Unified configuration for all modifications"""
prompt: str = ""
seed: int = -1 # -1 means no seed, use random
strength: float = 1.0
delta_mean: float = 0.0
delta_scale: float = 1.0
sigma_scale: float = 0.0
gate_probability: float = 1.0
gate_threshold: float = 0.1
noise_injection: float = 0.0
use_anchor: bool = True
pool_method: str = "sequential" # "sequential" or "weighted_average"
# Top-K parameters
use_topk: bool = False
topk_percentage: float = 100.0 # Percentage of tokens to keep
tau_temperature: float = 1.0 # Temperature scaling for tau
topk_mode: str = "attention" # "attention", "gate", "combined", "tau_softmax"
guidance_scale: float = 1.0,
max_tokens: int = 77 # Maximum number of tokens to process
@dataclass
class AdapterOutput:
"""Raw output from adapter forward pass"""
anchor: torch.Tensor
delta: torch.Tensor # Note: already has gate multiplied in!
log_sigma: torch.Tensor
tau: torch.Tensor
g_pred: torch.Tensor
gate: torch.Tensor
adapter_type: str
slice_range: Tuple[int, int]
# Add attention weights for top-k
attn_c2m: Optional[torch.Tensor] = None
attn_m2c: Optional[torch.Tensor] = None
class ConditioningShifter:
@staticmethod
def extract_encoder_embeddings(
encoder_pipe: Dict[str, Any],
device: torch.device,
shift_config: Optional[ShiftConfig | dict[str, Any]] = None,
sampler_cfg: Dict[str, Any] = None
) -> torch.Tensor:
"""
1) Clean prompt of any shunt tokens
2) Tokenize + encode via T5/BERT
3) Optionally project to sampler_cfg['projection_dims_in']
"""
# 1) prompt cleanup
if isinstance(shift_config, dict):
shift_config = ShiftConfig(**shift_config)
raw_prompt = shift_config.prompt
prompt = raw_prompt#RemoveSpecialTokens.remove_special_tokens(raw_prompt)
# 2) tokenize & encode
tokenizer = encoder_pipe["tokenizer"]
model = encoder_pipe["model"]
cfg = encoder_pipe["config"]["config"] # your existing mini‐config
tokens = tokenizer(
prompt,
return_tensors="pt",
padding=cfg.get("padding","max_length"),
truncation=True,
max_length=cfg.get("max_tokens",cfg.get("max_length", 512)),
)
input_ids = tokens["input_ids"].to(device)
attention_mask = tokens["attention_mask"].to(device)
with torch.no_grad():
model.to(device)
mtype = encoder_pipe["config"].get("model_type","")
if "t5" in mtype:
embeddings = model.encoder(input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
elif mtype in ("bert","nomic_bert"):
embeddings = model(input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
).last_hidden_state
else:
raise ValueError(f"Unsupported encoder type {mtype!r}")
model.to("cpu") # free GPU memory
# 3) optional input‐projection to match CLIP dims
if sampler_cfg and sampler_cfg.get("force_projection_in", False):
target_dims = sampler_cfg["projection_dims_in"]
embeddings = ConditioningShifter._project_embeddings(
embeddings, target_dims, sampler_cfg["interpolation_method_in"]
)
return embeddings.to(device)
@staticmethod
def _project_embeddings(
embeddings: torch.Tensor,
target_dim: int,
mode: str
) -> torch.Tensor:
"""
Interpolate the last dimension from D→target_dim via F.interpolate,
preserving batch & sequence dims.
"""
B, T, D = embeddings.shape
if D == target_dim:
return embeddings
# [B*T, 1, D] → interpolate → [B*T, 1, target_dim] → back to [B,T,target_dim]
flat = embeddings.reshape(B*T, 1, D)
proj = torch.nn.functional.interpolate(
flat.float(),
size=target_dim,
mode=mode,
align_corners=(mode in {"linear","bilinear","trilinear"})
)
return proj.reshape(B, T, target_dim)
@staticmethod
def run_adapter(adapter_model: ConditionModulationShuntAdapter,
encoder_embeddings: torch.Tensor,
clip_slice: torch.Tensor,
guidance_scale: float,
adapter_type: str,
slice_range: Tuple[int, int]) -> AdapterOutput:
"""Run adapter and package output"""
gen_config = {"max_guidance": guidance_scale if guidance_scale > 0 else 1.0}
#encoder_embeddings, clip_slice = reshape_for_shunt(encoder_embeddings, clip_slice, adapter_model)
with torch.no_grad():
outputs = adapter_model(encoder_embeddings.float(), clip_slice.float(), config=gen_config)
if isinstance(outputs, tuple) and len(outputs) == 8:
anchor, delta, log_sigma, attn_c2m, attn_m2c, tau, g_pred, gate = outputs
return AdapterOutput(
anchor=anchor,
delta=delta, # Already has gate multiplied!
log_sigma=log_sigma,
tau=tau,
g_pred=g_pred,
gate=gate,
adapter_type=adapter_type,
slice_range=slice_range,
attn_c2m=attn_c2m,
attn_m2c=attn_m2c
)
else:
raise ValueError(f"Unexpected adapter output format: {type(outputs)}")
@staticmethod
def apply_topk_selection(output: AdapterOutput, config: ShiftConfig) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply top-k selection using tau and attention weights.
Returns mask and selection scores for CLIP tokens.
"""
if not config.use_topk:
# Return full mask matching gate dimensions
return torch.ones_like(output.gate.squeeze(-1)), None
# Calculate selection scores based on mode
if config.topk_mode == "attention":
# Use modulation->condition attention (how much each CLIP token attends to encoder)
# Sum across encoder dimension to get importance score per CLIP token
scores = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip]
elif config.topk_mode == "attention_collaborative":
# Use modulation->condition attention (how much each CLIP token attends to encoder)
# Sum across encoder dimension to get importance score per CLIP token
# compare and normalize using the c2m attention as a soft mask
scores = output.attn_m2c.mean(dim=1).sum(dim=-1)
c2m_scores = output.attn_c2m.mean(dim=1).sum(dim=-1) # [batch, seq_clip]
# soft mask weaken and strengthen scores based on c2m_scores
scores = (scores - c2m_scores.min()) / (c2m_scores.max() - c2m_scores.min() + 1e-8)
elif config.topk_mode == "gate":
# Use gate values directly (already in CLIP space)
scores = output.gate.squeeze(-1) # [batch, seq_clip]
elif config.topk_mode == "combined":
# Combine attention and gate scores
attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip]
gate_score = output.gate.squeeze(-1)
# Normalize and combine
attn_score = (attn_score - attn_score.min()) / (attn_score.max() - attn_score.min() + 1e-8)
gate_score = (gate_score - gate_score.min()) / (gate_score.max() - gate_score.min() + 1e-8)
scores = (attn_score + gate_score) / 2
elif config.topk_mode == "tau_softmax":
# Use tau as temperature for softmax selection
attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip]
# Apply tau temperature scaling
tau_value = output.tau.mean().item() * config.tau_temperature
scores = torch.nn.functional.softmax(attn_score / tau_value, dim=-1)
else:
scores = output.gate.squeeze(-1)
# Calculate k
k = int(scores.size(-1) * (config.topk_percentage / 100.0))
k = max(1, min(k, scores.size(-1)))
# Get top-k indices
topk_values, topk_indices = torch.topk(scores, k, dim=-1)
# Create sparse mask
mask = torch.zeros_like(scores)
mask.scatter_(-1, topk_indices, 1.0)
return mask, scores
@staticmethod
def apply_modifications(clip_slice: torch.Tensor, outputs: List[AdapterOutput],
config: ShiftConfig) -> torch.Tensor:
"""Apply modifications based on config.pool_method"""
torch.manual_seed(config.seed if config.seed >= 0 else torch.randint(0, 2**32, (1,)).item())
modified = clip_slice.clone()
if config.pool_method == "sequential":
# Apply each adapter sequentially
for output in outputs:
modified = ConditioningShifter._apply_single(modified, output, config)
return modified
elif config.pool_method == "weighted_average":
# Pool all adapters then apply once
if len(outputs) == 1:
return ConditioningShifter._apply_single(modified, outputs[0], config)
pooled = ConditioningShifter._pool_outputs(outputs)
return ConditioningShifter._apply_single(clip_slice, pooled, config)
else:
raise ValueError(f"Unknown pool_method: {config.pool_method}")
@staticmethod
def _apply_single(clip_slice: torch.Tensor, output: AdapterOutput,
config: ShiftConfig) -> torch.Tensor:
"""Apply a single adapter output with optional top-k selection"""
# Apply top-k selection if enabled
topk_mask, scores = ConditioningShifter.apply_topk_selection(output, config)
# Preprocess (but remember delta already has gate!)
delta = output.delta * config.delta_scale + config.delta_mean
gate_scaled = output.gate * config.gate_probability
gate_mask = (gate_scaled > config.gate_threshold).float()
gate_masked = gate_scaled * gate_mask
# Apply top-k mask to gate and delta
if config.use_topk:
# Expand mask to match dimensions
topk_mask_expanded = topk_mask.unsqueeze(-1)
gate_masked = gate_masked * topk_mask_expanded
delta = delta * topk_mask_expanded
# Apply strength
delta_final = delta
# Apply based on anchor mode
if config.use_anchor:
# Blend original with anchor, then add delta
blended = clip_slice * (1 - gate_masked) + output.anchor * gate_masked
clip_modified = blended + delta_final
else:
# Simple additive
clip_modified = clip_slice + delta_final
# Apply noise
if config.sigma_scale > 0 and config.noise_injection > 0:
sigma = torch.exp(output.log_sigma * config.sigma_scale)
clip_modified += torch.randn_like(clip_modified) * sigma * config.noise_injection
elif config.noise_injection > 0:
clip_modified += torch.randn_like(clip_modified) * config.noise_injection
return clip_modified
@staticmethod
def _pool_outputs(outputs: List[AdapterOutput]) -> AdapterOutput:
"""Pool multiple adapter outputs into one"""
# Simple weighted average
total_weight = len(outputs)
pooled_anchor = sum(o.anchor for o in outputs) / total_weight
pooled_delta = sum(o.delta for o in outputs) / total_weight
pooled_log_sigma = sum(o.log_sigma for o in outputs) / total_weight
# Handle tau with different head counts
if all(o.tau is not None for o in outputs):
# Take mean across heads for each adapter, then average
tau_values = [o.tau.mean().item() for o in outputs]
pooled_tau_value = sum(tau_values) / total_weight
# Create scalar tensor on same device
pooled_tau = torch.tensor(pooled_tau_value, device=outputs[0].tau.device)
else:
pooled_tau = None
pooled_g_pred = sum(o.g_pred for o in outputs) / total_weight if outputs[0].g_pred is not None else None
pooled_gate = sum(o.gate for o in outputs) / total_weight
# Pool attention weights if available - handle different head counts
pooled_attn_c2m = None
pooled_attn_m2c = None
if all(o.attn_c2m is not None for o in outputs):
# First, average across heads for each adapter to get [batch, seq_c, seq_m]
attn_c2m_list = []
attn_m2c_list = []
for o in outputs:
# Average across heads dimension
attn_c2m_avg = o.attn_c2m.mean(dim=1) # [batch, seq_c, seq_m]
attn_m2c_avg = o.attn_m2c.mean(dim=1) # [batch, seq_m, seq_c]
attn_c2m_list.append(attn_c2m_avg)
attn_m2c_list.append(attn_m2c_avg)
# Now average across adapters
pooled_attn_c2m = sum(attn_c2m_list) / total_weight
pooled_attn_m2c = sum(attn_m2c_list) / total_weight
# Add back a dummy heads dimension for compatibility
pooled_attn_c2m = pooled_attn_c2m.unsqueeze(1) # [batch, 1, seq_c, seq_m]
pooled_attn_m2c = pooled_attn_m2c.unsqueeze(1) # [batch, 1, seq_m, seq_c]
return AdapterOutput(
anchor=pooled_anchor,
delta=pooled_delta,
log_sigma=pooled_log_sigma,
tau=pooled_tau,
g_pred=pooled_g_pred,
gate=pooled_gate,
adapter_type=outputs[0].adapter_type,
slice_range=outputs[0].slice_range,
attn_c2m=pooled_attn_c2m,
attn_m2c=pooled_attn_m2c
)
@staticmethod
def conditioning_set_values(conditioning, values={}, append=False):
"""
Set values in conditioning based on provided values.
Original set values was provided by comfyui node_helpers.py
"""
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
val = values[k]
if append:
old_val = n[1].get(k, None)
if old_val is not None:
val = old_val + val
n[1][k] = val
c.append(n)
return
@staticmethod
def conditioning_set_strength(conditioning, cond_strength: float, pool_strength: float = 1.0):
"""
Set strength in conditioning based on provided strength - we need to manually modify instead of setting values.
[ [base_tensor, { "pooled_outputs": pool, ... other dict entries } ], ... ]
"""
c = []
for t in conditioning:
base_tensor = t[0].copy()
# Set our usage strength, then find out if we have pooled outputs
base_tensor *= cond_strength
kwarg_dict = t[1].clone() if t[1] is not None else {} # copies the config params for later use
# lets get and remove the pooled outputs if they exist
pooled: Optional[None | torch.Tensor] = kwarg_dict.get("pooled_outputs", None)
if pooled is not None:
del kwarg_dict["pooled_outputs"]
pooled = pooled.clone()
# If we have pooled outputs, apply the pooled strength
pooled *= pool_strength
kwarg_dict["pooled_outputs"] = pooled
c.append([base_tensor, kwarg_dict])
|