AbstractPhil commited on
Commit
d3479d5
·
1 Parent(s): c557c56

local project created to properly edit and debug

Browse files
__init__.py ADDED
File without changes
app.py CHANGED
@@ -1,56 +1,49 @@
1
  # app.py ────────────────────────────────────────────────────────────────
2
  import io, warnings, numpy as np, matplotlib.pyplot as plt
 
 
3
  from pathlib import Path
4
- from typing import Dict, List, Optional, Tuple
5
 
6
  import gradio as gr
7
- import torch, torch.nn.functional as F
8
- from PIL import Image
 
9
  from transformers import T5Tokenizer, T5EncoderModel
10
- from diffusers import (
11
- StableDiffusionXLPipeline,
12
- DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler,
13
- )
14
  from huggingface_hub import hf_hub_download
15
  from safetensors.torch import load_file
16
 
17
- # local modules
18
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
19
  from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput
20
- from configs import T5_SHUNT_REPOS
21
 
22
  warnings.filterwarnings("ignore")
23
 
24
- # ─── GLOBALS ────────────────────────────────────────────────────────────
25
- dtype = torch.float16
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
 
28
-
29
- _t5_tok: Optional[T5Tokenizer] = None
30
- _t5_mod: Optional[T5EncoderModel] = None
31
- _pipe : Optional[StableDiffusionXLPipeline] = None
32
 
33
  SCHEDULERS = {
34
  "DPM++ 2M": DPMSolverMultistepScheduler,
35
- "DDIM": DDIMScheduler,
36
- "Euler": EulerDiscreteScheduler,
37
  }
38
 
39
- # adapter-meta from configs.py
40
- clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
41
- clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
42
- repo_l, conf_l = T5_SHUNT_REPOS["clip_l"]["repo"], T5_SHUNT_REPOS["clip_l"]["config"]
43
- repo_g, conf_g = T5_SHUNT_REPOS["clip_g"]["repo"], T5_SHUNT_REPOS["clip_g"]["config"]
44
 
45
-
46
- # ─── INITIALISERS ────────────────────────────────────────────────────────
47
  def _init_t5():
48
  global _t5_tok, _t5_mod
49
  if _t5_tok is None:
50
  _t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
51
- _t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base") \
52
- .to(device).eval()
53
-
54
 
55
  def _init_pipe():
56
  global _pipe
@@ -61,16 +54,15 @@ def _init_pipe():
61
  ).to(device)
62
  _pipe.enable_xformers_memory_efficient_attention()
63
 
64
-
65
- # ─── HELPERS ─────────────────────────────────────────────────────────────
66
- def load_adapter(repo: str, filename: str, cfg: dict,
67
- device: torch.device) -> TwoStreamShuntAdapter:
68
- path = hf_hub_download(repo_id=repo, filename=filename)
69
- model = TwoStreamShuntAdapter(cfg).eval()
70
  model.load_state_dict(load_file(path))
71
  return model.to(device)
72
 
73
-
74
  def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
75
  if isinstance(mat, torch.Tensor):
76
  mat = mat.detach().cpu().numpy()
@@ -90,34 +82,25 @@ def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
90
  plt.close(); buf.seek(0)
91
  return np.array(Image.open(buf))
92
 
93
-
94
  def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
95
- tok_l = pipe.tokenizer (prompt, max_length=77, truncation=True,
96
- padding="max_length", return_tensors="pt").input_ids.to(device)
97
- tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True,
98
- padding="max_length", return_tensors="pt").input_ids.to(device)
99
- ntok_l = pipe.tokenizer (negative,max_length=77, truncation=True,
100
- padding="max_length", return_tensors="pt").input_ids.to(device)
101
- ntok_g = pipe.tokenizer_2(negative,max_length=77, truncation=True,
102
- padding="max_length", return_tensors="pt").input_ids.to(device)
103
 
104
  with torch.no_grad():
105
- clip_l = pipe.text_encoder(tok_l)[0]
106
- neg_clip_l = pipe.text_encoder(ntok_l)[0]
107
-
108
- g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False)
109
- clip_g, pl = g_out[1], g_out[0]
110
- ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
111
  neg_clip_g, npl = ng_out[1], ng_out[0]
112
 
113
- return {"clip_l": clip_l, "clip_g": clip_g,
114
- "neg_l": neg_clip_l, "neg_g": neg_clip_g,
115
- "pooled": pl, "neg_pooled": npl}
116
-
117
 
118
  # ─── INFERENCE ───────────────────────────────────────────────────────────
119
  def infer(prompt: str, negative_prompt: str,
120
- adapter_l_file: str, adapter_g_file: str,
121
  strength: float, delta_scale: float, sigma_scale: float,
122
  gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool,
123
  steps: int, cfg_scale: float, scheduler_name: str,
@@ -129,91 +112,76 @@ def infer(prompt: str, negative_prompt: str,
129
  if scheduler_name in SCHEDULERS:
130
  _pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
131
 
132
- generator = (torch.Generator(device=device).manual_seed(seed)
133
- if seed != -1 else None)
134
 
135
- # build ShiftConfig (one per request)
136
  cfg_shift = ShiftConfig(
137
- prompt = prompt,
138
- seed = seed,
139
- strength = strength,
140
- delta_scale = delta_scale,
141
- sigma_scale = sigma_scale,
142
- gate_probability = gate_prob,
143
- noise_injection = noise,
144
- use_anchor = use_anchor,
145
- guidance_scale = gpred_scale,
146
  )
147
 
148
- # encoder (T5) embeddings
149
  t5_seq = ConditioningShifter.extract_encoder_embeddings(
150
  {"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
151
  device, cfg_shift
152
  )
153
 
154
- # CLIP embeddings
155
  embeds = encode_prompt_xl(_pipe, prompt, negative_prompt)
156
-
157
- # run adapters --------------------------------------------------------
158
  outputs: List[AdapterOutput] = []
159
- if adapter_l_file and adapter_l_file != "None":
160
- ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
 
161
  outputs.append(ConditioningShifter.run_adapter(
162
  ada_l, t5_seq, embeds["clip_l"],
163
  cfg_shift.guidance_scale, "clip_l", (0, 768)))
164
 
165
- if adapter_g_file and adapter_g_file != "None":
166
- ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
167
  outputs.append(ConditioningShifter.run_adapter(
168
  ada_g, t5_seq, embeds["clip_g"],
169
  cfg_shift.guidance_scale, "clip_g", (768, 2048)))
170
 
171
- # apply modifications -------------------------------------------------
172
  clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
173
- delta_viz = {"clip_l": torch.zeros_like(clip_l_mod),
174
- "clip_g": torch.zeros_like(clip_g_mod)}
175
- gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]),
176
- "clip_g": torch.zeros_like(clip_g_mod[..., :1])}
177
 
178
  for out in outputs:
179
  target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
180
- mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift)
181
- if out.adapter_type == "clip_l":
182
- clip_l_mod = mod
183
- else:
184
- clip_g_mod = mod
185
  delta_viz[out.adapter_type] = out.delta.detach()
186
- gate_viz [out.adapter_type] = out.gate.detach()
187
 
188
- # prepare for SDXL ----------------------------------------------------
189
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
190
- neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
191
 
192
  image = _pipe(
193
- prompt_embeds = prompt_embeds,
194
- negative_prompt_embeds = neg_embeds,
195
- pooled_prompt_embeds = embeds["pooled"],
196
- negative_pooled_prompt_embeds = embeds["neg_pooled"],
197
- num_inference_steps = steps,
198
- guidance_scale = cfg_scale,
199
- width = width, height = height, generator = generator
200
  ).images[0]
201
 
202
- # diagnostics ---------------------------------------------------------
203
  delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Δ CLIP-L")
204
- gate_l_img = plot_heat(gate_viz ["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L")
205
  delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Δ CLIP-G")
206
- gate_g_img = plot_heat(gate_viz ["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G")
207
 
208
- stats_l = (f"τ̄_L = {outputs[0].tau.mean().item():.3f}"
209
- if outputs and outputs[0].adapter_type == "clip_l" else "-")
210
- stats_g = (f"τ̄_G = {outputs[-1].tau.mean().item():.3f}"
211
- if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-")
212
 
213
  return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
214
 
215
-
216
- # ─── GRADIO UI ────────────────────────────────────────────────────────────
217
  def create_interface():
218
  with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
219
  gr.Markdown("# 🧠 SDXL Dual-Shunt Tester")
@@ -221,18 +189,12 @@ def create_interface():
221
  with gr.Row():
222
  with gr.Column(scale=1):
223
  gr.Markdown("### Prompts")
224
- prompt = gr.Textbox(label="Prompt", lines=3,
225
- value="a futuristic control station with holographic displays")
226
- negative = gr.Textbox(label="Negative", lines=2,
227
- value="blurry, low quality, distorted")
228
 
229
  gr.Markdown("### Adapters")
230
- adapter_l = gr.Dropdown(["None"] + clip_l_opts,
231
- value="t5-vit-l-14-dual_shunt_caption.safetensors",
232
- label="CLIP-L Adapter")
233
- adapter_g = gr.Dropdown(["None"] + clip_g_opts,
234
- value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
235
- label="CLIP-G Adapter")
236
 
237
  gr.Markdown("### Adapter Controls")
238
  strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength")
@@ -247,8 +209,7 @@ def create_interface():
247
  with gr.Row():
248
  steps = gr.Slider(1, 50, 20, 1, label="Steps")
249
  cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
250
- scheduler = gr.Dropdown(list(SCHEDULERS.keys()),
251
- value="DPM++ 2M", label="Scheduler")
252
  with gr.Row():
253
  width = gr.Slider(512, 1536, 1024, 64, label="Width")
254
  height = gr.Slider(512, 1536, 1024, 64, label="Height")
@@ -266,13 +227,8 @@ def create_interface():
266
  stats_l = gr.Textbox(label="Stats L", interactive=False)
267
  stats_g = gr.Textbox(label="Stats G", interactive=False)
268
 
269
- def _run(*args):
270
- pl, npl = args[0], args[1]
271
- al, ag = (None if v == "None" else v for v in args[2:4])
272
- return infer(pl, npl, al, ag, *args[4:])
273
-
274
  run_btn.click(
275
- fn=_run,
276
  inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale,
277
  sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps,
278
  cfg_scale, scheduler, width, height, seed],
@@ -280,6 +236,5 @@ def create_interface():
280
  )
281
  return demo
282
 
283
-
284
  if __name__ == "__main__":
285
  create_interface().launch()
 
1
  # app.py ────────────────────────────────────────────────────────────────
2
  import io, warnings, numpy as np, matplotlib.pyplot as plt
3
+ from typing import Dict, List, Optional
4
+ from PIL import Image
5
  from pathlib import Path
 
6
 
7
  import gradio as gr
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
  from transformers import T5Tokenizer, T5EncoderModel
12
+ from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
 
 
 
13
  from huggingface_hub import hf_hub_download
14
  from safetensors.torch import load_file
15
 
 
16
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
17
  from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput
18
+ from configs import ShuntUtil
19
 
20
  warnings.filterwarnings("ignore")
21
 
22
+ # ─── GLOBALS ─────────────────────────────────────────────────────────────
23
+ dtype = torch.float16
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
+ _t5_tok: Optional[T5Tokenizer] = None
27
+ _t5_mod: Optional[T5EncoderModel] = None
28
+ _pipe: Optional[StableDiffusionXLPipeline] = None
 
29
 
30
  SCHEDULERS = {
31
  "DPM++ 2M": DPMSolverMultistepScheduler,
32
+ "DDIM": DDIMScheduler,
33
+ "Euler": EulerDiscreteScheduler,
34
  }
35
 
36
+ clip_l_shunts = ShuntUtil.get_shunts_by_clip_type("clip_l")
37
+ clip_g_shunts = ShuntUtil.get_shunts_by_clip_type("clip_g")
38
+ clip_l_opts = ["None"] + [s.name for s in clip_l_shunts]
39
+ clip_g_opts = ["None"] + [s.name for s in clip_g_shunts]
 
40
 
41
+ # ─── INIT ───────────────────────────────────────────────────────────────
 
42
  def _init_t5():
43
  global _t5_tok, _t5_mod
44
  if _t5_tok is None:
45
  _t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
46
+ _t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
 
 
47
 
48
  def _init_pipe():
49
  global _pipe
 
54
  ).to(device)
55
  _pipe.enable_xformers_memory_efficient_attention()
56
 
57
+ # ─── UTILITY ────────────────────────────────────────────────────────────
58
+ def load_adapter_by_name(name: str, device: torch.device) -> TwoStreamShuntAdapter:
59
+ shunt = ShuntUtil.get_shunt_by_name(name)
60
+ assert shunt, f"Shunt '{name}' not found."
61
+ path = hf_hub_download(repo_id=shunt.repo, filename=shunt.file)
62
+ model = TwoStreamShuntAdapter(shunt.config).eval()
63
  model.load_state_dict(load_file(path))
64
  return model.to(device)
65
 
 
66
  def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
67
  if isinstance(mat, torch.Tensor):
68
  mat = mat.detach().cpu().numpy()
 
82
  plt.close(); buf.seek(0)
83
  return np.array(Image.open(buf))
84
 
 
85
  def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
86
+ tok_l = pipe.tokenizer(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
87
+ tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
88
+ ntok_l = pipe.tokenizer(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
89
+ ntok_g = pipe.tokenizer_2(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
 
 
 
 
90
 
91
  with torch.no_grad():
92
+ clip_l = pipe.text_encoder(tok_l)[0]
93
+ neg_clip_l = pipe.text_encoder(ntok_l)[0]
94
+ g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False)
95
+ clip_g, pl = g_out[1], g_out[0]
96
+ ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
 
97
  neg_clip_g, npl = ng_out[1], ng_out[0]
98
 
99
+ return {"clip_l": clip_l, "clip_g": clip_g, "neg_l": neg_clip_l, "neg_g": neg_clip_g, "pooled": pl, "neg_pooled": npl}
 
 
 
100
 
101
  # ─── INFERENCE ───────────────────────────────────────────────────────────
102
  def infer(prompt: str, negative_prompt: str,
103
+ adapter_l_name: str, adapter_g_name: str,
104
  strength: float, delta_scale: float, sigma_scale: float,
105
  gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool,
106
  steps: int, cfg_scale: float, scheduler_name: str,
 
112
  if scheduler_name in SCHEDULERS:
113
  _pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
114
 
115
+ generator = (torch.Generator(device=device).manual_seed(seed) if seed != -1 else None)
 
116
 
 
117
  cfg_shift = ShiftConfig(
118
+ prompt=prompt,
119
+ seed=seed,
120
+ strength=strength,
121
+ delta_scale=delta_scale,
122
+ sigma_scale=sigma_scale,
123
+ gate_probability=gate_prob,
124
+ noise_injection=noise,
125
+ use_anchor=use_anchor,
126
+ guidance_scale=gpred_scale,
127
  )
128
 
 
129
  t5_seq = ConditioningShifter.extract_encoder_embeddings(
130
  {"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
131
  device, cfg_shift
132
  )
133
 
 
134
  embeds = encode_prompt_xl(_pipe, prompt, negative_prompt)
 
 
135
  outputs: List[AdapterOutput] = []
136
+
137
+ if adapter_l_name and adapter_l_name != "None":
138
+ ada_l = load_adapter_by_name(adapter_l_name, device)
139
  outputs.append(ConditioningShifter.run_adapter(
140
  ada_l, t5_seq, embeds["clip_l"],
141
  cfg_shift.guidance_scale, "clip_l", (0, 768)))
142
 
143
+ if adapter_g_name and adapter_g_name != "None":
144
+ ada_g = load_adapter_by_name(adapter_g_name, device)
145
  outputs.append(ConditioningShifter.run_adapter(
146
  ada_g, t5_seq, embeds["clip_g"],
147
  cfg_shift.guidance_scale, "clip_g", (768, 2048)))
148
 
 
149
  clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
150
+ delta_viz = {"clip_l": torch.zeros_like(clip_l_mod), "clip_g": torch.zeros_like(clip_g_mod)}
151
+ gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]), "clip_g": torch.zeros_like(clip_g_mod[..., :1])}
 
 
152
 
153
  for out in outputs:
154
  target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
155
+ mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift)
156
+ if out.adapter_type == "clip_l": clip_l_mod = mod
157
+ else: clip_g_mod = mod
 
 
158
  delta_viz[out.adapter_type] = out.delta.detach()
159
+ gate_viz[out.adapter_type] = out.gate.detach()
160
 
 
161
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
162
+ neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
163
 
164
  image = _pipe(
165
+ prompt_embeds=prompt_embeds,
166
+ negative_prompt_embeds=neg_embeds,
167
+ pooled_prompt_embeds=embeds["pooled"],
168
+ negative_pooled_prompt_embeds=embeds["neg_pooled"],
169
+ num_inference_steps=steps,
170
+ guidance_scale=cfg_scale,
171
+ width=width, height=height, generator=generator
172
  ).images[0]
173
 
 
174
  delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Δ CLIP-L")
175
+ gate_l_img = plot_heat(gate_viz["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L")
176
  delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Δ CLIP-G")
177
+ gate_g_img = plot_heat(gate_viz["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G")
178
 
179
+ stats_l = (f"τ̄_L = {outputs[0].tau.mean().item():.3f}" if outputs and outputs[0].adapter_type == "clip_l" else "-")
180
+ stats_g = (f"τ̄_G = {outputs[-1].tau.mean().item():.3f}" if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-")
 
 
181
 
182
  return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
183
 
184
+ # ─── GRADIO UI ───────────────────────────────────────────────────────────
 
185
  def create_interface():
186
  with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
187
  gr.Markdown("# 🧠 SDXL Dual-Shunt Tester")
 
189
  with gr.Row():
190
  with gr.Column(scale=1):
191
  gr.Markdown("### Prompts")
192
+ prompt = gr.Textbox(label="Prompt", lines=3)
193
+ negative = gr.Textbox(label="Negative", lines=2)
 
 
194
 
195
  gr.Markdown("### Adapters")
196
+ adapter_l = gr.Dropdown(clip_l_opts, value=clip_l_opts[1], label="CLIP-L Adapter")
197
+ adapter_g = gr.Dropdown(clip_g_opts, value=clip_g_opts[1], label="CLIP-G Adapter")
 
 
 
 
198
 
199
  gr.Markdown("### Adapter Controls")
200
  strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength")
 
209
  with gr.Row():
210
  steps = gr.Slider(1, 50, 20, 1, label="Steps")
211
  cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
212
+ scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler")
 
213
  with gr.Row():
214
  width = gr.Slider(512, 1536, 1024, 64, label="Width")
215
  height = gr.Slider(512, 1536, 1024, 64, label="Height")
 
227
  stats_l = gr.Textbox(label="Stats L", interactive=False)
228
  stats_g = gr.Textbox(label="Stats G", interactive=False)
229
 
 
 
 
 
 
230
  run_btn.click(
231
+ fn=infer,
232
  inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale,
233
  sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps,
234
  cfg_scale, scheduler, width, height, seed],
 
236
  )
237
  return demo
238
 
 
239
  if __name__ == "__main__":
240
  create_interface().launch()
conditioning_shifter.py CHANGED
@@ -4,7 +4,7 @@ import logging
4
  from typing import Dict, List, Tuple, Optional, Any
5
  from dataclasses import dataclass
6
 
7
- from . import ConditionModulationShuntAdapter, reshape_for_shunt
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
4
  from typing import Dict, List, Tuple, Optional, Any
5
  from dataclasses import dataclass
6
 
7
+ from two_stream_shunt_adapter import ConditionModulationShuntAdapter, reshape_for_shunt
8
 
9
  logger = logging.getLogger(__name__)
10
 
configs.py CHANGED
@@ -801,3 +801,16 @@ class ShuntUtil:
801
  """
802
  return [shunt.name for shunt in SHUNT_DATAS]
803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
  """
802
  return [shunt.name for shunt in SHUNT_DATAS]
803
 
804
+ @staticmethod
805
+ def get_shunts_by_clip_type(clip_type: str) -> list[ShuntData]:
806
+ """
807
+ Returns a list of shunts that match the given clip type.
808
+
809
+ Args:
810
+ clip_type (str): The type of clip to filter by (e.g., "clip_l", "clip_g").
811
+
812
+ Returns:
813
+ list[ShuntData]: List of shunts that match the clip type.
814
+ """
815
+ return [shunt for shunt in SHUNT_DATAS if any(mod["type"] == clip_type for mod in shunt.modulation_encoders)]
816
+
custom/__init__.py ADDED
File without changes
custom/t5_encoder_with_projection.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5EncoderModel, T5Config, T5PreTrainedModel
3
+ from transformers.modeling_outputs import BaseModelOutput
4
+ from typing import List, Optional, Tuple, Union
5
+ from torch import nn, Tensor
6
+
7
+
8
+ class T5ProjectionConfig(T5Config):
9
+ def __init__(self, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.project_in_dim = kwargs.get("project_in_dim", 768)
12
+ self.project_out_dim = kwargs.get("out_dim", 4096)
13
+
14
+
15
+ class T5EncoderWithProjection(T5PreTrainedModel):
16
+ config_class = T5ProjectionConfig
17
+
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+ # self.encoder = encoder
21
+ self.encoder = T5EncoderModel(config)
22
+
23
+ self.final_projection = nn.Sequential(
24
+ nn.Linear(config.project_in_dim, config.project_out_dim, bias=False),
25
+ nn.ReLU(),
26
+ nn.Dropout(0.0),
27
+ nn.Linear(config.project_out_dim, config.project_out_dim, bias=False)
28
+ )
29
+
30
+ def forward(
31
+ self,
32
+ input_ids: Optional[torch.LongTensor] = None,
33
+ attention_mask: Optional[torch.FloatTensor] = None,
34
+ head_mask: Optional[torch.FloatTensor] = None,
35
+ inputs_embeds: Optional[torch.FloatTensor] = None,
36
+ output_attentions: Optional[bool] = None,
37
+ output_hidden_states: Optional[bool] = None,
38
+ return_dict: Optional[bool] = None,
39
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
40
+
41
+ return_dict = return_dict if return_dict is not None else False
42
+
43
+ encoder_outputs = self.encoder(
44
+ input_ids=input_ids,
45
+ attention_mask=attention_mask,
46
+ inputs_embeds=inputs_embeds,
47
+ head_mask=head_mask,
48
+ output_attentions=output_attentions,
49
+ output_hidden_states=output_hidden_states,
50
+ return_dict=return_dict,
51
+ )
52
+ last_hidden_state = self.final_projection(encoder_outputs[0])
53
+ # last_hidden_state = self.final_block(last_hidden_state)[0]
54
+
55
+ if not return_dict:
56
+ return tuple(
57
+ v for v in [last_hidden_state] if v is not None
58
+ )
59
+
60
+ return BaseModelOutput(
61
+ last_hidden_state=last_hidden_state
62
+ )
model_manager.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Any, Union, Tuple
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import logging
6
+ from pathlib import Path
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+
10
+ from safetensors.torch import load_file
11
+ from torch.nn import Module
12
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, BertModel, BertTokenizer, \
13
+ PreTrainedTokenizerFast, T5TokenizerFast, T5EncoderModel
14
+
15
+ from .custom.t5_encoder_with_projection import T5EncoderWithProjection
16
+
17
+ logger = logging.getLogger(__name__)
18
+ # --------------------------------------------------------------------------- #
19
+ # Helper for namespaced cache keys
20
+ def _make_key(model_type: str, model_id: str) -> str:
21
+ """
22
+ Produce a unique key for the internal cache.
23
+
24
+ Example
25
+ -------
26
+ >>> _make_key("bert", "bert-base")
27
+ 'bert:bert-base'
28
+ """
29
+ return f"{model_type}:{model_id}"
30
+
31
+
32
+ # Thread-safe registry wrapper
33
+ class _SafeDict(dict):
34
+ """A dict protected by a re-entrant lock for thread-safe writes."""
35
+ def __init__(self):
36
+ super().__init__()
37
+ import threading
38
+ self._lock = threading.RLock()
39
+
40
+ def safe_set(self, key, value):
41
+ with self._lock:
42
+ super().__setitem__(key, value)
43
+
44
+ def safe_get(self, key, default=None):
45
+ with self._lock:
46
+ return super().get(key, default)
47
+
48
+ def safe_del(self, key):
49
+ with self._lock:
50
+ if key in self:
51
+ super().__delitem__(key)
52
+ return True
53
+ return False
54
+
55
+
56
+ # -------------------------------------------------------------------------------------------------------------------- #
57
+ # WARNING: ENABLING THIS TRUST_REMOTE_CODE FLAG WILL ALLOW EXECUTION OF ARBITRARY CODE FROM THE MODEL REPOSITORY.
58
+ # USE WITH EXTREME CAUTION, AS IT CAN POTENTIALLY EXECUTE MALICIOUS CODE FROM UNTRUSTED SOURCES.
59
+
60
+ TRUST_REMOTE_CODE = False # Set to True only if you trust the source of the models you are loading.
61
+
62
+ # I advise leaving this OFF for any production or sensitive environments, and for any government or enterprise use.
63
+ # Ensure you fully trust the model repository and its maintainers and reviewing the code thoroughly.
64
+ # You cannot ONLY trust an AI's response to the question of whether it is safe to enable this flag,
65
+ # as it may not have the full context of security implications or the specific model's behavior.
66
+ # -------------------------------------------------------------------------------------------------------------------- #
67
+ # COMFYUI operates within a form of sandbox, but enabling remote code execution can still pose many unseen risks.
68
+ # -------------------------------------------------------------------------------------------------------------------- #
69
+
70
+
71
+ class ModelType(Enum):
72
+ """Enum for different model types"""
73
+ SHUNT_ADAPTER = "shunt_adapter"
74
+ T5_MODEL = "t5_model"
75
+ BERT_MODEL = "bert"
76
+ NOMIC_BERT_MODEL = "nomic_bert"
77
+ GENERIC = "generic"
78
+ TOKENIZER = "tokenizer"
79
+
80
+ @dataclass
81
+ class ModelInfo:
82
+ """Container for model information"""
83
+ model: nn.Module
84
+ model_type: ModelType
85
+ config: Dict[str, Any]
86
+ device: torch.device
87
+ dtype: torch.dtype
88
+ metadata: Dict[str, Any] = None
89
+ trust_remote_code: bool = TRUST_REMOTE_CODE # Use global setting by default
90
+
91
+
92
+ class ModelManager:
93
+ """
94
+ Centralized model loader / cache with thread-safety and namespaced keys.
95
+ """
96
+
97
+ def __init__(self, cache_dir: Optional[str] = None):
98
+ # Thread-safe model cache
99
+ self.models: _SafeDict = _SafeDict()
100
+
101
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+ self.cache_dir = self._setup_cache_dir(cache_dir)
103
+
104
+ # be VERY careful with huggingface keys, remote code execution, and model downloads.
105
+ # If you are using private models or need to authenticate, set the HuggingFace API key.
106
+ def set_huggingface_key(self, key: str):
107
+ """
108
+ Set the HuggingFace API key for model downloads.
109
+ This is useful if you have a private model or need to authenticate.
110
+ """
111
+ os.environ["HF_TOKEN"] = key
112
+ logger.info("HuggingFace API key set successfully.")
113
+
114
+ def get_huggingface_key(self) -> Optional[str]:
115
+ """
116
+ Get the HuggingFace API key if set.
117
+ This is useful for debugging or checking if authentication is needed.
118
+ """
119
+ return os.environ.get("HF_TOKEN")
120
+
121
+ def set_huggingface_cache_directory(self, directory: str):
122
+ """
123
+ Set the cache directory for HuggingFace model downloads.
124
+ This is useful if you want to change the cache location.
125
+ This will not move your models, it only sets the new default directory.
126
+ """
127
+ os.environ["HF_HOME"] = directory
128
+ logger.info(f"HuggingFace default directory set to: {directory}")
129
+
130
+ def get_huggingface_cache_directory(self) -> Optional[str]:
131
+ """
132
+ Get the cache directory for HuggingFace model downloads.
133
+ This is useful for debugging or checking where models are stored.
134
+ """
135
+ return os.environ.get("HF_HOME", str(self.cache_dir))
136
+
137
+ # --------------------------------------------------------------------- #
138
+ # Internal helpers
139
+ def _store(self, key: str, info: "ModelInfo") -> None:
140
+ """Thread-safe insertion into the model cache."""
141
+ self.models.safe_set(key, info)
142
+
143
+
144
+ def _setup_cache_dir(self, cache_dir: Optional[str]) -> Path:
145
+ """Setup and validate cache directory"""
146
+ if cache_dir:
147
+ cache_path = Path(cache_dir)
148
+ else:
149
+ # Use default HuggingFace cache location
150
+ cache_path = Path.home() / ".cache" / "huggingface" / "transformers"
151
+
152
+ cache_path.mkdir(parents=True, exist_ok=True)
153
+ logger.info(f"Using cache directory: {cache_path}")
154
+ return cache_path
155
+
156
+ def get_model(self, key: str) -> Optional["ModelInfo"]:
157
+ """Retrieve a model by its namespaced key."""
158
+ return self.models.safe_get(key)
159
+
160
+ def is_loaded(self, key: str) -> bool:
161
+ """Return True if the namespaced key is present in the cache."""
162
+ return self.models.safe_get(key) is not None
163
+
164
+
165
+ def move_model(
166
+ self,
167
+ namespaced_key: str,
168
+ *,
169
+ device: Optional[torch.device] = None,
170
+ dtype: Optional[torch.dtype] = None,
171
+ ) -> Optional[nn.Module]:
172
+ """
173
+ Convert device/dtype of a cached model and return the updated object.
174
+ """
175
+ model = self._maybe_convert_dtype(namespaced_key, dtype, device)
176
+ if model is None:
177
+ logger.warning("move_model: %s not found", namespaced_key)
178
+ return model
179
+
180
+
181
+ def load_tokenizer(
182
+ self,
183
+ id: str,
184
+ tokenizer_name_or_path: str,
185
+ target_output_device: Optional[torch.device] = None,
186
+ force_reload: bool = False,
187
+ trust_remote_code: Optional[bool] = None,
188
+ ) -> Optional[tuple[PreTrainedTokenizerFast, dict[str, Any]]]:
189
+ """Load or fetch from cache a Hugging-Face tokenizer."""
190
+ key = _make_key("tokenizer", id)
191
+ if not force_reload and self.is_loaded(key):
192
+ model_info = self.get_model(key)
193
+ return model_info.model, model_info.metadata
194
+
195
+ try:
196
+ trust_remote_code = (
197
+ trust_remote_code if trust_remote_code is not None else TRUST_REMOTE_CODE
198
+ )
199
+ tok = AutoTokenizer.from_pretrained(
200
+ tokenizer_name_or_path, trust_remote_code=trust_remote_code
201
+ )
202
+
203
+ self._store(
204
+ key,
205
+ ModelInfo(
206
+ model=tok,
207
+ model_type=ModelType.TOKENIZER,
208
+ config={"tokenizer_name": tokenizer_name_or_path},
209
+ device=target_output_device or torch.device("cpu"),
210
+ dtype=torch.float32,
211
+ metadata={"source": "huggingface", "trust_remote_code": trust_remote_code},
212
+ ),
213
+ )
214
+ logger.info("Loaded tokenizer %s", key)
215
+ return tok, self.get_model(key).metadata
216
+
217
+ except Exception:
218
+ logger.exception("Failed to load tokenizer %s", id)
219
+ return None
220
+
221
+
222
+ def load_shunt_adapter(
223
+ self,
224
+ adapter_id: str,
225
+ config: Dict[str, Any],
226
+ path: Optional[str] = None,
227
+ repo_id: Optional[str] = None,
228
+ filename: Optional[str] = None,
229
+ device: Optional[torch.device] = None,
230
+ dtype: Optional[torch.dtype] = None,
231
+ force_reload: bool = False
232
+ ) -> Optional[nn.Module]:
233
+ """
234
+ Load a shunt adapter from local path or HuggingFace.
235
+
236
+ Args:
237
+ adapter_id: Unique identifier for the adapter
238
+ config: Configuration dictionary for the adapter
239
+ path: Local path to the adapter file
240
+ repo_id: HuggingFace repository ID
241
+ filename: Filename in the HuggingFace repository
242
+ device: Target device
243
+ dtype: Target dtype
244
+ force_reload: Force reload even if cached
245
+
246
+ Returns:
247
+ Loaded adapter model or None if failed
248
+ """
249
+ if not force_reload and self.is_loaded(adapter_id):
250
+ logger.info(f"Using cached adapter: {adapter_id}")
251
+ return self._maybe_convert_dtype(adapter_id, dtype, device)
252
+ try:
253
+ # Import here to avoid circular imports
254
+ from two_stream_shunt_adapter import ConditionModulationShuntAdapter
255
+
256
+ # Determine file location
257
+ file_path = self._resolve_file_path(path, repo_id, filename)
258
+ if not file_path:
259
+ raise FileNotFoundError(f"Could not find adapter file for {adapter_id}")
260
+ # Initialize adapter
261
+ # if the filename ends with t5-vit-l-14-dual_shunt_booru_13_000_000.safetensors we set attention heads to 4, else we set to 12
262
+ logger.info(f"Loading adapter {adapter_id} from {file_path}")
263
+ adapter = ConditionModulationShuntAdapter(config=config)
264
+ logger.info(f"Initialized adapter {adapter_id} with config: {config}")
265
+ # Load weights
266
+ state_dict = load_file(file_path)
267
+ logger.info(f"Loaded state_dict for adapter {adapter_id} from {file_path}")
268
+ adapter.load_state_dict(state_dict, strict=False)
269
+ logger.info(f"Adapter {adapter_id} state_dict loaded successfully")
270
+
271
+ # Move to device and dtype
272
+ device = device or self.device
273
+ dtype = dtype or torch.float32
274
+ logger.info(f"Moving adapter {adapter_id} to device: {device}, dtype: {dtype}")
275
+ adapter = adapter.to(device=device, dtype=dtype)
276
+ logger.info(f"Adapter {adapter_id} moved to device and dtype successfully")
277
+
278
+ # Cache the model
279
+ self.models[adapter_id] = ModelInfo(
280
+ model=adapter,
281
+ model_type=ModelType.SHUNT_ADAPTER,
282
+ config=config,
283
+ device=device,
284
+ dtype=dtype,
285
+ metadata={"file_path": str(file_path)}
286
+ )
287
+ logger.info(f"Adapter {adapter_id} cached successfully")
288
+
289
+ logger.info(f"Successfully loaded adapter: {adapter_id}")
290
+ return adapter
291
+
292
+ except Exception as e:
293
+ logger.error(f"Failed to load adapter {adapter_id} from {path or repo_id}/{filename}: {e}")
294
+ logger.debug(f"Traceback: {e.__traceback__}")
295
+ return None
296
+
297
+ def load_encoder_model(self,
298
+ model_type: str, # use this to see if it's compatible with the current model manager
299
+ model_id: str,
300
+ model_name_or_path: str,
301
+ device: Optional[torch.device] = None,
302
+ dtype: Optional[torch.dtype] = None,
303
+ force_reload: bool = False,
304
+ trust_remote_code: Optional[bool] = None, # Overrides the global TRUST_REMOTE_CODE setting.
305
+ config: Optional[Dict[str, Any]] = None # Additional configuration for the model
306
+ ) -> Optional[nn.Module]:
307
+ """
308
+ Load an encoder model (e.g., BERT, T5) and return it.
309
+
310
+ Args:
311
+ model_type: Type of the model (e.g., "bert", "t5")
312
+ model_id: Unique identifier for the model
313
+ model_name_or_path: Model name or path
314
+ device: Target device
315
+ dtype: Target dtype
316
+ force_reload: Force reload even if cached
317
+
318
+ Returns:
319
+ Loaded model or None if failed
320
+ """
321
+ if model_type == "bert":
322
+ return self.load_bert_model(model_id, model_name_or_path, device, dtype, force_reload, trust_remote_code)
323
+ elif model_type == "nomic_bert":
324
+ # Nomic BERT is a specific variant of BERT, so we can use the same loading function
325
+ return self.load_bert_model(model_id, model_name_or_path, device, dtype, force_reload, trust_remote_code)
326
+ elif "t5" in model_type:
327
+ return self.load_t5_model(model_id, model_name_or_path, device, dtype, force_reload, trust_remote_code, config)
328
+ else:
329
+ logger.error(f"Unsupported model type: {model_type}")
330
+ return None
331
+
332
+ def load_bert_model(
333
+ self,
334
+ model_id: str,
335
+ model_name_or_path: str,
336
+ device: Optional[torch.device] = None,
337
+ dtype: Optional[torch.dtype] = None,
338
+ force_reload: bool = False,
339
+ trust_remote_code: Optional[bool] = None # Overrides the global TRUST_REMOTE_CODE setting.
340
+ ) -> Optional[Tuple[nn.Module, Any]]:
341
+
342
+ """
343
+ Load a BERT model and tokenizer.
344
+
345
+ Returns:
346
+ Tuple of (model, tokenizer) or None if failed
347
+ """
348
+ if not force_reload and self.is_loaded(model_id):
349
+ logger.info(f"Using cached BERT model: {model_id}")
350
+ model_info = self.get_model(model_id)
351
+ return model_info.model, model_info.metadata.get("tokenizer")
352
+
353
+ try:
354
+ device = device or self.device
355
+ dtype = dtype or torch.float32
356
+
357
+ config = AutoConfig.from_pretrained(
358
+ model_name_or_path,
359
+ trust_remote_code=trust_remote_code if trust_remote_code is not None else TRUST_REMOTE_CODE # Use the global flag for remote code execution
360
+ )
361
+
362
+ # Load tokenizer and model
363
+ tokenizer = AutoTokenizer.from_pretrained(
364
+ model_name_or_path,
365
+ config=config,
366
+ use_special_tokens=True, # Ensure special tokens are used
367
+ trust_remote_code=trust_remote_code if trust_remote_code is not None else TRUST_REMOTE_CODE # Use the global flag for remote code execution
368
+ )
369
+ model = AutoModel.from_pretrained(
370
+ model_name_or_path,
371
+ config=config,
372
+ torch_dtype=dtype,
373
+ trust_remote_code=trust_remote_code if trust_remote_code is not None else TRUST_REMOTE_CODE # Use the global flag for remote code execution
374
+ ).to(device)
375
+
376
+ # Cache the model
377
+
378
+ self._store(_make_key("bert", model_id), ModelInfo(
379
+ model=model,
380
+ model_type=ModelType.BERT_MODEL,
381
+ config={"model_name": model_name_or_path},
382
+ device=device,
383
+ dtype=dtype,
384
+ metadata={"tokenizer": tokenizer},
385
+ trust_remote_code=trust_remote_code if trust_remote_code is not None else TRUST_REMOTE_CODE
386
+ ))
387
+
388
+ logger.info(f"Successfully loaded BERT model: {model_id}")
389
+ return model, tokenizer
390
+
391
+ except Exception as e:
392
+ logger.error(f"Failed to load BERT model {model_id}: {e}")
393
+ return None
394
+
395
+ def load_t5_model(
396
+ self,
397
+ model_id: str,
398
+ model_name_or_path: str,
399
+ device: Optional[torch.device] = None,
400
+ dtype: Optional[torch.dtype] = None,
401
+ force_reload: bool = False,
402
+ override_remote_code: Optional[bool] = None, # Overrides the global TRUST_REMOTE_CODE setting.
403
+ config: Optional[Dict[str, Any]] = None # Additional configuration for the model
404
+ ) -> Optional[Tuple[nn.Module, Any]]:
405
+ """
406
+ Load a T5 model and tokenizer.
407
+
408
+ Returns:
409
+ Tuple of (model, tokenizer) or None if failed
410
+ """
411
+ if not force_reload and self.is_loaded(model_id):
412
+ logger.info(f"Using cached T5 model: {model_id}")
413
+ model_info = self.get_model(model_id)
414
+ return model_info.model, model_info.metadata.get("tokenizer")
415
+
416
+ try:
417
+ device = device or self.device
418
+ dtype = dtype or torch.float32
419
+ trust_remote_code = override_remote_code if override_remote_code is not None else TRUST_REMOTE_CODE
420
+ # Load tokenizer and model
421
+ if config.get("type", "t5") == "t5":
422
+ tokenizer = AutoTokenizer.from_pretrained(
423
+ "google/flan-t5-base",
424
+ trust_remote_code=trust_remote_code # Use the global flag for remote code execution
425
+ )
426
+ elif config.get("type", "t5") == "t5_unchained":
427
+ tokenizer = T5TokenizerFast.from_pretrained(
428
+ "AbstractPhil/t5xxl-unchained",
429
+ trust_remote_code=trust_remote_code # Use the global flag for remote code execution
430
+ )
431
+ else:
432
+ tokenizer = T5TokenizerFast.from_pretrained(
433
+ "google/flan-t5-base",
434
+ trust_remote_code=trust_remote_code # Use the global flag for remote code execution
435
+ )
436
+
437
+ if config.get("type", "t5") == "t5":
438
+ logger.info(f"Loading T5ForConditionalGeneration model from {model_name_or_path}")
439
+ model = AutoModelForSeq2SeqLM.from_pretrained(
440
+ model_name_or_path,
441
+ torch_dtype=dtype,
442
+ trust_remote_code=trust_remote_code # Use the global flag for remote code execution
443
+ ).to(device)
444
+ elif config.get("type", "t5") == "t5_encoder_with_projection":
445
+ # Load T5EncoderModel with projection layer
446
+ logger.info(f"Loading T5EncoderWithProjection model from {model_name_or_path}")
447
+ model = T5EncoderWithProjection.from_pretrained(
448
+ model_name_or_path,
449
+ torch_dtype=dtype,
450
+ trust_remote_code=trust_remote_code # Use the global flag for remote code execution
451
+ ).to(device)
452
+
453
+ else:
454
+ # Load standard T5 model
455
+ logger.info(f"Loading T5EncoderModel from {model_name_or_path}")
456
+ model = AutoModel.from_pretrained(
457
+ model_name_or_path,
458
+ torch_dtype=dtype,
459
+ trust_remote_code=trust_remote_code # Use the global flag for remote code execution
460
+ ).to(device)
461
+
462
+ # Cache the model
463
+ self._store(_make_key("t5", model_id), ModelInfo(
464
+ model=model,
465
+ model_type=ModelType.T5_MODEL,
466
+ config={"model_name": model_name_or_path},
467
+ device=device,
468
+ dtype=dtype,
469
+ metadata={"tokenizer": tokenizer}
470
+ ))
471
+
472
+ logger.info(f"Successfully loaded T5 model: {model_id}")
473
+ return model, tokenizer
474
+
475
+ except Exception as e:
476
+ logger.error(f"Failed to load T5 model {model_id}: {e}")
477
+ return None
478
+
479
+ def unload_model(self, model_id: str) -> bool:
480
+ """
481
+ Unload a model to free memory.
482
+
483
+ Returns:
484
+ True if successfully unloaded, False otherwise
485
+ """
486
+ if model_id in self.models:
487
+ try:
488
+ # Move to CPU first to free GPU memory
489
+ model_info = self.models[model_id]
490
+ model_info.model.cpu()
491
+
492
+ # Delete the model
493
+ del self.models[model_id]
494
+
495
+ # Force garbage collection
496
+ import gc
497
+ gc.collect()
498
+ if torch.cuda.is_available():
499
+ torch.cuda.empty_cache()
500
+
501
+ logger.info(f"Successfully unloaded model: {model_id}")
502
+ return True
503
+
504
+ except Exception as e:
505
+ logger.error(f"Failed to unload model {model_id}: {e}")
506
+ return False
507
+ else:
508
+ logger.warning(f"Model {model_id} not found in cache")
509
+ return False
510
+
511
+ def list_models(self) -> Dict[str, Dict[str, Any]]:
512
+ """List all loaded models with their information"""
513
+ return {
514
+ model_id: {
515
+ "type": info.model_type.value,
516
+ "device": str(info.device),
517
+ "dtype": str(info.dtype),
518
+ "config": info.config
519
+ }
520
+ for model_id, info in self.models.items()
521
+ }
522
+
523
+ def clear_all(self):
524
+ """Clear all loaded models"""
525
+ model_ids = list(self.models.keys())
526
+ for model_id in model_ids:
527
+ self.unload_model(model_id)
528
+ logger.info("All models cleared from memory")
529
+
530
+ def _resolve_file_path(
531
+ self,
532
+ local_path: Optional[str],
533
+ repo_id: Optional[str],
534
+ filename: Optional[str]
535
+ ) -> Optional[Path]:
536
+ """Resolve file path from local or HuggingFace"""
537
+ # Try local path first
538
+ if local_path and os.path.exists(local_path):
539
+ return Path(local_path)
540
+
541
+ # Try HuggingFace
542
+ if repo_id and filename:
543
+ try:
544
+ from huggingface_hub import hf_hub_download
545
+
546
+ file_path = hf_hub_download(
547
+ repo_id=repo_id,
548
+ filename=filename,
549
+ cache_dir=str(self.cache_dir),
550
+ repo_type="model"
551
+ )
552
+ return Path(file_path)
553
+
554
+ except Exception as e:
555
+ logger.error(f"Failed to download from HuggingFace: {e}")
556
+
557
+ return None
558
+
559
+ def _maybe_convert_dtype(
560
+ self,
561
+ model_id: str,
562
+ target_dtype: Optional[torch.dtype],
563
+ target_device: Optional[torch.device]
564
+ ) -> Optional[nn.Module]:
565
+ """Convert model dtype/device if needed"""
566
+ model_info = self.get_model(model_id)
567
+ if not model_info:
568
+ return None
569
+
570
+ model = model_info.model
571
+ changed = False
572
+
573
+ # Check dtype conversion
574
+ if target_dtype and model_info.dtype != target_dtype:
575
+ try:
576
+ model = model.to(dtype=target_dtype)
577
+ model_info.dtype = target_dtype
578
+ changed = True
579
+ logger.info(f"Converted {model_id} to dtype: {target_dtype}")
580
+ except Exception as e:
581
+ logger.error(f"Failed to convert dtype for {model_id}: {e}")
582
+
583
+ # Check device conversion
584
+ if target_device and model_info.device != target_device:
585
+ try:
586
+ model = model.to(device=target_device)
587
+ model_info.device = target_device
588
+ changed = True
589
+ logger.info(f"Moved {model_id} to device: {target_device}")
590
+ except Exception as e:
591
+ logger.error(f"Failed to move {model_id} to device: {e}")
592
+
593
+ if changed:
594
+ model_info.model = model
595
+
596
+ return model
597
+
598
+
599
+ def __del__(self):
600
+ """Cleanup on deletion"""
601
+ self.clear_all()
602
+
603
+
604
+ # Global instance (singleton pattern)
605
+ _global_model_manager: Optional[ModelManager] = None
606
+
607
+
608
+ def get_model_manager(cache_dir: Optional[str] = None) -> ModelManager:
609
+ """Get or create the global model manager instance"""
610
+ global _global_model_manager
611
+
612
+ if _global_model_manager is None:
613
+ _global_model_manager = ModelManager(cache_dir=cache_dir)
614
+
615
+ return _global_model_manager
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  sentencepiece
2
  accelerate
3
  diffusers
4
- invisible_watermark
5
  torch
6
  transformers
7
  xformers
8
- matplotlib
 
 
1
  sentencepiece
2
  accelerate
3
  diffusers
 
4
  torch
5
  transformers
6
  xformers
7
+ matplotlib
8
+ gradio
two_stream_shunt_adapter.py CHANGED
@@ -2,7 +2,6 @@ from typing import Tuple
2
 
3
  import torch
4
  import torch.nn as nn
5
- from . import ENCODER_CONFIGS, HARMONIC_SHUNT_REPOS
6
 
7
 
8
  class DualConversionNames:
 
2
 
3
  import torch
4
  import torch.nn as nn
 
5
 
6
 
7
  class DualConversionNames: