File size: 16,296 Bytes
c98544f
 
 
 
 
 
d3479d5
c98544f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
import torch
import numpy as np
import logging
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass

from two_stream_shunt_adapter import ConditionModulationShuntAdapter, reshape_for_shunt

logger = logging.getLogger(__name__)

@dataclass
class ShiftConfig:
    """Unified configuration for all modifications"""
    prompt: str = ""
    seed: int = -1  # -1 means no seed, use random
    strength: float = 1.0
    delta_mean: float = 0.0
    delta_scale: float = 1.0
    sigma_scale: float = 0.0
    gate_probability: float = 1.0
    gate_threshold: float = 0.1
    noise_injection: float = 0.0
    use_anchor: bool = True
    pool_method: str = "sequential"  # "sequential" or "weighted_average"
    # Top-K parameters
    use_topk: bool = False
    topk_percentage: float = 100.0  # Percentage of tokens to keep
    tau_temperature: float = 1.0  # Temperature scaling for tau
    topk_mode: str = "attention"  # "attention", "gate", "combined", "tau_softmax"
    guidance_scale: float = 1.0,
    max_tokens: int = 77  # Maximum number of tokens to process


@dataclass
class AdapterOutput:
    """Raw output from adapter forward pass"""
    anchor: torch.Tensor
    delta: torch.Tensor  # Note: already has gate multiplied in!
    log_sigma: torch.Tensor
    tau: torch.Tensor
    g_pred: torch.Tensor
    gate: torch.Tensor
    adapter_type: str
    slice_range: Tuple[int, int]
    # Add attention weights for top-k
    attn_c2m: Optional[torch.Tensor] = None
    attn_m2c: Optional[torch.Tensor] = None


class ConditioningShifter:
    @staticmethod
    def extract_encoder_embeddings(
        encoder_pipe: Dict[str, Any],
        device: torch.device,
        shift_config: Optional[ShiftConfig | dict[str, Any]] = None,
        sampler_cfg: Dict[str, Any] = None
    ) -> torch.Tensor:
        """
        1) Clean prompt of any shunt tokens
        2) Tokenize + encode via T5/BERT
        3) Optionally project to sampler_cfg['projection_dims_in']
        """
        # 1) prompt cleanup
        if isinstance(shift_config, dict):
            shift_config = ShiftConfig(**shift_config)
        raw_prompt = shift_config.prompt
        prompt = raw_prompt#RemoveSpecialTokens.remove_special_tokens(raw_prompt)

        # 2) tokenize & encode
        tokenizer = encoder_pipe["tokenizer"]
        model     = encoder_pipe["model"]
        cfg       = encoder_pipe["config"]["config"]  # your existing mini‐config

        tokens = tokenizer(
            prompt,
            return_tensors="pt",
            padding=cfg.get("padding","max_length"),
            truncation=True,
            max_length=cfg.get("max_tokens",cfg.get("max_length", 512)),
        )
        input_ids      = tokens["input_ids"].to(device)
        attention_mask = tokens["attention_mask"].to(device)

        with torch.no_grad():
            model.to(device)
            mtype = encoder_pipe["config"].get("model_type","")
            if "t5" in mtype:
                embeddings = model.encoder(input_ids=input_ids,
                                           attention_mask=attention_mask
                ).last_hidden_state
            elif mtype in ("bert","nomic_bert"):
                embeddings = model(input_ids=input_ids,
                                   attention_mask=attention_mask,
                                   return_dict=True
                ).last_hidden_state
            else:
                raise ValueError(f"Unsupported encoder type {mtype!r}")
            model.to("cpu")  # free GPU memory

        # 3) optional input‐projection to match CLIP dims
        if sampler_cfg and sampler_cfg.get("force_projection_in", False):
            target_dims = sampler_cfg["projection_dims_in"]
            embeddings = ConditioningShifter._project_embeddings(
                embeddings, target_dims, sampler_cfg["interpolation_method_in"]
            )

        return embeddings.to(device)


    @staticmethod
    def _project_embeddings(
        embeddings: torch.Tensor,
        target_dim: int,
        mode: str
    ) -> torch.Tensor:
        """
        Interpolate the last dimension from D→target_dim via F.interpolate,
        preserving batch & sequence dims.
        """
        B, T, D = embeddings.shape
        if D == target_dim:
            return embeddings

        # [B*T, 1, D] → interpolate → [B*T, 1, target_dim] → back to [B,T,target_dim]
        flat = embeddings.reshape(B*T, 1, D)
        proj = torch.nn.functional.interpolate(
            flat.float(),
            size=target_dim,
            mode=mode,
            align_corners=(mode in {"linear","bilinear","trilinear"})
        )
        return proj.reshape(B, T, target_dim)

    @staticmethod
    def run_adapter(adapter_model: ConditionModulationShuntAdapter,
                    encoder_embeddings: torch.Tensor,
                    clip_slice: torch.Tensor,
                    guidance_scale: float,
                    adapter_type: str,
                    slice_range: Tuple[int, int]) -> AdapterOutput:
        """Run adapter and package output"""
        gen_config = {"max_guidance": guidance_scale if guidance_scale > 0 else 1.0}

        #encoder_embeddings, clip_slice = reshape_for_shunt(encoder_embeddings, clip_slice, adapter_model)

        with torch.no_grad():
            outputs = adapter_model(encoder_embeddings.float(), clip_slice.float(), config=gen_config)

            if isinstance(outputs, tuple) and len(outputs) == 8:
                anchor, delta, log_sigma, attn_c2m, attn_m2c, tau, g_pred, gate = outputs
                return AdapterOutput(
                    anchor=anchor,
                    delta=delta,  # Already has gate multiplied!
                    log_sigma=log_sigma,
                    tau=tau,
                    g_pred=g_pred,
                    gate=gate,
                    adapter_type=adapter_type,
                    slice_range=slice_range,
                    attn_c2m=attn_c2m,
                    attn_m2c=attn_m2c
                )
            else:
                raise ValueError(f"Unexpected adapter output format: {type(outputs)}")

    @staticmethod
    def apply_topk_selection(output: AdapterOutput, config: ShiftConfig) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply top-k selection using tau and attention weights.
        Returns mask and selection scores for CLIP tokens.
        """
        if not config.use_topk:
            # Return full mask matching gate dimensions
            return torch.ones_like(output.gate.squeeze(-1)), None

        # Calculate selection scores based on mode
        if config.topk_mode == "attention":
            # Use modulation->condition attention (how much each CLIP token attends to encoder)
            # Sum across encoder dimension to get importance score per CLIP token
            scores = output.attn_m2c.mean(dim=1).sum(dim=-1)  # [batch, seq_clip]
        elif config.topk_mode == "attention_collaborative":
            # Use modulation->condition attention (how much each CLIP token attends to encoder)
            # Sum across encoder dimension to get importance score per CLIP token
            # compare and normalize using the c2m attention as a soft mask
            scores = output.attn_m2c.mean(dim=1).sum(dim=-1)
            c2m_scores = output.attn_c2m.mean(dim=1).sum(dim=-1)  # [batch, seq_clip]
            # soft mask weaken and strengthen scores based on c2m_scores
            scores = (scores - c2m_scores.min()) / (c2m_scores.max() - c2m_scores.min() + 1e-8)


        elif config.topk_mode == "gate":
            # Use gate values directly (already in CLIP space)
            scores = output.gate.squeeze(-1)  # [batch, seq_clip]

        elif config.topk_mode == "combined":
            # Combine attention and gate scores
            attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1)  # [batch, seq_clip]
            gate_score = output.gate.squeeze(-1)

            # Normalize and combine
            attn_score = (attn_score - attn_score.min()) / (attn_score.max() - attn_score.min() + 1e-8)
            gate_score = (gate_score - gate_score.min()) / (gate_score.max() - gate_score.min() + 1e-8)

            scores = (attn_score + gate_score) / 2

        elif config.topk_mode == "tau_softmax":
            # Use tau as temperature for softmax selection
            attn_score = output.attn_m2c.mean(dim=1).sum(dim=-1)  # [batch, seq_clip]

            # Apply tau temperature scaling
            tau_value = output.tau.mean().item() * config.tau_temperature
            scores = torch.nn.functional.softmax(attn_score / tau_value, dim=-1)
        else:
            scores = output.gate.squeeze(-1)

        # Calculate k
        k = int(scores.size(-1) * (config.topk_percentage / 100.0))
        k = max(1, min(k, scores.size(-1)))

        # Get top-k indices
        topk_values, topk_indices = torch.topk(scores, k, dim=-1)

        # Create sparse mask
        mask = torch.zeros_like(scores)
        mask.scatter_(-1, topk_indices, 1.0)

        return mask, scores

    @staticmethod
    def apply_modifications(clip_slice: torch.Tensor, outputs: List[AdapterOutput],
                            config: ShiftConfig) -> torch.Tensor:
        """Apply modifications based on config.pool_method"""
        torch.manual_seed(config.seed if config.seed >= 0 else torch.randint(0, 2**32, (1,)).item())

        modified = clip_slice.clone()
        if config.pool_method == "sequential":
            # Apply each adapter sequentially
            for output in outputs:
                modified = ConditioningShifter._apply_single(modified, output, config)
            return modified

        elif config.pool_method == "weighted_average":
            # Pool all adapters then apply once
            if len(outputs) == 1:
                return ConditioningShifter._apply_single(modified, outputs[0], config)

            pooled = ConditioningShifter._pool_outputs(outputs)
            return ConditioningShifter._apply_single(clip_slice, pooled, config)

        else:
            raise ValueError(f"Unknown pool_method: {config.pool_method}")

    @staticmethod
    def _apply_single(clip_slice: torch.Tensor, output: AdapterOutput,
                      config: ShiftConfig) -> torch.Tensor:
        """Apply a single adapter output with optional top-k selection"""

        # Apply top-k selection if enabled
        topk_mask, scores = ConditioningShifter.apply_topk_selection(output, config)

        # Preprocess (but remember delta already has gate!)
        delta = output.delta * config.delta_scale + config.delta_mean

        gate_scaled = output.gate * config.gate_probability
        gate_mask = (gate_scaled > config.gate_threshold).float()
        gate_masked = gate_scaled * gate_mask

        # Apply top-k mask to gate and delta
        if config.use_topk:
            # Expand mask to match dimensions
            topk_mask_expanded = topk_mask.unsqueeze(-1)
            gate_masked = gate_masked * topk_mask_expanded
            delta = delta * topk_mask_expanded

        # Apply strength
        delta_final = delta

        # Apply based on anchor mode
        if config.use_anchor:
            # Blend original with anchor, then add delta
            blended = clip_slice * (1 - gate_masked) + output.anchor * gate_masked
            clip_modified = blended + delta_final
        else:
            # Simple additive
            clip_modified = clip_slice + delta_final

        # Apply noise
        if config.sigma_scale > 0 and config.noise_injection > 0:
            sigma = torch.exp(output.log_sigma * config.sigma_scale)
            clip_modified += torch.randn_like(clip_modified) * sigma * config.noise_injection
        elif config.noise_injection > 0:
            clip_modified += torch.randn_like(clip_modified) * config.noise_injection

        return clip_modified

    @staticmethod
    def _pool_outputs(outputs: List[AdapterOutput]) -> AdapterOutput:
        """Pool multiple adapter outputs into one"""
        # Simple weighted average
        total_weight = len(outputs)

        pooled_anchor = sum(o.anchor for o in outputs) / total_weight
        pooled_delta = sum(o.delta for o in outputs) / total_weight
        pooled_log_sigma = sum(o.log_sigma for o in outputs) / total_weight

        # Handle tau with different head counts
        if all(o.tau is not None for o in outputs):
            # Take mean across heads for each adapter, then average
            tau_values = [o.tau.mean().item() for o in outputs]
            pooled_tau_value = sum(tau_values) / total_weight
            # Create scalar tensor on same device
            pooled_tau = torch.tensor(pooled_tau_value, device=outputs[0].tau.device)
        else:
            pooled_tau = None

        pooled_g_pred = sum(o.g_pred for o in outputs) / total_weight if outputs[0].g_pred is not None else None
        pooled_gate = sum(o.gate for o in outputs) / total_weight

        # Pool attention weights if available - handle different head counts
        pooled_attn_c2m = None
        pooled_attn_m2c = None
        if all(o.attn_c2m is not None for o in outputs):
            # First, average across heads for each adapter to get [batch, seq_c, seq_m]
            attn_c2m_list = []
            attn_m2c_list = []

            for o in outputs:
                # Average across heads dimension
                attn_c2m_avg = o.attn_c2m.mean(dim=1)  # [batch, seq_c, seq_m]
                attn_m2c_avg = o.attn_m2c.mean(dim=1)  # [batch, seq_m, seq_c]
                attn_c2m_list.append(attn_c2m_avg)
                attn_m2c_list.append(attn_m2c_avg)

            # Now average across adapters
            pooled_attn_c2m = sum(attn_c2m_list) / total_weight
            pooled_attn_m2c = sum(attn_m2c_list) / total_weight

            # Add back a dummy heads dimension for compatibility
            pooled_attn_c2m = pooled_attn_c2m.unsqueeze(1)  # [batch, 1, seq_c, seq_m]
            pooled_attn_m2c = pooled_attn_m2c.unsqueeze(1)  # [batch, 1, seq_m, seq_c]

        return AdapterOutput(
            anchor=pooled_anchor,
            delta=pooled_delta,
            log_sigma=pooled_log_sigma,
            tau=pooled_tau,
            g_pred=pooled_g_pred,
            gate=pooled_gate,
            adapter_type=outputs[0].adapter_type,
            slice_range=outputs[0].slice_range,
            attn_c2m=pooled_attn_c2m,
            attn_m2c=pooled_attn_m2c
        )

    @staticmethod
    def conditioning_set_values(conditioning, values={}, append=False):
        """
        Set values in conditioning based on provided values.
        Original set values was provided by comfyui node_helpers.py

        """
        c = []
        for t in conditioning:
            n = [t[0], t[1].copy()]
            for k in values:
                val = values[k]
                if append:
                    old_val = n[1].get(k, None)
                    if old_val is not None:
                        val = old_val + val

                n[1][k] = val
            c.append(n)

        return

    @staticmethod
    def conditioning_set_strength(conditioning, cond_strength: float, pool_strength: float = 1.0):
        """
        Set strength in conditioning based on provided strength - we need to manually modify instead of setting values.
            [    [base_tensor, { "pooled_outputs": pool, ... other dict entries } ], ...    ]
        """
        c = []
        for t in conditioning:
            base_tensor = t[0].copy()
            # Set our usage strength, then find out if we have pooled outputs
            base_tensor *= cond_strength
            kwarg_dict = t[1].clone() if t[1] is not None else {} # copies the config params for later use

            # lets get and remove the pooled outputs if they exist
            pooled: Optional[None | torch.Tensor] = kwarg_dict.get("pooled_outputs", None)
            if pooled is not None:
                del kwarg_dict["pooled_outputs"]
                pooled = pooled.clone()
                # If we have pooled outputs, apply the pooled strength
                pooled *= pool_strength
                kwarg_dict["pooled_outputs"] = pooled

            c.append([base_tensor, kwarg_dict])