AbstractPhil commited on
Commit
788f431
Β·
verified Β·
1 Parent(s): 1fe00bb

Update two_stream_shunt_adapter.py

Browse files
Files changed (1) hide show
  1. two_stream_shunt_adapter.py +376 -81
two_stream_shunt_adapter.py CHANGED
@@ -1,115 +1,410 @@
1
- # adapter_v2.py ────────────────────────────────────────────────────────────
2
- import torch, math
 
3
  import torch.nn as nn
4
- import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
6
 
7
- # ─── Residual pocket block ────────────────────────────────────────────────
8
- class PocketBlock(nn.Module):
 
 
 
 
 
 
 
 
9
  def __init__(self, dim, kernel=3, dropout=0.0):
10
  super().__init__()
11
- self.body = nn.Sequential(
12
- nn.LayerNorm(dim),
13
- nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
 
14
  nn.GELU(),
15
- nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
16
- nn.Dropout(dropout),
17
  )
18
 
19
  def forward(self, x):
20
- y = self.body(x.transpose(1, 2)).transpose(1, 2)
21
- return x + y
 
 
 
22
 
23
 
24
- # ─── adapter ──────────────────────────────────────────────────────────────
25
- class TwoStreamShuntAdapter(nn.Module):
26
- """T5-seq βž” bottleneck ⇄ CLIP-seq β†’ anchor / delta / Οƒ …"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def __init__(self, cfg: dict):
 
 
 
29
  super().__init__()
30
- self.cfg = cfg
31
- hid_t5 = cfg["t5"]["hidden_size"]
32
- hid_clip = cfg["clip"]["hidden_size"]
33
- bneck = cfg["bottleneck"]
34
- heads = cfg["heads"]
35
- proj_layers = cfg.get("proj_layers", 2)
36
- use_norm = cfg.get("layer_norm", True)
37
- p_drop = cfg.get("dropout", 0.0)
38
- pocket_depth = cfg.get("pocket_depth", 2)
39
-
40
- # helper ----------------------------------------------------------------
41
- def proj(in_d, out_d):
42
- layers, d = [], in_d
43
- for i in range(proj_layers):
44
- if use_norm:
45
- layers.append(nn.LayerNorm(d))
46
- layers += [nn.Linear(d, bneck if i == proj_layers - 1 else bneck * 2),
47
- nn.GELU()]
48
- if p_drop: layers.append(nn.Dropout(p_drop))
49
- d = bneck
 
 
 
 
 
 
 
50
  return nn.Sequential(*layers)
51
 
52
- # projections -----------------------------------------------------------
53
- self.t5_in = proj(hid_t5, bneck)
54
- self.clip_in = proj(hid_clip, bneck)
55
 
56
- # bidirectional cross-attention ----------------------------------------
57
- self.attn_t2c = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
58
- self.attn_c2t = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
59
- self.tau = nn.Parameter(torch.full((heads, 1, 1), cfg.get("tau_init", 1.0)))
60
 
61
- # pocket stack ----------------------------------------------------------
62
- self.pocket = nn.Sequential(*[PocketBlock(bneck, dropout=p_drop) for _ in range(pocket_depth)])
 
 
 
63
 
64
- # fuse bottleneck β†’ bneck ----------------------------------------------
65
  self.fuse = nn.Sequential(
66
- nn.LayerNorm(bneck * 2),
67
- nn.Linear(bneck * 2, bneck * 2),
68
  nn.GELU(),
69
- nn.Linear(bneck * 2, bneck)
70
  )
71
 
72
- # head projections ------------------------------------------------------
73
- self.anchor_out = proj(bneck, hid_clip)
74
- self.delta_out = proj(bneck, hid_clip)
75
- self.sigma_out = proj(bneck, hid_clip) # log Οƒ
76
 
77
- self.gate_guid_proj = nn.Sequential(
78
- nn.LayerNorm(bneck),
79
- nn.Linear(bneck, bneck),
80
  nn.GELU(),
81
- nn.Linear(bneck, 2), # [:, :, 0] β†’ gate, [:, :, 1] β†’ g_pred
 
 
82
  )
83
 
84
- self.max_guidance = cfg.get("max_guidance", 2.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # --- forward --------------------------------------------------------------
87
- def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
88
- assert t5_seq.size(-1) == self.cfg["t5"]["hidden_size"]
89
- assert clip_seq.size(-1) == self.cfg["clip"]["hidden_size"]
90
 
91
- t5_b = self.t5_in(t5_seq)
92
- clip_b = self.clip_in(clip_seq)
93
 
94
- t2c, attn_t2c = self.attn_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
95
- c2t, attn_c2t = self.attn_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
 
 
 
 
 
 
 
96
 
97
- p = self.pocket(t2c)
98
- z = torch.cat([p.mean(1, keepdim=True).expand_as(c2t), c2t], dim=-1)
99
- h = self.fuse(z)
100
 
101
- anchor = self.anchor_out(h)
102
- delta = self.delta_out(h)
 
 
 
 
 
 
103
 
104
- log_sigma = self.sigma_out(h)
105
 
106
- gate_and_g = self.gate_guid_proj(h)
107
- gate = torch.sigmoid(gate_and_g[..., 0:1])
108
- g_pred = torch.clamp(gate_and_g[..., 1:2].mean(1, keepdim=True),
109
- 0, self.max_guidance)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- return (anchor, delta, log_sigma,
112
- attn_t2c, attn_c2t,
113
- self.tau,
114
- g_pred,
115
- gate)
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
  import torch.nn as nn
5
+ from .configs import ENCODER_CONFIGS, HARMONIC_SHUNT_REPOS
6
+
7
+
8
+ class DualConversionNames:
9
+ """
10
+ Mapping from legacy dual adapter layer names to updated
11
+ condition/modulation schema. Also supports delta/gate harmonization.
12
+ """
13
+ LAYER_NAMES = {
14
+ # Projection remapping
15
+ "t5_proj": "condition_projection",
16
+ "clip_proj": "modulation_projection",
17
+
18
+ # Cross attention
19
+ "cross_t2c": "cross_c2m", # condition to modulation
20
+ "cross_c2t": "cross_m2c", # modulation to condition
21
+
22
+ # Output projections
23
+ "anchor_proj": "anchor_projection",
24
+ "delta_proj": "delta_projection",
25
+ "logsig_proj": "log_sigma_projection",
26
 
27
+ # Gate and guidance
28
+ "gate_proj": "gate_projection",
29
+ "guidance_proj": "guidance_projection",
30
 
31
+ # Fuse block
32
+ "fuse": "fusion_block",
33
+
34
+ # Pocket residual
35
+ "pocket_blocks": "residual_pocket_block"
36
+ }
37
+
38
+
39
+ # ─── Residual Pocket Block ───────────────────────────────────
40
+ class BottleneckResBlock(nn.Module):
41
  def __init__(self, dim, kernel=3, dropout=0.0):
42
  super().__init__()
43
+ self.norm = nn.LayerNorm(dim)
44
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)
45
+ self.proj = nn.Sequential(
46
+ nn.Linear(dim, dim * 2),
47
  nn.GELU(),
48
+ nn.Linear(dim * 2, dim),
49
+ nn.Dropout(dropout)
50
  )
51
 
52
  def forward(self, x):
53
+ residual = x
54
+ x = self.norm(x)
55
+ x = x.transpose(1, 2)
56
+ x = self.conv(x).transpose(1, 2)
57
+ return residual + self.proj(x)
58
 
59
 
60
+ class ConditionModulationShuntAdapter(nn.Module):
61
+ def __init__(self, config: dict):
62
+ super().__init__()
63
+ self.config = config
64
+ self.dtype = config.get("dtype", torch.float32)
65
+ self.condition_dim = config.get("condition_encoders", [])[0].get("hidden_size", 768)
66
+ self.modulation_dim = config.get("modulation_encoders", [])[0].get("hidden_size", 768)
67
+ self.bneck = config["bottleneck"]
68
+ self.heads = config["heads"]
69
+ self.tau_init = config["tau_init"]
70
+ self.max_guidance = config["max_guidance"]
71
+
72
+ use_norm = config.get("layer_norm", True)
73
+ use_do = config.get("use_dropout", True)
74
+ do_p = config.get("dropout", 0.0)
75
+ proj_depth = config.get("proj_layers", 2)
76
+
77
+ def build_projection(input_dim, output_dim):
78
+ layers = []
79
+ last_dim = input_dim
80
+ if use_norm:
81
+ layers.append(nn.LayerNorm(last_dim))
82
+ for i in range(proj_depth):
83
+ next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)
84
+ layers.append(nn.Linear(last_dim, next_dim))
85
+ layers.append(nn.GELU())
86
+ if use_do:
87
+ layers.append(nn.Dropout(do_p))
88
+ last_dim = next_dim
89
+ layers.append(nn.Linear(last_dim, output_dim))
90
+ return nn.Sequential(*layers)
91
+
92
+ # Projection layers
93
+ self.condition_projection = build_projection(self.condition_dim, self.bneck)
94
+ self.modulation_projection = build_projection(self.modulation_dim, self.bneck)
95
+
96
+ # Cross attention blocks
97
+ self.cross_c2m = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
98
+ self.cross_m2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
99
+ self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))
100
+
101
+ # Residual processing block
102
+ self.residual_pocket_block = nn.Sequential(
103
+ BottleneckResBlock(self.bneck, dropout=do_p),
104
+ BottleneckResBlock(self.bneck, dropout=do_p)
105
+ )
106
+
107
+ # Fusion pathway
108
+ self.fusion_block = nn.Sequential(
109
+ nn.LayerNorm(2 * self.bneck),
110
+ nn.Linear(2 * self.bneck, self.bneck * 2),
111
+ nn.GELU(),
112
+ nn.Linear(self.bneck * 2, self.bneck)
113
+ )
114
+
115
+ # Output projections
116
+ self.anchor_projection = build_projection(self.bneck, self.modulation_dim)
117
+ self.delta_projection = build_projection(self.bneck, self.modulation_dim)
118
+ self.log_sigma_projection = build_projection(self.bneck, self.modulation_dim)
119
+
120
+ # Gate and guidance
121
+ self.gate_projection = nn.Sequential(
122
+ nn.LayerNorm(self.bneck),
123
+ nn.Linear(self.bneck, self.bneck),
124
+ nn.GELU(),
125
+ nn.Linear(self.bneck, 1),
126
+ nn.Tanh(),
127
+ nn.Sigmoid()
128
+ )
129
+ self.guidance_projection = nn.Sequential(
130
+ nn.LayerNorm(self.bneck),
131
+ nn.Linear(self.bneck, 1),
132
+ nn.Sigmoid()
133
+ )
134
+
135
+ # ─── Legacy Aliases (Version 1 Compatibility) ──────────────────────────
136
+ self.proj_t5 = self.condition_projection
137
+ self.proj_clip = self.modulation_projection
138
+ self.cross_t2c = self.cross_c2m
139
+ self.cross_c2t = self.cross_m2c
140
+ self.pocket_blocks = self.residual_pocket_block
141
+ self.fuse = self.fusion_block
142
+ self.anchor_proj = self.anchor_projection
143
+ self.delta_proj = self.delta_projection
144
+ self.logsig_proj = self.log_sigma_projection
145
+ self.gate_proj = self.gate_projection
146
+ self.guidance_proj = self.guidance_projection
147
+
148
+ def forward(self, cond_seq: torch.Tensor, mod_seq: torch.Tensor, config: dict = None):
149
+ if self.config.get("assert_input_dims", True):
150
+ assert cond_seq.size(-1) == self.condition_dim
151
+ assert mod_seq.size(-1) == self.modulation_dim
152
+
153
+ max_guidance = self.max_guidance if config is None else config.get("max_guidance", 0.0)
154
+ if max_guidance <= 0:
155
+ max_guidance = self.max_guidance
156
+ if max_guidance <= 0:
157
+ max_guidance = config.get("guidance_scale", 10.0)
158
+
159
+ cond_b = self.condition_projection(cond_seq)
160
+ mod_b = self.modulation_projection(mod_seq)
161
+
162
+ c2m, attn_c2m = self.cross_c2m(cond_b, mod_b, mod_b, need_weights=True, average_attn_weights=False)
163
+ m2c, attn_m2c = self.cross_m2c(mod_b, cond_b, cond_b, need_weights=True, average_attn_weights=False)
164
+
165
+ pocket = self.residual_pocket_block(c2m)
166
+ pocket_mean = pocket.mean(1, keepdim=True).expand(-1, mod_b.size(1), -1)
167
+
168
+ h = self.fusion_block(torch.cat([pocket_mean, m2c], dim=-1))
169
+
170
+ anchor = self.anchor_projection(h)
171
+ delta = self.delta_projection(h) * self.gate_projection(h)
172
+ log_sigma = self.log_sigma_projection(h)
173
+
174
+ g_tok = self.guidance_projection(h).squeeze(-1)
175
+ g_pred = g_tok.mean(1, keepdim=True) * max_guidance
176
+
177
+ return anchor, delta, log_sigma, attn_c2m, attn_m2c, self.tau, g_pred, self.gate_projection(h)
178
 
179
+
180
+ # ─── V1 Original Two Stream Shunt Adapter ──────────────────────────────────────
181
+ class TwoStreamShuntAdapter(nn.Module):
182
+ def __init__(self, config: dict):
183
  super().__init__()
184
+ self.config = config
185
+ self.dtype = config.get("dtype", torch.float32)
186
+ self.t5_dim = config.get("condition_encoders", [])[0].get("hidden_size", 768)
187
+ self.clip_dim = config.get("modulation_encoders", [])[0].get("hidden_size", 768)
188
+ self.bneck = config["bottleneck"]
189
+ self.heads = config["heads"]
190
+ self.tau_init = config["tau_init"]
191
+ self.max_guidance = config["max_guidance"]
192
+
193
+ use_norm = config.get("layer_norm", True)
194
+ use_do = config.get("use_dropout", True)
195
+ do_p = config.get("dropout", 0.0)
196
+ proj_depth = config.get("proj_layers", 2)
197
+
198
+ def build_projection(input_dim, output_dim):
199
+ layers = []
200
+ last_dim = input_dim
201
+ if use_norm:
202
+ layers.append(nn.LayerNorm(last_dim))
203
+ for i in range(proj_depth):
204
+ next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)
205
+ layers.append(nn.Linear(last_dim, next_dim))
206
+ layers.append(nn.GELU())
207
+ if use_do:
208
+ layers.append(nn.Dropout(do_p))
209
+ last_dim = next_dim
210
+ layers.append(nn.Linear(last_dim, output_dim))
211
  return nn.Sequential(*layers)
212
 
213
+ # Projections
214
+ self.proj_t5 = build_projection(self.t5_dim, self.bneck)
215
+ self.proj_clip = build_projection(self.clip_dim, self.bneck)
216
 
217
+ # Attention
218
+ self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
219
+ self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
220
+ self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))
221
 
222
+ # Residual Pocket
223
+ self.pocket_blocks = nn.Sequential(
224
+ BottleneckResBlock(self.bneck, dropout=do_p),
225
+ BottleneckResBlock(self.bneck, dropout=do_p)
226
+ )
227
 
228
+ # Fuse
229
  self.fuse = nn.Sequential(
230
+ nn.LayerNorm(2 * self.bneck),
231
+ nn.Linear(2 * self.bneck, self.bneck * 2),
232
  nn.GELU(),
233
+ nn.Linear(self.bneck * 2, self.bneck)
234
  )
235
 
236
+ # Output Projections
237
+ self.anchor_proj = build_projection(self.bneck, self.clip_dim)
238
+ self.delta_proj = build_projection(self.bneck, self.clip_dim)
239
+ self.logsig_proj = build_projection(self.bneck, self.clip_dim)
240
 
241
+ self.gate_proj = nn.Sequential(
242
+ nn.LayerNorm(self.bneck),
243
+ nn.Linear(self.bneck, self.bneck),
244
  nn.GELU(),
245
+ nn.Linear(self.bneck, 1),
246
+ nn.Tanh(),
247
+ nn.Sigmoid()
248
  )
249
 
250
+ self.guidance_proj = nn.Sequential(
251
+ nn.LayerNorm(self.bneck),
252
+ nn.Linear(self.bneck, 1),
253
+ nn.Sigmoid()
254
+ )
255
+
256
+ def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor, config: dict = None):
257
+ if self.config.get("assert_input_dims", True):
258
+ assert t5_seq.size(-1) == self.t5_dim
259
+ assert clip_seq.size(-1) == self.clip_dim
260
+
261
+ max_guidance = self.max_guidance if config is None else config.get("max_guidance", 0.0)
262
+ if max_guidance <= 0:
263
+ max_guidance = self.max_guidance
264
+ if max_guidance <= 0:
265
+ max_guidance = 10
266
+ max_guidance = config.get("guidance_scale", 5.0)
267
+
268
+ t5_b = self.proj_t5(t5_seq)
269
+ clip_b = self.proj_clip(clip_seq)
270
+
271
+ t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
272
+ c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
273
+
274
+ pocket = self.pocket_blocks(t2c)
275
+
276
+ pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)
277
+ h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))
278
+
279
+ anchor = self.anchor_proj(h)
280
+ delta = self.delta_proj(h) * self.gate_proj(h)
281
+ log_sigma = self.logsig_proj(h)
282
+
283
+ g_tok = self.guidance_proj(h).squeeze(-1)
284
+ g_pred = g_tok.mean(1, keepdim=True) * max_guidance
285
+
286
+ return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)
287
+
288
+
289
 
 
 
 
 
290
 
291
+ from safetensors.torch import save_file, load_file
 
292
 
293
+ def save_safetensors(adapter: nn.Module, path: str, metadata: dict = None):
294
+ """
295
+ Save the current adapter state to safetensors format.
296
+ All tensors are moved to CPU and saved as float32 for compatibility.
297
+ Optional metadata may be embedded (e.g., version, prompt_mode).
298
+ """
299
+ state = {k: v.float().cpu() for k, v in adapter.state_dict().items()}
300
+ save_file(state, path, metadata=metadata or {})
301
+ print(f"βœ… Model saved to {path}")
302
 
 
 
 
303
 
304
+ def load_safetensors(adapter: nn.Module, path: str, map_location="cpu"):
305
+ """
306
+ Load a safetensors checkpoint into the adapter.
307
+ Uses strict key matching. Tensors are loaded to the specified device.
308
+ """
309
+ state = load_file(path, device=map_location)
310
+ adapter.load_state_dict(state, strict=True)
311
+ print(f"βœ… Model loaded from {path}")
312
 
 
313
 
314
+ def load_converted_safetensors(adapter: nn.Module, path: str, map_location="cpu"):
315
+ """
316
+ Load a legacy-format adapter into the updated dual-shunt schema.
317
+ Converts key names according to DualConversionNames mapping.
318
+ """
319
+ state = load_file(path, device=map_location)
320
+ new_state = {}
321
+
322
+ rename_map = DualConversionNames.LAYER_NAMES
323
+ matched, renamed, skipped = 0, 0, 0
324
+
325
+ for key, tensor in state.items():
326
+ found = False
327
+ for old, new in rename_map.items():
328
+ if old in key:
329
+ new_key = key.replace(old, new)
330
+ new_state[new_key] = tensor
331
+ print(f"[MIGRATE] {key} β†’ {new_key}")
332
+ renamed += 1
333
+ found = True
334
+ break
335
+ if not found:
336
+ if key in adapter.state_dict():
337
+ new_state[key] = tensor
338
+ matched += 1
339
+ else:
340
+ print(f"[SKIP] {key} not found in target adapter.")
341
+ skipped += 1
342
+
343
+ adapter.load_state_dict(new_state, strict=False)
344
+
345
+ print(f"\nβœ… Converted model loaded from {path}")
346
+ print(f" πŸ” Renamed Keys: {renamed}")
347
+ print(f" βœ… Direct Matches: {matched}")
348
+ print(f" ⚠️ Skipped Keys: {skipped}")
349
+
350
+
351
+ def reshape_for_shunt(
352
+ encoder_embeddings: torch.Tensor,
353
+ clip_slice: torch.Tensor,
354
+ adapter_model
355
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
356
+ """
357
+ Ensures encoder_embeddings and clip_slice match the required dimensions
358
+ for adapter_model: [B, adapter_seq, adapter_dim].
359
+
360
+ Applies sequence interpolation and feature projection as needed.
361
+ """
362
+ return encoder_embeddings, clip_slice
363
+ B, encoder_seq, encoder_dim = encoder_embeddings.shape
364
+ B2, clip_seq, clip_dim = clip_slice.shape
365
+
366
+ assert B == B2, "Batch sizes must match"
367
+
368
+ # -- Step 1: Interpolate SEQUENCE LENGTH (dim=1) if needed --
369
+ target_seq = max(adapter_model.condition_dim, adapter_model.modulation_dim)
370
+
371
+ if clip_seq != target_seq:
372
+ clip_slice = clip_slice.permute(0, 0, 2) # [B, C, T]
373
+ clip_slice = torch.nn.functional.interpolate(
374
+ clip_slice.float(),
375
+ size=target_seq,
376
+ mode="nearest"
377
+ )
378
+ clip_slice = clip_slice.permute(0, 0, 2) # [B, T, C]
379
+
380
+ if encoder_seq != target_seq:
381
+ encoder_embeddings = encoder_embeddings.permute(0, 0, 2)
382
+ encoder_embeddings = torch.nn.functional.interpolate(
383
+ encoder_embeddings.float(),
384
+ size=target_seq,
385
+ mode="nearest"
386
+ )
387
+ encoder_embeddings = encoder_embeddings.permute(0, 0, 2)
388
+
389
+ # -- Step 2: Project FEATURE DIMENSION (dim=2) if needed --
390
+ if clip_slice.size(-1) != adapter_model.condition_dim:
391
+ projection_clip = torch.nn.Linear(
392
+ clip_slice.size(-1),
393
+ adapter_model.condition_dim,
394
+ bias=True,
395
+ device=clip_slice.device
396
+ )
397
+ clip_slice = projection_clip(clip_slice)
398
+ del projection_clip
399
+
400
+ if encoder_embeddings.size(-1) != adapter_model.modulation_dim:
401
+ projection_encoder = torch.nn.Linear(
402
+ encoder_embeddings.size(-1),
403
+ adapter_model.modulation_dim,
404
+ bias=True,
405
+ device=encoder_embeddings.device
406
+ )
407
+ encoder_embeddings = projection_encoder(encoder_embeddings)
408
+ del projection_encoder
409
 
410
+ return encoder_embeddings, clip_slice