import torch import numpy as np import logging from typing import Dict, List, Tuple, Optional, Any from dataclasses import dataclass from two_stream_shunt_adapter import ConditionModulationShuntAdapter, reshape_for_shunt logger = logging.getLogger(__name__) @dataclass class ShiftConfig: """Unified configuration for all modifications""" prompt: str = "" seed: int = -1 # -1 means no seed, use random strength: float = 1.0 delta_mean: float = 0.0 delta_scale: float = 1.0 sigma_scale: float = 0.0 gate_probability: float = 1.0 gate_threshold: float = 0.1 noise_injection: float = 0.0 use_anchor: bool = True pool_method: str = "sequential" # "sequential" or "weighted_average" # Top-K parameters use_topk: bool = False topk_percentage: float = 100.0 # Percentage of tokens to keep tau_temperature: float = 1.0 # Temperature scaling for tau topk_mode: str = "attention" # "attention", "gate", "combined", "tau_softmax" guidance_scale: float = 1.0, max_tokens: int = 77 # Maximum number of tokens to process @dataclass class AdapterOutput: """Raw output from adapter forward pass""" anchor: torch.Tensor delta: torch.Tensor # Note: already has gate multiplied in! log_sigma: torch.Tensor tau: torch.Tensor g_pred: torch.Tensor gate: torch.Tensor adapter_type: str slice_range: Tuple[int, int] # Add attention weights for top-k attn_c2m: Optional[torch.Tensor] = None attn_m2c: Optional[torch.Tensor] = None class ConditioningShifter: @staticmethod def extract_encoder_embeddings( encoder_pipe: Dict[str, Any], device: torch.device, shift_config: Optional[ShiftConfig | dict[str, Any]] = None, sampler_cfg: Dict[str, Any] = None ) -> torch.Tensor: """ 1) Clean prompt of any shunt tokens 2) Tokenize + encode via T5/BERT 3) Optionally project to sampler_cfg['projection_dims_in'] """ # 1) prompt cleanup if isinstance(shift_config, dict): shift_config = ShiftConfig(**shift_config) raw_prompt = shift_config.prompt prompt = raw_prompt#RemoveSpecialTokens.remove_special_tokens(raw_prompt) # 2) tokenize & encode tokenizer = encoder_pipe["tokenizer"] model = encoder_pipe["model"] cfg = encoder_pipe["config"]["config"] # your existing mini‐config tokens = tokenizer( prompt, return_tensors="pt", padding=cfg.get("padding","max_length"), truncation=True, max_length=cfg.get("max_tokens",cfg.get("max_length", 512)), ) input_ids = tokens["input_ids"].to(device) attention_mask = tokens["attention_mask"].to(device) with torch.no_grad(): model.to(device) mtype = encoder_pipe["config"].get("model_type","") if "t5" in mtype: embeddings = model.encoder(input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state elif mtype in ("bert","nomic_bert"): embeddings = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True ).last_hidden_state else: raise ValueError(f"Unsupported encoder type {mtype!r}") model.to("cpu") # free GPU memory # 3) optional input‐projection to match CLIP dims if sampler_cfg and sampler_cfg.get("force_projection_in", False): target_dims = sampler_cfg["projection_dims_in"] embeddings = ConditioningShifter._project_embeddings( embeddings, target_dims, sampler_cfg["interpolation_method_in"] ) return embeddings.to(device) @staticmethod def _project_embeddings( embeddings: torch.Tensor, target_dim: int, mode: str ) -> torch.Tensor: """ Interpolate the last dimension from D→target_dim via F.interpolate, preserving batch & sequence dims. """ B, T, D = embeddings.shape if D == target_dim: return embeddings # [B*T, 1, D] → interpolate → [B*T, 1, target_dim] → back to [B,T,target_dim] flat = embeddings.reshape(B*T, 1, D) proj = torch.nn.functional.interpolate( flat.float(), size=target_dim, mode=mode, align_corners=(mode in {"linear","bilinear","trilinear"}) ) return proj.reshape(B, T, target_dim) @staticmethod def run_adapter(adapter_model: ConditionModulationShuntAdapter, encoder_embeddings: torch.Tensor, clip_slice: torch.Tensor, guidance_scale: float, adapter_type: str, slice_range: Tuple[int, int]) -> AdapterOutput: """Run adapter and package output""" gen_config = {"max_guidance": guidance_scale if guidance_scale > 0 else 1.0} #encoder_embeddings, clip_slice = reshape_for_shunt(encoder_embeddings, clip_slice, adapter_model) with torch.no_grad(): outputs = adapter_model(encoder_embeddings.float(), clip_slice.float(), config=gen_config) if isinstance(outputs, tuple) and len(outputs) == 8: anchor, delta, log_sigma, attn_c2m, attn_m2c, tau, g_pred, gate = outputs return AdapterOutput( anchor=anchor, delta=delta, # Already has gate multiplied! log_sigma=log_sigma, tau=tau, g_pred=g_pred, gate=gate, adapter_type=adapter_type, slice_range=slice_range, attn_c2m=attn_c2m, attn_m2c=attn_m2c ) else: raise ValueError(f"Unexpected adapter output format: {type(outputs)}") @staticmethod def apply_topk_selection(output: AdapterOutput, config: ShiftConfig) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply top-k selection using tau and attention weights. Returns mask and selection scores for CLIP tokens. """ if not config.use_topk: # Return full mask matching gate dimensions return torch.ones_like(output.gate.squeeze(-1)), None # Calculate selection scores based on mode if config.topk_mode == "attention": # Use modulation->condition attention (how much each CLIP token attends to encoder) # Sum across encoder dimension to get importance score per CLIP token scores = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip] elif config.topk_mode == "attention_collaborative": # Use modulation->condition attention (how much each CLIP token attends to encoder) # Sum across encoder dimension to get importance score per CLIP token # compare and normalize using the c2m attention as a soft mask scores = output.attn_m2c.mean(dim=1).sum(dim=-1) c2m_scores = output.attn_c2m.mean(dim=1).sum(dim=-1) # [batch, seq_clip] # soft mask weaken and strengthen scores based on c2m_scores scores = (scores - c2m_scores.min()) / (c2m_scores.max() - c2m_scores.min() + 1e-8) elif config.topk_mode == "gate": # Use gate values directly (already in CLIP space) scores = output.gate.squeeze(-1) # [batch, seq_clip] elif config.topk_mode == "combined": # Combine attention and gate scores attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip] gate_score = output.gate.squeeze(-1) # Normalize and combine attn_score = (attn_score - attn_score.min()) / (attn_score.max() - attn_score.min() + 1e-8) gate_score = (gate_score - gate_score.min()) / (gate_score.max() - gate_score.min() + 1e-8) scores = (attn_score + gate_score) / 2 elif config.topk_mode == "tau_softmax": # Use tau as temperature for softmax selection attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip] # Apply tau temperature scaling tau_value = output.tau.mean().item() * config.tau_temperature scores = torch.nn.functional.softmax(attn_score / tau_value, dim=-1) else: scores = output.gate.squeeze(-1) # Calculate k k = int(scores.size(-1) * (config.topk_percentage / 100.0)) k = max(1, min(k, scores.size(-1))) # Get top-k indices topk_values, topk_indices = torch.topk(scores, k, dim=-1) # Create sparse mask mask = torch.zeros_like(scores) mask.scatter_(-1, topk_indices, 1.0) return mask, scores @staticmethod def apply_modifications(clip_slice: torch.Tensor, outputs: List[AdapterOutput], config: ShiftConfig) -> torch.Tensor: """Apply modifications based on config.pool_method""" torch.manual_seed(config.seed if config.seed >= 0 else torch.randint(0, 2**32, (1,)).item()) modified = clip_slice.clone() if config.pool_method == "sequential": # Apply each adapter sequentially for output in outputs: modified = ConditioningShifter._apply_single(modified, output, config) return modified elif config.pool_method == "weighted_average": # Pool all adapters then apply once if len(outputs) == 1: return ConditioningShifter._apply_single(modified, outputs[0], config) pooled = ConditioningShifter._pool_outputs(outputs) return ConditioningShifter._apply_single(clip_slice, pooled, config) else: raise ValueError(f"Unknown pool_method: {config.pool_method}") @staticmethod def _apply_single(clip_slice: torch.Tensor, output: AdapterOutput, config: ShiftConfig) -> torch.Tensor: """Apply a single adapter output with optional top-k selection""" # Apply top-k selection if enabled topk_mask, scores = ConditioningShifter.apply_topk_selection(output, config) # Preprocess (but remember delta already has gate!) delta = output.delta * config.delta_scale + config.delta_mean gate_scaled = output.gate * config.gate_probability gate_mask = (gate_scaled > config.gate_threshold).float() gate_masked = gate_scaled * gate_mask # Apply top-k mask to gate and delta if config.use_topk: # Expand mask to match dimensions topk_mask_expanded = topk_mask.unsqueeze(-1) gate_masked = gate_masked * topk_mask_expanded delta = delta * topk_mask_expanded # Apply strength delta_final = delta # Apply based on anchor mode if config.use_anchor: # Blend original with anchor, then add delta blended = clip_slice * (1 - gate_masked) + output.anchor * gate_masked clip_modified = blended + delta_final else: # Simple additive clip_modified = clip_slice + delta_final # Apply noise if config.sigma_scale > 0 and config.noise_injection > 0: sigma = torch.exp(output.log_sigma * config.sigma_scale) clip_modified += torch.randn_like(clip_modified) * sigma * config.noise_injection elif config.noise_injection > 0: clip_modified += torch.randn_like(clip_modified) * config.noise_injection return clip_modified @staticmethod def _pool_outputs(outputs: List[AdapterOutput]) -> AdapterOutput: """Pool multiple adapter outputs into one""" # Simple weighted average total_weight = len(outputs) pooled_anchor = sum(o.anchor for o in outputs) / total_weight pooled_delta = sum(o.delta for o in outputs) / total_weight pooled_log_sigma = sum(o.log_sigma for o in outputs) / total_weight # Handle tau with different head counts if all(o.tau is not None for o in outputs): # Take mean across heads for each adapter, then average tau_values = [o.tau.mean().item() for o in outputs] pooled_tau_value = sum(tau_values) / total_weight # Create scalar tensor on same device pooled_tau = torch.tensor(pooled_tau_value, device=outputs[0].tau.device) else: pooled_tau = None pooled_g_pred = sum(o.g_pred for o in outputs) / total_weight if outputs[0].g_pred is not None else None pooled_gate = sum(o.gate for o in outputs) / total_weight # Pool attention weights if available - handle different head counts pooled_attn_c2m = None pooled_attn_m2c = None if all(o.attn_c2m is not None for o in outputs): # First, average across heads for each adapter to get [batch, seq_c, seq_m] attn_c2m_list = [] attn_m2c_list = [] for o in outputs: # Average across heads dimension attn_c2m_avg = o.attn_c2m.mean(dim=1) # [batch, seq_c, seq_m] attn_m2c_avg = o.attn_m2c.mean(dim=1) # [batch, seq_m, seq_c] attn_c2m_list.append(attn_c2m_avg) attn_m2c_list.append(attn_m2c_avg) # Now average across adapters pooled_attn_c2m = sum(attn_c2m_list) / total_weight pooled_attn_m2c = sum(attn_m2c_list) / total_weight # Add back a dummy heads dimension for compatibility pooled_attn_c2m = pooled_attn_c2m.unsqueeze(1) # [batch, 1, seq_c, seq_m] pooled_attn_m2c = pooled_attn_m2c.unsqueeze(1) # [batch, 1, seq_m, seq_c] return AdapterOutput( anchor=pooled_anchor, delta=pooled_delta, log_sigma=pooled_log_sigma, tau=pooled_tau, g_pred=pooled_g_pred, gate=pooled_gate, adapter_type=outputs[0].adapter_type, slice_range=outputs[0].slice_range, attn_c2m=pooled_attn_c2m, attn_m2c=pooled_attn_m2c ) @staticmethod def conditioning_set_values(conditioning, values={}, append=False): """ Set values in conditioning based on provided values. Original set values was provided by comfyui node_helpers.py """ c = [] for t in conditioning: n = [t[0], t[1].copy()] for k in values: val = values[k] if append: old_val = n[1].get(k, None) if old_val is not None: val = old_val + val n[1][k] = val c.append(n) return @staticmethod def conditioning_set_strength(conditioning, cond_strength: float, pool_strength: float = 1.0): """ Set strength in conditioning based on provided strength - we need to manually modify instead of setting values. [ [base_tensor, { "pooled_outputs": pool, ... other dict entries } ], ... ] """ c = [] for t in conditioning: base_tensor = t[0].copy() # Set our usage strength, then find out if we have pooled outputs base_tensor *= cond_strength kwarg_dict = t[1].clone() if t[1] is not None else {} # copies the config params for later use # lets get and remove the pooled outputs if they exist pooled: Optional[None | torch.Tensor] = kwarg_dict.get("pooled_outputs", None) if pooled is not None: del kwarg_dict["pooled_outputs"] pooled = pooled.clone() # If we have pooled outputs, apply the pooled strength pooled *= pool_strength kwarg_dict["pooled_outputs"] = pooled c.append([base_tensor, kwarg_dict])