Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import logging | |
from typing import Dict, List, Tuple, Optional, Any | |
from dataclasses import dataclass | |
from two_stream_shunt_adapter import ConditionModulationShuntAdapter, reshape_for_shunt | |
logger = logging.getLogger(__name__) | |
class ShiftConfig: | |
"""Unified configuration for all modifications""" | |
prompt: str = "" | |
seed: int = -1 # -1 means no seed, use random | |
strength: float = 1.0 | |
delta_mean: float = 0.0 | |
delta_scale: float = 1.0 | |
sigma_scale: float = 0.0 | |
gate_probability: float = 1.0 | |
gate_threshold: float = 0.1 | |
noise_injection: float = 0.0 | |
use_anchor: bool = True | |
pool_method: str = "sequential" # "sequential" or "weighted_average" | |
# Top-K parameters | |
use_topk: bool = False | |
topk_percentage: float = 100.0 # Percentage of tokens to keep | |
tau_temperature: float = 1.0 # Temperature scaling for tau | |
topk_mode: str = "attention" # "attention", "gate", "combined", "tau_softmax" | |
guidance_scale: float = 1.0, | |
max_tokens: int = 77 # Maximum number of tokens to process | |
class AdapterOutput: | |
"""Raw output from adapter forward pass""" | |
anchor: torch.Tensor | |
delta: torch.Tensor # Note: already has gate multiplied in! | |
log_sigma: torch.Tensor | |
tau: torch.Tensor | |
g_pred: torch.Tensor | |
gate: torch.Tensor | |
adapter_type: str | |
slice_range: Tuple[int, int] | |
# Add attention weights for top-k | |
attn_c2m: Optional[torch.Tensor] = None | |
attn_m2c: Optional[torch.Tensor] = None | |
class ConditioningShifter: | |
def extract_encoder_embeddings( | |
encoder_pipe: Dict[str, Any], | |
device: torch.device, | |
shift_config: Optional[ShiftConfig | dict[str, Any]] = None, | |
sampler_cfg: Dict[str, Any] = None | |
) -> torch.Tensor: | |
""" | |
1) Clean prompt of any shunt tokens | |
2) Tokenize + encode via T5/BERT | |
3) Optionally project to sampler_cfg['projection_dims_in'] | |
""" | |
# 1) prompt cleanup | |
if isinstance(shift_config, dict): | |
shift_config = ShiftConfig(**shift_config) | |
raw_prompt = shift_config.prompt | |
prompt = raw_prompt#RemoveSpecialTokens.remove_special_tokens(raw_prompt) | |
# 2) tokenize & encode | |
tokenizer = encoder_pipe["tokenizer"] | |
model = encoder_pipe["model"] | |
cfg = encoder_pipe["config"]["config"] # your existing mini‐config | |
tokens = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=cfg.get("padding","max_length"), | |
truncation=True, | |
max_length=cfg.get("max_tokens",cfg.get("max_length", 512)), | |
) | |
input_ids = tokens["input_ids"].to(device) | |
attention_mask = tokens["attention_mask"].to(device) | |
with torch.no_grad(): | |
model.to(device) | |
mtype = encoder_pipe["config"].get("model_type","") | |
if "t5" in mtype: | |
embeddings = model.encoder(input_ids=input_ids, | |
attention_mask=attention_mask | |
).last_hidden_state | |
elif mtype in ("bert","nomic_bert"): | |
embeddings = model(input_ids=input_ids, | |
attention_mask=attention_mask, | |
return_dict=True | |
).last_hidden_state | |
else: | |
raise ValueError(f"Unsupported encoder type {mtype!r}") | |
model.to("cpu") # free GPU memory | |
# 3) optional input‐projection to match CLIP dims | |
if sampler_cfg and sampler_cfg.get("force_projection_in", False): | |
target_dims = sampler_cfg["projection_dims_in"] | |
embeddings = ConditioningShifter._project_embeddings( | |
embeddings, target_dims, sampler_cfg["interpolation_method_in"] | |
) | |
return embeddings.to(device) | |
def _project_embeddings( | |
embeddings: torch.Tensor, | |
target_dim: int, | |
mode: str | |
) -> torch.Tensor: | |
""" | |
Interpolate the last dimension from D→target_dim via F.interpolate, | |
preserving batch & sequence dims. | |
""" | |
B, T, D = embeddings.shape | |
if D == target_dim: | |
return embeddings | |
# [B*T, 1, D] → interpolate → [B*T, 1, target_dim] → back to [B,T,target_dim] | |
flat = embeddings.reshape(B*T, 1, D) | |
proj = torch.nn.functional.interpolate( | |
flat.float(), | |
size=target_dim, | |
mode=mode, | |
align_corners=(mode in {"linear","bilinear","trilinear"}) | |
) | |
return proj.reshape(B, T, target_dim) | |
def run_adapter(adapter_model: ConditionModulationShuntAdapter, | |
encoder_embeddings: torch.Tensor, | |
clip_slice: torch.Tensor, | |
guidance_scale: float, | |
adapter_type: str, | |
slice_range: Tuple[int, int]) -> AdapterOutput: | |
"""Run adapter and package output""" | |
gen_config = {"max_guidance": guidance_scale if guidance_scale > 0 else 1.0} | |
#encoder_embeddings, clip_slice = reshape_for_shunt(encoder_embeddings, clip_slice, adapter_model) | |
with torch.no_grad(): | |
outputs = adapter_model(encoder_embeddings.float(), clip_slice.float(), config=gen_config) | |
if isinstance(outputs, tuple) and len(outputs) == 8: | |
anchor, delta, log_sigma, attn_c2m, attn_m2c, tau, g_pred, gate = outputs | |
return AdapterOutput( | |
anchor=anchor, | |
delta=delta, # Already has gate multiplied! | |
log_sigma=log_sigma, | |
tau=tau, | |
g_pred=g_pred, | |
gate=gate, | |
adapter_type=adapter_type, | |
slice_range=slice_range, | |
attn_c2m=attn_c2m, | |
attn_m2c=attn_m2c | |
) | |
else: | |
raise ValueError(f"Unexpected adapter output format: {type(outputs)}") | |
def apply_topk_selection(output: AdapterOutput, config: ShiftConfig) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Apply top-k selection using tau and attention weights. | |
Returns mask and selection scores for CLIP tokens. | |
""" | |
if not config.use_topk: | |
# Return full mask matching gate dimensions | |
return torch.ones_like(output.gate.squeeze(-1)), None | |
# Calculate selection scores based on mode | |
if config.topk_mode == "attention": | |
# Use modulation->condition attention (how much each CLIP token attends to encoder) | |
# Sum across encoder dimension to get importance score per CLIP token | |
scores = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip] | |
elif config.topk_mode == "attention_collaborative": | |
# Use modulation->condition attention (how much each CLIP token attends to encoder) | |
# Sum across encoder dimension to get importance score per CLIP token | |
# compare and normalize using the c2m attention as a soft mask | |
scores = output.attn_m2c.mean(dim=1).sum(dim=-1) | |
c2m_scores = output.attn_c2m.mean(dim=1).sum(dim=-1) # [batch, seq_clip] | |
# soft mask weaken and strengthen scores based on c2m_scores | |
scores = (scores - c2m_scores.min()) / (c2m_scores.max() - c2m_scores.min() + 1e-8) | |
elif config.topk_mode == "gate": | |
# Use gate values directly (already in CLIP space) | |
scores = output.gate.squeeze(-1) # [batch, seq_clip] | |
elif config.topk_mode == "combined": | |
# Combine attention and gate scores | |
attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip] | |
gate_score = output.gate.squeeze(-1) | |
# Normalize and combine | |
attn_score = (attn_score - attn_score.min()) / (attn_score.max() - attn_score.min() + 1e-8) | |
gate_score = (gate_score - gate_score.min()) / (gate_score.max() - gate_score.min() + 1e-8) | |
scores = (attn_score + gate_score) / 2 | |
elif config.topk_mode == "tau_softmax": | |
# Use tau as temperature for softmax selection | |
attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip] | |
# Apply tau temperature scaling | |
tau_value = output.tau.mean().item() * config.tau_temperature | |
scores = torch.nn.functional.softmax(attn_score / tau_value, dim=-1) | |
else: | |
scores = output.gate.squeeze(-1) | |
# Calculate k | |
k = int(scores.size(-1) * (config.topk_percentage / 100.0)) | |
k = max(1, min(k, scores.size(-1))) | |
# Get top-k indices | |
topk_values, topk_indices = torch.topk(scores, k, dim=-1) | |
# Create sparse mask | |
mask = torch.zeros_like(scores) | |
mask.scatter_(-1, topk_indices, 1.0) | |
return mask, scores | |
def apply_modifications(clip_slice: torch.Tensor, outputs: List[AdapterOutput], | |
config: ShiftConfig) -> torch.Tensor: | |
"""Apply modifications based on config.pool_method""" | |
torch.manual_seed(config.seed if config.seed >= 0 else torch.randint(0, 2**32, (1,)).item()) | |
modified = clip_slice.clone() | |
if config.pool_method == "sequential": | |
# Apply each adapter sequentially | |
for output in outputs: | |
modified = ConditioningShifter._apply_single(modified, output, config) | |
return modified | |
elif config.pool_method == "weighted_average": | |
# Pool all adapters then apply once | |
if len(outputs) == 1: | |
return ConditioningShifter._apply_single(modified, outputs[0], config) | |
pooled = ConditioningShifter._pool_outputs(outputs) | |
return ConditioningShifter._apply_single(clip_slice, pooled, config) | |
else: | |
raise ValueError(f"Unknown pool_method: {config.pool_method}") | |
def _apply_single(clip_slice: torch.Tensor, output: AdapterOutput, | |
config: ShiftConfig) -> torch.Tensor: | |
"""Apply a single adapter output with optional top-k selection""" | |
# Apply top-k selection if enabled | |
topk_mask, scores = ConditioningShifter.apply_topk_selection(output, config) | |
# Preprocess (but remember delta already has gate!) | |
delta = output.delta * config.delta_scale + config.delta_mean | |
gate_scaled = output.gate * config.gate_probability | |
gate_mask = (gate_scaled > config.gate_threshold).float() | |
gate_masked = gate_scaled * gate_mask | |
# Apply top-k mask to gate and delta | |
if config.use_topk: | |
# Expand mask to match dimensions | |
topk_mask_expanded = topk_mask.unsqueeze(-1) | |
gate_masked = gate_masked * topk_mask_expanded | |
delta = delta * topk_mask_expanded | |
# Apply strength | |
delta_final = delta | |
# Apply based on anchor mode | |
if config.use_anchor: | |
# Blend original with anchor, then add delta | |
blended = clip_slice * (1 - gate_masked) + output.anchor * gate_masked | |
clip_modified = blended + delta_final | |
else: | |
# Simple additive | |
clip_modified = clip_slice + delta_final | |
# Apply noise | |
if config.sigma_scale > 0 and config.noise_injection > 0: | |
sigma = torch.exp(output.log_sigma * config.sigma_scale) | |
clip_modified += torch.randn_like(clip_modified) * sigma * config.noise_injection | |
elif config.noise_injection > 0: | |
clip_modified += torch.randn_like(clip_modified) * config.noise_injection | |
return clip_modified | |
def _pool_outputs(outputs: List[AdapterOutput]) -> AdapterOutput: | |
"""Pool multiple adapter outputs into one""" | |
# Simple weighted average | |
total_weight = len(outputs) | |
pooled_anchor = sum(o.anchor for o in outputs) / total_weight | |
pooled_delta = sum(o.delta for o in outputs) / total_weight | |
pooled_log_sigma = sum(o.log_sigma for o in outputs) / total_weight | |
# Handle tau with different head counts | |
if all(o.tau is not None for o in outputs): | |
# Take mean across heads for each adapter, then average | |
tau_values = [o.tau.mean().item() for o in outputs] | |
pooled_tau_value = sum(tau_values) / total_weight | |
# Create scalar tensor on same device | |
pooled_tau = torch.tensor(pooled_tau_value, device=outputs[0].tau.device) | |
else: | |
pooled_tau = None | |
pooled_g_pred = sum(o.g_pred for o in outputs) / total_weight if outputs[0].g_pred is not None else None | |
pooled_gate = sum(o.gate for o in outputs) / total_weight | |
# Pool attention weights if available - handle different head counts | |
pooled_attn_c2m = None | |
pooled_attn_m2c = None | |
if all(o.attn_c2m is not None for o in outputs): | |
# First, average across heads for each adapter to get [batch, seq_c, seq_m] | |
attn_c2m_list = [] | |
attn_m2c_list = [] | |
for o in outputs: | |
# Average across heads dimension | |
attn_c2m_avg = o.attn_c2m.mean(dim=1) # [batch, seq_c, seq_m] | |
attn_m2c_avg = o.attn_m2c.mean(dim=1) # [batch, seq_m, seq_c] | |
attn_c2m_list.append(attn_c2m_avg) | |
attn_m2c_list.append(attn_m2c_avg) | |
# Now average across adapters | |
pooled_attn_c2m = sum(attn_c2m_list) / total_weight | |
pooled_attn_m2c = sum(attn_m2c_list) / total_weight | |
# Add back a dummy heads dimension for compatibility | |
pooled_attn_c2m = pooled_attn_c2m.unsqueeze(1) # [batch, 1, seq_c, seq_m] | |
pooled_attn_m2c = pooled_attn_m2c.unsqueeze(1) # [batch, 1, seq_m, seq_c] | |
return AdapterOutput( | |
anchor=pooled_anchor, | |
delta=pooled_delta, | |
log_sigma=pooled_log_sigma, | |
tau=pooled_tau, | |
g_pred=pooled_g_pred, | |
gate=pooled_gate, | |
adapter_type=outputs[0].adapter_type, | |
slice_range=outputs[0].slice_range, | |
attn_c2m=pooled_attn_c2m, | |
attn_m2c=pooled_attn_m2c | |
) | |
def conditioning_set_values(conditioning, values={}, append=False): | |
""" | |
Set values in conditioning based on provided values. | |
Original set values was provided by comfyui node_helpers.py | |
""" | |
c = [] | |
for t in conditioning: | |
n = [t[0], t[1].copy()] | |
for k in values: | |
val = values[k] | |
if append: | |
old_val = n[1].get(k, None) | |
if old_val is not None: | |
val = old_val + val | |
n[1][k] = val | |
c.append(n) | |
return | |
def conditioning_set_strength(conditioning, cond_strength: float, pool_strength: float = 1.0): | |
""" | |
Set strength in conditioning based on provided strength - we need to manually modify instead of setting values. | |
[ [base_tensor, { "pooled_outputs": pool, ... other dict entries } ], ... ] | |
""" | |
c = [] | |
for t in conditioning: | |
base_tensor = t[0].copy() | |
# Set our usage strength, then find out if we have pooled outputs | |
base_tensor *= cond_strength | |
kwarg_dict = t[1].clone() if t[1] is not None else {} # copies the config params for later use | |
# lets get and remove the pooled outputs if they exist | |
pooled: Optional[None | torch.Tensor] = kwarg_dict.get("pooled_outputs", None) | |
if pooled is not None: | |
del kwarg_dict["pooled_outputs"] | |
pooled = pooled.clone() | |
# If we have pooled outputs, apply the pooled strength | |
pooled *= pool_strength | |
kwarg_dict["pooled_outputs"] = pooled | |
c.append([base_tensor, kwarg_dict]) | |