AbstractPhil commited on
Commit
c98544f
·
verified ·
1 Parent(s): 535b292

Create conditioning_shifter.py

Browse files
Files changed (1) hide show
  1. conditioning_shifter.py +402 -0
conditioning_shifter.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import logging
4
+ from typing import Dict, List, Tuple, Optional, Any
5
+ from dataclasses import dataclass
6
+
7
+ from ..model.dual_stream_adapter_model import ConditionModulationShuntAdapter, reshape_for_shunt
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ @dataclass
12
+ class ShiftConfig:
13
+ """Unified configuration for all modifications"""
14
+ prompt: str = ""
15
+ seed: int = -1 # -1 means no seed, use random
16
+ strength: float = 1.0
17
+ delta_mean: float = 0.0
18
+ delta_scale: float = 1.0
19
+ sigma_scale: float = 0.0
20
+ gate_probability: float = 1.0
21
+ gate_threshold: float = 0.1
22
+ noise_injection: float = 0.0
23
+ use_anchor: bool = True
24
+ pool_method: str = "sequential" # "sequential" or "weighted_average"
25
+ # Top-K parameters
26
+ use_topk: bool = False
27
+ topk_percentage: float = 100.0 # Percentage of tokens to keep
28
+ tau_temperature: float = 1.0 # Temperature scaling for tau
29
+ topk_mode: str = "attention" # "attention", "gate", "combined", "tau_softmax"
30
+ guidance_scale: float = 1.0,
31
+ max_tokens: int = 77 # Maximum number of tokens to process
32
+
33
+
34
+ @dataclass
35
+ class AdapterOutput:
36
+ """Raw output from adapter forward pass"""
37
+ anchor: torch.Tensor
38
+ delta: torch.Tensor # Note: already has gate multiplied in!
39
+ log_sigma: torch.Tensor
40
+ tau: torch.Tensor
41
+ g_pred: torch.Tensor
42
+ gate: torch.Tensor
43
+ adapter_type: str
44
+ slice_range: Tuple[int, int]
45
+ # Add attention weights for top-k
46
+ attn_c2m: Optional[torch.Tensor] = None
47
+ attn_m2c: Optional[torch.Tensor] = None
48
+
49
+
50
+ class ConditioningShifter:
51
+ @staticmethod
52
+ def extract_encoder_embeddings(
53
+ encoder_pipe: Dict[str, Any],
54
+ device: torch.device,
55
+ shift_config: Optional[ShiftConfig | dict[str, Any]] = None,
56
+ sampler_cfg: Dict[str, Any] = None
57
+ ) -> torch.Tensor:
58
+ """
59
+ 1) Clean prompt of any shunt tokens
60
+ 2) Tokenize + encode via T5/BERT
61
+ 3) Optionally project to sampler_cfg['projection_dims_in']
62
+ """
63
+ # 1) prompt cleanup
64
+ if isinstance(shift_config, dict):
65
+ shift_config = ShiftConfig(**shift_config)
66
+ raw_prompt = shift_config.prompt
67
+ prompt = raw_prompt#RemoveSpecialTokens.remove_special_tokens(raw_prompt)
68
+
69
+ # 2) tokenize & encode
70
+ tokenizer = encoder_pipe["tokenizer"]
71
+ model = encoder_pipe["model"]
72
+ cfg = encoder_pipe["config"]["config"] # your existing mini‐config
73
+
74
+ tokens = tokenizer(
75
+ prompt,
76
+ return_tensors="pt",
77
+ padding=cfg.get("padding","max_length"),
78
+ truncation=True,
79
+ max_length=cfg.get("max_tokens",cfg.get("max_length", 512)),
80
+ )
81
+ input_ids = tokens["input_ids"].to(device)
82
+ attention_mask = tokens["attention_mask"].to(device)
83
+
84
+ with torch.no_grad():
85
+ model.to(device)
86
+ mtype = encoder_pipe["config"].get("model_type","")
87
+ if "t5" in mtype:
88
+ embeddings = model.encoder(input_ids=input_ids,
89
+ attention_mask=attention_mask
90
+ ).last_hidden_state
91
+ elif mtype in ("bert","nomic_bert"):
92
+ embeddings = model(input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ return_dict=True
95
+ ).last_hidden_state
96
+ else:
97
+ raise ValueError(f"Unsupported encoder type {mtype!r}")
98
+ model.to("cpu") # free GPU memory
99
+
100
+ # 3) optional input‐projection to match CLIP dims
101
+ if sampler_cfg and sampler_cfg.get("force_projection_in", False):
102
+ target_dims = sampler_cfg["projection_dims_in"]
103
+ embeddings = ConditioningShifter._project_embeddings(
104
+ embeddings, target_dims, sampler_cfg["interpolation_method_in"]
105
+ )
106
+
107
+ return embeddings.to(device)
108
+
109
+
110
+ @staticmethod
111
+ def _project_embeddings(
112
+ embeddings: torch.Tensor,
113
+ target_dim: int,
114
+ mode: str
115
+ ) -> torch.Tensor:
116
+ """
117
+ Interpolate the last dimension from D→target_dim via F.interpolate,
118
+ preserving batch & sequence dims.
119
+ """
120
+ B, T, D = embeddings.shape
121
+ if D == target_dim:
122
+ return embeddings
123
+
124
+ # [B*T, 1, D] → interpolate → [B*T, 1, target_dim] → back to [B,T,target_dim]
125
+ flat = embeddings.reshape(B*T, 1, D)
126
+ proj = torch.nn.functional.interpolate(
127
+ flat.float(),
128
+ size=target_dim,
129
+ mode=mode,
130
+ align_corners=(mode in {"linear","bilinear","trilinear"})
131
+ )
132
+ return proj.reshape(B, T, target_dim)
133
+
134
+ @staticmethod
135
+ def run_adapter(adapter_model: ConditionModulationShuntAdapter,
136
+ encoder_embeddings: torch.Tensor,
137
+ clip_slice: torch.Tensor,
138
+ guidance_scale: float,
139
+ adapter_type: str,
140
+ slice_range: Tuple[int, int]) -> AdapterOutput:
141
+ """Run adapter and package output"""
142
+ gen_config = {"max_guidance": guidance_scale if guidance_scale > 0 else 1.0}
143
+
144
+ #encoder_embeddings, clip_slice = reshape_for_shunt(encoder_embeddings, clip_slice, adapter_model)
145
+
146
+ with torch.no_grad():
147
+ outputs = adapter_model(encoder_embeddings.float(), clip_slice.float(), config=gen_config)
148
+
149
+ if isinstance(outputs, tuple) and len(outputs) == 8:
150
+ anchor, delta, log_sigma, attn_c2m, attn_m2c, tau, g_pred, gate = outputs
151
+ return AdapterOutput(
152
+ anchor=anchor,
153
+ delta=delta, # Already has gate multiplied!
154
+ log_sigma=log_sigma,
155
+ tau=tau,
156
+ g_pred=g_pred,
157
+ gate=gate,
158
+ adapter_type=adapter_type,
159
+ slice_range=slice_range,
160
+ attn_c2m=attn_c2m,
161
+ attn_m2c=attn_m2c
162
+ )
163
+ else:
164
+ raise ValueError(f"Unexpected adapter output format: {type(outputs)}")
165
+
166
+ @staticmethod
167
+ def apply_topk_selection(output: AdapterOutput, config: ShiftConfig) -> Tuple[torch.Tensor, torch.Tensor]:
168
+ """
169
+ Apply top-k selection using tau and attention weights.
170
+ Returns mask and selection scores for CLIP tokens.
171
+ """
172
+ if not config.use_topk:
173
+ # Return full mask matching gate dimensions
174
+ return torch.ones_like(output.gate.squeeze(-1)), None
175
+
176
+ # Calculate selection scores based on mode
177
+ if config.topk_mode == "attention":
178
+ # Use modulation->condition attention (how much each CLIP token attends to encoder)
179
+ # Sum across encoder dimension to get importance score per CLIP token
180
+ scores = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip]
181
+ elif config.topk_mode == "attention_collaborative":
182
+ # Use modulation->condition attention (how much each CLIP token attends to encoder)
183
+ # Sum across encoder dimension to get importance score per CLIP token
184
+ # compare and normalize using the c2m attention as a soft mask
185
+ scores = output.attn_m2c.mean(dim=1).sum(dim=-1)
186
+ c2m_scores = output.attn_c2m.mean(dim=1).sum(dim=-1) # [batch, seq_clip]
187
+ # soft mask weaken and strengthen scores based on c2m_scores
188
+ scores = (scores - c2m_scores.min()) / (c2m_scores.max() - c2m_scores.min() + 1e-8)
189
+
190
+
191
+ elif config.topk_mode == "gate":
192
+ # Use gate values directly (already in CLIP space)
193
+ scores = output.gate.squeeze(-1) # [batch, seq_clip]
194
+
195
+ elif config.topk_mode == "combined":
196
+ # Combine attention and gate scores
197
+ attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip]
198
+ gate_score = output.gate.squeeze(-1)
199
+
200
+ # Normalize and combine
201
+ attn_score = (attn_score - attn_score.min()) / (attn_score.max() - attn_score.min() + 1e-8)
202
+ gate_score = (gate_score - gate_score.min()) / (gate_score.max() - gate_score.min() + 1e-8)
203
+
204
+ scores = (attn_score + gate_score) / 2
205
+
206
+ elif config.topk_mode == "tau_softmax":
207
+ # Use tau as temperature for softmax selection
208
+ attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1) # [batch, seq_clip]
209
+
210
+ # Apply tau temperature scaling
211
+ tau_value = output.tau.mean().item() * config.tau_temperature
212
+ scores = torch.nn.functional.softmax(attn_score / tau_value, dim=-1)
213
+ else:
214
+ scores = output.gate.squeeze(-1)
215
+
216
+ # Calculate k
217
+ k = int(scores.size(-1) * (config.topk_percentage / 100.0))
218
+ k = max(1, min(k, scores.size(-1)))
219
+
220
+ # Get top-k indices
221
+ topk_values, topk_indices = torch.topk(scores, k, dim=-1)
222
+
223
+ # Create sparse mask
224
+ mask = torch.zeros_like(scores)
225
+ mask.scatter_(-1, topk_indices, 1.0)
226
+
227
+ return mask, scores
228
+
229
+ @staticmethod
230
+ def apply_modifications(clip_slice: torch.Tensor, outputs: List[AdapterOutput],
231
+ config: ShiftConfig) -> torch.Tensor:
232
+ """Apply modifications based on config.pool_method"""
233
+ torch.manual_seed(config.seed if config.seed >= 0 else torch.randint(0, 2**32, (1,)).item())
234
+
235
+ modified = clip_slice.clone()
236
+ if config.pool_method == "sequential":
237
+ # Apply each adapter sequentially
238
+ for output in outputs:
239
+ modified = ConditioningShifter._apply_single(modified, output, config)
240
+ return modified
241
+
242
+ elif config.pool_method == "weighted_average":
243
+ # Pool all adapters then apply once
244
+ if len(outputs) == 1:
245
+ return ConditioningShifter._apply_single(modified, outputs[0], config)
246
+
247
+ pooled = ConditioningShifter._pool_outputs(outputs)
248
+ return ConditioningShifter._apply_single(clip_slice, pooled, config)
249
+
250
+ else:
251
+ raise ValueError(f"Unknown pool_method: {config.pool_method}")
252
+
253
+ @staticmethod
254
+ def _apply_single(clip_slice: torch.Tensor, output: AdapterOutput,
255
+ config: ShiftConfig) -> torch.Tensor:
256
+ """Apply a single adapter output with optional top-k selection"""
257
+
258
+ # Apply top-k selection if enabled
259
+ topk_mask, scores = ConditioningShifter.apply_topk_selection(output, config)
260
+
261
+ # Preprocess (but remember delta already has gate!)
262
+ delta = output.delta * config.delta_scale + config.delta_mean
263
+
264
+ gate_scaled = output.gate * config.gate_probability
265
+ gate_mask = (gate_scaled > config.gate_threshold).float()
266
+ gate_masked = gate_scaled * gate_mask
267
+
268
+ # Apply top-k mask to gate and delta
269
+ if config.use_topk:
270
+ # Expand mask to match dimensions
271
+ topk_mask_expanded = topk_mask.unsqueeze(-1)
272
+ gate_masked = gate_masked * topk_mask_expanded
273
+ delta = delta * topk_mask_expanded
274
+
275
+ # Apply strength
276
+ delta_final = delta
277
+
278
+ # Apply based on anchor mode
279
+ if config.use_anchor:
280
+ # Blend original with anchor, then add delta
281
+ blended = clip_slice * (1 - gate_masked) + output.anchor * gate_masked
282
+ clip_modified = blended + delta_final
283
+ else:
284
+ # Simple additive
285
+ clip_modified = clip_slice + delta_final
286
+
287
+ # Apply noise
288
+ if config.sigma_scale > 0 and config.noise_injection > 0:
289
+ sigma = torch.exp(output.log_sigma * config.sigma_scale)
290
+ clip_modified += torch.randn_like(clip_modified) * sigma * config.noise_injection
291
+ elif config.noise_injection > 0:
292
+ clip_modified += torch.randn_like(clip_modified) * config.noise_injection
293
+
294
+ return clip_modified
295
+
296
+ @staticmethod
297
+ def _pool_outputs(outputs: List[AdapterOutput]) -> AdapterOutput:
298
+ """Pool multiple adapter outputs into one"""
299
+ # Simple weighted average
300
+ total_weight = len(outputs)
301
+
302
+ pooled_anchor = sum(o.anchor for o in outputs) / total_weight
303
+ pooled_delta = sum(o.delta for o in outputs) / total_weight
304
+ pooled_log_sigma = sum(o.log_sigma for o in outputs) / total_weight
305
+
306
+ # Handle tau with different head counts
307
+ if all(o.tau is not None for o in outputs):
308
+ # Take mean across heads for each adapter, then average
309
+ tau_values = [o.tau.mean().item() for o in outputs]
310
+ pooled_tau_value = sum(tau_values) / total_weight
311
+ # Create scalar tensor on same device
312
+ pooled_tau = torch.tensor(pooled_tau_value, device=outputs[0].tau.device)
313
+ else:
314
+ pooled_tau = None
315
+
316
+ pooled_g_pred = sum(o.g_pred for o in outputs) / total_weight if outputs[0].g_pred is not None else None
317
+ pooled_gate = sum(o.gate for o in outputs) / total_weight
318
+
319
+ # Pool attention weights if available - handle different head counts
320
+ pooled_attn_c2m = None
321
+ pooled_attn_m2c = None
322
+ if all(o.attn_c2m is not None for o in outputs):
323
+ # First, average across heads for each adapter to get [batch, seq_c, seq_m]
324
+ attn_c2m_list = []
325
+ attn_m2c_list = []
326
+
327
+ for o in outputs:
328
+ # Average across heads dimension
329
+ attn_c2m_avg = o.attn_c2m.mean(dim=1) # [batch, seq_c, seq_m]
330
+ attn_m2c_avg = o.attn_m2c.mean(dim=1) # [batch, seq_m, seq_c]
331
+ attn_c2m_list.append(attn_c2m_avg)
332
+ attn_m2c_list.append(attn_m2c_avg)
333
+
334
+ # Now average across adapters
335
+ pooled_attn_c2m = sum(attn_c2m_list) / total_weight
336
+ pooled_attn_m2c = sum(attn_m2c_list) / total_weight
337
+
338
+ # Add back a dummy heads dimension for compatibility
339
+ pooled_attn_c2m = pooled_attn_c2m.unsqueeze(1) # [batch, 1, seq_c, seq_m]
340
+ pooled_attn_m2c = pooled_attn_m2c.unsqueeze(1) # [batch, 1, seq_m, seq_c]
341
+
342
+ return AdapterOutput(
343
+ anchor=pooled_anchor,
344
+ delta=pooled_delta,
345
+ log_sigma=pooled_log_sigma,
346
+ tau=pooled_tau,
347
+ g_pred=pooled_g_pred,
348
+ gate=pooled_gate,
349
+ adapter_type=outputs[0].adapter_type,
350
+ slice_range=outputs[0].slice_range,
351
+ attn_c2m=pooled_attn_c2m,
352
+ attn_m2c=pooled_attn_m2c
353
+ )
354
+
355
+ @staticmethod
356
+ def conditioning_set_values(conditioning, values={}, append=False):
357
+ """
358
+ Set values in conditioning based on provided values.
359
+ Original set values was provided by comfyui node_helpers.py
360
+
361
+ """
362
+ c = []
363
+ for t in conditioning:
364
+ n = [t[0], t[1].copy()]
365
+ for k in values:
366
+ val = values[k]
367
+ if append:
368
+ old_val = n[1].get(k, None)
369
+ if old_val is not None:
370
+ val = old_val + val
371
+
372
+ n[1][k] = val
373
+ c.append(n)
374
+
375
+ return
376
+
377
+ @staticmethod
378
+ def conditioning_set_strength(conditioning, cond_strength: float, pool_strength: float = 1.0):
379
+ """
380
+ Set strength in conditioning based on provided strength - we need to manually modify instead of setting values.
381
+ [ [base_tensor, { "pooled_outputs": pool, ... other dict entries } ], ... ]
382
+ """
383
+ c = []
384
+ for t in conditioning:
385
+ base_tensor = t[0].copy()
386
+ # Set our usage strength, then find out if we have pooled outputs
387
+ base_tensor *= cond_strength
388
+ kwarg_dict = t[1].clone() if t[1] is not None else {} # copies the config params for later use
389
+
390
+ # lets get and remove the pooled outputs if they exist
391
+ pooled: Optional[None | torch.Tensor] = kwarg_dict.get("pooled_outputs", None)
392
+ if pooled is not None:
393
+ del kwarg_dict["pooled_outputs"]
394
+ pooled = pooled.clone()
395
+ # If we have pooled outputs, apply the pooled strength
396
+ pooled *= pool_strength
397
+ kwarg_dict["pooled_outputs"] = pooled
398
+
399
+ c.append([base_tensor, kwarg_dict])
400
+
401
+
402
+