ginipick commited on
Commit
ae8c703
·
verified ·
1 Parent(s): a12a030

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -257
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # import os
2
  import spaces
3
 
4
  import time
@@ -6,7 +6,7 @@ import gradio as gr
6
  import torch
7
  from PIL import Image
8
  from torchvision import transforms
9
- from dataclasses import dataclass
10
  import math
11
  from typing import Callable
12
 
@@ -21,11 +21,8 @@ from diffusers import AutoencoderKL
21
  from torch import Tensor, nn
22
  from transformers import CLIPTextModel, CLIPTokenizer
23
  from transformers import T5EncoderModel, T5Tokenizer
24
- # from optimum.quanto import freeze, qfloat8, quantize
25
- from transformers import pipeline
26
 
27
- ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
28
- ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
29
 
30
  class HFEmbedder(nn.Module):
31
  def __init__(self, version: str, max_length: int, **hf_kwargs):
@@ -60,58 +57,24 @@ class HFEmbedder(nn.Module):
60
  output_hidden_states=False,
61
  )
62
  return outputs[self.output_key]
63
-
64
 
65
  device = "cuda"
66
  t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
67
  clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
68
  ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
69
- # quantize(t5, weights=qfloat8)
70
- # freeze(t5)
71
-
72
 
73
  # ---------------- NF4 ----------------
74
 
75
-
76
  def functional_linear_4bits(x, weight, bias):
 
77
  out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
78
  out = out.to(x)
79
  return out
80
 
81
-
82
- def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
83
- if state is None:
84
- return None
85
-
86
- device = device or state.absmax.device
87
-
88
- state2 = (
89
- QuantState(
90
- absmax=state.state2.absmax.to(device),
91
- shape=state.state2.shape,
92
- code=state.state2.code.to(device),
93
- blocksize=state.state2.blocksize,
94
- quant_type=state.state2.quant_type,
95
- dtype=state.state2.dtype,
96
- )
97
- if state.nested
98
- else None
99
- )
100
-
101
- return QuantState(
102
- absmax=state.absmax.to(device),
103
- shape=state.shape,
104
- code=state.code.to(device),
105
- blocksize=state.blocksize,
106
- quant_type=state.quant_type,
107
- dtype=state.dtype,
108
- offset=state.offset.to(device) if state.nested else None,
109
- state2=state2,
110
- )
111
-
112
-
113
  class ForgeParams4bit(Params4bit):
 
114
  def to(self, *args, **kwargs):
 
115
  device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
116
  if device is not None and device.type == "cuda" and not self.bnb_quantized:
117
  return self._quantize(device)
@@ -119,9 +82,7 @@ class ForgeParams4bit(Params4bit):
119
  n = ForgeParams4bit(
120
  torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
121
  requires_grad=self.requires_grad,
122
- quant_state=copy_quant_state(self.quant_state, device),
123
- # blocksize=self.blocksize,
124
- # compress_statistics=self.compress_statistics,
125
  compress_statistics=False,
126
  blocksize=64,
127
  quant_type=self.quant_type,
@@ -134,11 +95,10 @@ class ForgeParams4bit(Params4bit):
134
  self.quant_state = n.quant_state
135
  return n
136
 
137
-
138
- class ForgeLoader4Bit(torch.nn.Module):
139
  def __init__(self, *, device, dtype, quant_type, **kwargs):
140
  super().__init__()
141
- self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
142
  self.weight = None
143
  self.quant_state = None
144
  self.bias = None
@@ -146,23 +106,26 @@ class ForgeLoader4Bit(torch.nn.Module):
146
 
147
  def _save_to_state_dict(self, destination, prefix, keep_vars):
148
  super()._save_to_state_dict(destination, prefix, keep_vars)
 
149
  quant_state = getattr(self.weight, "quant_state", None)
150
  if quant_state is not None:
151
  for k, v in quant_state.as_dict(packed=True).items():
152
  destination[prefix + "weight." + k] = v if keep_vars else v.detach()
153
  return
154
 
155
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
156
- quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
 
 
 
157
 
 
158
  if any('bitsandbytes' in k for k in quant_state_keys):
159
  quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
160
-
161
  self.weight = ForgeParams4bit.from_prequantized(
162
  data=state_dict[prefix + 'weight'],
163
  quantized_stats=quant_state_dict,
164
  requires_grad=False,
165
- # device=self.dummy.device,
166
  device=torch.device('cuda'),
167
  module=self
168
  )
@@ -170,7 +133,6 @@ class ForgeLoader4Bit(torch.nn.Module):
170
 
171
  if prefix + 'bias' in state_dict:
172
  self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
173
-
174
  del self.dummy
175
  elif hasattr(self, 'dummy'):
176
  if prefix + 'weight' in state_dict:
@@ -191,56 +153,39 @@ class ForgeLoader4Bit(torch.nn.Module):
191
  else:
192
  super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
193
 
194
-
195
  class Linear(ForgeLoader4Bit):
196
  def __init__(self, *args, device=None, dtype=None, **kwargs):
197
  super().__init__(device=device, dtype=dtype, quant_type='nf4')
198
 
199
  def forward(self, x):
200
  self.weight.quant_state = self.quant_state
201
-
202
  if self.bias is not None and self.bias.dtype != x.dtype:
203
- # Maybe this can also be set to all non-bnb ops since the cost is very low.
204
- # And it only invokes one time, and most linear does not have bias
205
  self.bias.data = self.bias.data.to(x.dtype)
206
-
207
  return functional_linear_4bits(x, self.weight, self.bias)
208
-
209
 
 
210
  nn.Linear = Linear
211
 
212
-
213
  # ---------------- Model ----------------
214
 
215
-
216
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
217
  q, k = apply_rope(q, k, pe)
218
-
219
  x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
220
- # x = rearrange(x, "B H L D -> B L (H D)")
221
  x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1)
222
-
223
  return x
224
 
225
-
226
  def rope(pos, dim, theta):
 
227
  scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
228
  omega = 1.0 / (theta ** scale)
229
-
230
- # out = torch.einsum("...n,d->...nd", pos, omega)
231
  out = pos.unsqueeze(-1) * omega.unsqueeze(0)
232
-
233
  cos_out = torch.cos(out)
234
  sin_out = torch.sin(out)
235
  out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
236
-
237
- # out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
238
  b, n, d, _ = out.shape
239
  out = out.view(b, n, d, 2, 2)
240
-
241
  return out.float()
242
 
243
-
244
  def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
245
  xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
246
  xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
@@ -248,7 +193,6 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
248
  xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
249
  return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
250
 
251
-
252
  class EmbedND(nn.Module):
253
  def __init__(self, dim: int, theta: int, axes_dim: list[int]):
254
  super().__init__()
@@ -257,33 +201,19 @@ class EmbedND(nn.Module):
257
  self.axes_dim = axes_dim
258
 
259
  def forward(self, ids: Tensor) -> Tensor:
 
260
  n_axes = ids.shape[-1]
261
  emb = torch.cat(
262
  [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
263
  dim=-3,
264
  )
265
-
266
  return emb.unsqueeze(1)
267
 
268
-
269
  def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
270
- """
271
- Create sinusoidal timestep embeddings.
272
- :param t: a 1-D Tensor of N indices, one per batch element.
273
- These may be fractional.
274
- :param dim: the dimension of the output.
275
- :param max_period: controls the minimum frequency of the embeddings.
276
- :return: an (N, D) Tensor of positional embeddings.
277
- """
278
  t = time_factor * t
279
  half = dim // 2
280
-
281
- # Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
282
- # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
283
-
284
- # Block CUDA steam, but consistent with official codes:
285
  freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
286
-
287
  args = t[:, None].float() * freqs[None]
288
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
289
  if dim % 2:
@@ -292,7 +222,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
292
  embedding = embedding.to(t)
293
  return embedding
294
 
295
-
296
  class MLPEmbedder(nn.Module):
297
  def __init__(self, in_dim: int, hidden_dim: int):
298
  super().__init__()
@@ -303,19 +232,18 @@ class MLPEmbedder(nn.Module):
303
  def forward(self, x: Tensor) -> Tensor:
304
  return self.out_layer(self.silu(self.in_layer(x)))
305
 
306
-
307
  class RMSNorm(torch.nn.Module):
308
  def __init__(self, dim: int):
309
  super().__init__()
310
  self.scale = nn.Parameter(torch.ones(dim))
311
 
312
  def forward(self, x: Tensor):
 
313
  x_dtype = x.dtype
314
  x = x.float()
315
  rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
316
  return (x * rrms).to(dtype=x_dtype) * self.scale
317
 
318
-
319
  class QKNorm(torch.nn.Module):
320
  def __init__(self, dim: int):
321
  super().__init__()
@@ -327,20 +255,17 @@ class QKNorm(torch.nn.Module):
327
  k = self.key_norm(k)
328
  return q.to(v), k.to(v)
329
 
330
-
331
  class SelfAttention(nn.Module):
332
  def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
333
  super().__init__()
334
  self.num_heads = num_heads
335
- head_dim = dim // num_heads
336
-
337
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 
338
  self.norm = QKNorm(head_dim)
339
  self.proj = nn.Linear(dim, dim)
340
 
341
  def forward(self, x: Tensor, pe: Tensor) -> Tensor:
342
  qkv = self.qkv(x)
343
- # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
344
  B, L, _ = qkv.shape
345
  qkv = qkv.view(B, L, 3, self.num_heads, -1)
346
  q, k, v = qkv.permute(2, 0, 3, 1, 4)
@@ -349,6 +274,7 @@ class SelfAttention(nn.Module):
349
  x = self.proj(x)
350
  return x
351
 
 
352
 
353
  @dataclass
354
  class ModulationOut:
@@ -356,7 +282,6 @@ class ModulationOut:
356
  scale: Tensor
357
  gate: Tensor
358
 
359
-
360
  class Modulation(nn.Module):
361
  def __init__(self, dim: int, double: bool):
362
  super().__init__()
@@ -364,37 +289,30 @@ class Modulation(nn.Module):
364
  self.multiplier = 6 if double else 3
365
  self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
366
 
367
- def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
368
  out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
369
-
370
- return (
371
- ModulationOut(*out[:3]),
372
- ModulationOut(*out[3:]) if self.is_double else None,
373
- )
374
-
375
 
376
  class DoubleStreamBlock(nn.Module):
377
  def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
378
  super().__init__()
379
-
380
  mlp_hidden_dim = int(hidden_size * mlp_ratio)
381
  self.num_heads = num_heads
382
  self.hidden_size = hidden_size
383
  self.img_mod = Modulation(hidden_size, double=True)
384
  self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
385
  self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
386
-
387
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
388
  self.img_mlp = nn.Sequential(
389
  nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
390
  nn.GELU(approximate="tanh"),
391
  nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
392
  )
393
-
394
  self.txt_mod = Modulation(hidden_size, double=True)
395
  self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
396
  self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
397
-
398
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
399
  self.txt_mlp = nn.Sequential(
400
  nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
@@ -406,50 +324,41 @@ class DoubleStreamBlock(nn.Module):
406
  img_mod1, img_mod2 = self.img_mod(vec)
407
  txt_mod1, txt_mod2 = self.txt_mod(vec)
408
 
409
- # prepare image for attention
410
  img_modulated = self.img_norm1(img)
411
  img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
412
  img_qkv = self.img_attn.qkv(img_modulated)
413
- # img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
414
  B, L, _ = img_qkv.shape
415
  H = self.num_heads
416
  D = img_qkv.shape[-1] // (3 * H)
417
  img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
418
  img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
419
 
420
- # prepare txt for attention
421
  txt_modulated = self.txt_norm1(txt)
422
  txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
423
  txt_qkv = self.txt_attn.qkv(txt_modulated)
424
- # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
425
  B, L, _ = txt_qkv.shape
426
  txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
427
  txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
428
 
429
- # run actual attention
430
  q = torch.cat((txt_q, img_q), dim=2)
431
  k = torch.cat((txt_k, img_k), dim=2)
432
  v = torch.cat((txt_v, img_v), dim=2)
433
-
434
  attn = attention(q, k, v, pe=pe)
435
  txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
436
 
437
- # calculate the img bloks
438
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
439
  img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
440
 
441
- # calculate the txt bloks
442
  txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
443
  txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
444
  return img, txt
445
 
446
-
447
  class SingleStreamBlock(nn.Module):
448
- """
449
- A DiT block with parallel linear layers as described in
450
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
451
- """
452
-
453
  def __init__(
454
  self,
455
  hidden_size: int,
@@ -462,18 +371,12 @@ class SingleStreamBlock(nn.Module):
462
  self.num_heads = num_heads
463
  head_dim = hidden_size // num_heads
464
  self.scale = qk_scale or head_dim**-0.5
465
-
466
  self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
467
- # qkv and mlp_in
468
  self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
469
- # proj and mlp_out
470
  self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
471
-
472
  self.norm = QKNorm(head_dim)
473
-
474
  self.hidden_size = hidden_size
475
  self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
476
-
477
  self.mlp_act = nn.GELU(approximate="tanh")
478
  self.modulation = Modulation(hidden_size, double=False)
479
 
@@ -481,18 +384,12 @@ class SingleStreamBlock(nn.Module):
481
  mod, _ = self.modulation(vec)
482
  x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
483
  qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
484
-
485
- # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
486
  qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
487
  q, k, v = qkv.permute(2, 0, 3, 1, 4)
488
  q, k = self.norm(q, k, v)
489
-
490
- # compute attention
491
  attn = attention(q, k, v, pe=pe)
492
- # compute activation in mlp stream, cat again and run second linear layer
493
  output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
494
  return x + mod.gate * output
495
-
496
 
497
  class LastLayer(nn.Module):
498
  def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
@@ -506,8 +403,10 @@ class LastLayer(nn.Module):
506
  x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
507
  x = self.linear(x)
508
  return x
509
-
510
-
 
 
511
  class FluxParams:
512
  in_channels: int = 64
513
  vec_in_dim: int = 768
@@ -517,20 +416,14 @@ class FluxParams:
517
  num_heads: int = 24
518
  depth: int = 19
519
  depth_single_blocks: int = 38
520
- axes_dim: list = [16, 56, 56]
521
- theta: int = 10_000
522
  qkv_bias: bool = True
523
  guidance_embed: bool = True
524
 
525
-
526
  class Flux(nn.Module):
527
- """
528
- Transformer model for flow matching on sequences.
529
- """
530
-
531
  def __init__(self, params = FluxParams()):
532
  super().__init__()
533
-
534
  self.params = params
535
  self.in_channels = params.in_channels
536
  self.out_channels = self.in_channels
@@ -585,57 +478,46 @@ class Flux(nn.Module):
585
  ) -> Tensor:
586
  if img.ndim != 3 or txt.ndim != 3:
587
  raise ValueError("Input img and txt tensors must have 3 dimensions.")
588
-
589
- # running on sequences img
590
  img = self.img_in(img)
591
  vec = self.time_in(timestep_embedding(timesteps, 256))
592
  if self.params.guidance_embed:
593
  if guidance is None:
594
- raise ValueError("Didn't get guidance strength for guidance distilled model.")
595
  vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
596
  vec = vec + self.vector_in(y)
597
  txt = self.txt_in(txt)
598
-
599
  ids = torch.cat((txt_ids, img_ids), dim=1)
600
  pe = self.pe_embedder(ids)
601
-
602
  for block in self.double_blocks:
603
  img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
604
-
605
  img = torch.cat((txt, img), 1)
606
  for block in self.single_blocks:
607
  img = block(img, vec=vec, pe=pe)
608
  img = img[:, txt.shape[1] :, ...]
609
-
610
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
611
  return img
612
 
613
-
614
  def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
 
615
  bs, c, h, w = img.shape
616
  if bs == 1 and not isinstance(prompt, str):
617
  bs = len(prompt)
618
-
619
  img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
620
  if img.shape[0] == 1 and bs > 1:
621
  img = repeat(img, "1 ... -> bs ...", bs=bs)
622
-
623
  img_ids = torch.zeros(h // 2, w // 2, 3)
624
  img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
625
  img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
626
  img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
627
-
628
  if isinstance(prompt, str):
629
  prompt = [prompt]
630
  txt = t5(prompt)
631
  if txt.shape[0] == 1 and bs > 1:
632
  txt = repeat(txt, "1 ... -> bs ...", bs=bs)
633
  txt_ids = torch.zeros(bs, txt.shape[1], 3)
634
-
635
  vec = clip(prompt)
636
  if vec.shape[0] == 1 and bs > 1:
637
  vec = repeat(vec, "1 ... -> bs ...", bs=bs)
638
-
639
  return {
640
  "img": img,
641
  "img_ids": img_ids.to(img.device),
@@ -644,19 +526,18 @@ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[st
644
  "vec": vec.to(img.device),
645
  }
646
 
647
-
648
  def time_shift(mu: float, sigma: float, t: Tensor):
 
649
  return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
650
 
651
-
652
  def get_lin_function(
653
  x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
654
  ) -> Callable[[float], float]:
 
655
  m = (y2 - y1) / (x2 - x1)
656
  b = y1 - m * x1
657
  return lambda x: m * x + b
658
 
659
-
660
  def get_schedule(
661
  num_steps: int,
662
  image_seq_len: int,
@@ -664,31 +545,25 @@ def get_schedule(
664
  max_shift: float = 1.15,
665
  shift: bool = True,
666
  ) -> list[float]:
667
- # extra step for zero
 
668
  timesteps = torch.linspace(1, 0, num_steps + 1)
669
-
670
- # shifting the schedule to favor high timesteps for higher signal images
671
  if shift:
672
- # eastimate mu based on linear estimation between two points
673
  mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
674
  timesteps = time_shift(mu, 1.0, timesteps)
675
-
676
  return timesteps.tolist()
677
 
678
-
679
  def denoise(
680
  model: Flux,
681
- # model input
682
  img: Tensor,
683
  img_ids: Tensor,
684
  txt: Tensor,
685
  txt_ids: Tensor,
686
  vec: Tensor,
687
- # sampling parameters
688
  timesteps: list[float],
689
  guidance: float = 4.0,
690
  ):
691
- # this is ignored for schnell
692
  guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
693
  for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
694
  t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@@ -704,7 +579,6 @@ def denoise(
704
  img = img + (t_prev - t_curr) * pred
705
  return img
706
 
707
-
708
  def unpack(x: Tensor, height: int, width: int) -> Tensor:
709
  return rearrange(
710
  x,
@@ -722,13 +596,11 @@ class SamplingOptions:
722
  height: int
723
  guidance: float
724
  seed: int | None
725
-
726
 
727
  def get_image(image) -> torch.Tensor | None:
728
  if image is None:
729
  return None
730
  image = Image.fromarray(image).convert("RGB")
731
-
732
  transform = transforms.Compose([
733
  transforms.ToTensor(),
734
  transforms.Lambda(lambda x: 2.0 * x - 1.0),
@@ -736,10 +608,7 @@ def get_image(image) -> torch.Tensor | None:
736
  img: torch.Tensor = transform(image)
737
  return img[None, ...]
738
 
739
-
740
- # ---------------- Demo ----------------
741
-
742
-
743
  from huggingface_hub import hf_hub_download
744
  from safetensors.torch import load_file
745
 
@@ -749,10 +618,6 @@ model = Flux().to(dtype=torch.bfloat16, device="cuda")
749
  result = model.load_state_dict(sd)
750
  model_zero_init = False
751
 
752
- # model = Flux().to(dtype=torch.bfloat16, device="cuda")
753
- # result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
754
-
755
-
756
  @spaces.GPU
757
  @torch.no_grad()
758
  def generate_image(
@@ -760,38 +625,17 @@ def generate_image(
760
  do_img2img, init_image, image2image_strength, resize_img,
761
  progress=gr.Progress(track_tqdm=True),
762
  ):
763
- translated_prompt = prompt
764
-
765
- # 한글 또는 일본어 문자 감지
766
- def contains_korean(text):
767
- return any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text)
768
-
769
- def contains_japanese(text):
770
- return any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' or '\u4E00' <= c <= '\u9FFF' for c in text)
771
-
772
- # 한글이나 일본어가 있으면 번역
773
- if contains_korean(prompt):
774
- translated_prompt = ko_translator(prompt, max_length=512)[0]['translation_text']
775
- print(f"Translated Korean prompt: {translated_prompt}")
776
- prompt = translated_prompt
777
- elif contains_japanese(prompt):
778
- translated_prompt = ja_translator(prompt, max_length=512)[0]['translation_text']
779
- print(f"Translated Japanese prompt: {translated_prompt}")
780
- prompt = translated_prompt
781
-
782
  if seed == 0:
783
- seed = int(random.random() * 1000000)
784
-
785
  device = "cuda" if torch.cuda.is_available() else "cpu"
786
  torch_device = torch.device(device)
787
 
788
-
789
-
790
  global model, model_zero_init
791
  if not model_zero_init:
792
  model = model.to(torch_device)
793
  model_zero_init = True
794
-
795
  if do_img2img and init_image is not None:
796
  init_image = get_image(init_image)
797
  if resize_img:
@@ -802,84 +646,80 @@ def generate_image(
802
  height = init_image.shape[-2]
803
  width = init_image.shape[-1]
804
  init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
805
- init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
806
 
807
  generator = torch.Generator(device=device).manual_seed(seed)
808
- x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
809
-
810
- num_steps = inference_steps
811
- timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
 
 
 
 
 
 
 
812
 
813
  if do_img2img and init_image is not None:
814
- t_idx = int((1 - image2image_strength) * num_steps)
815
  t = timesteps[t_idx]
816
  timesteps = timesteps[t_idx:]
817
  x = t * x + (1.0 - t) * init_image.to(x.dtype)
818
 
819
  inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
820
  x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
821
-
822
- # with profile(activities=[ProfilerActivity.CPU],record_shapes=True,profile_memory=True) as prof:
823
- # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
824
-
825
  x = unpack(x.float(), height, width)
 
826
  with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
827
- x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
828
  x = ae.decode(x).sample
829
 
830
  x = x.clamp(-1, 1)
831
  x = rearrange(x[0], "c h w -> h w c")
832
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
833
-
834
-
835
- return img, seed, translated_prompt
836
-
837
- css = """
838
- footer {
839
- visibility: hidden;
840
- }
841
- """
842
 
843
  def create_demo():
844
- with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
845
-
846
- gr.Markdown("# FLUXllama Multilingual v3")
847
-
 
 
 
 
 
 
848
  with gr.Row():
849
  with gr.Column():
850
- prompt = gr.Textbox(label="Prompt(Supports English, Korean, and Japanese)", value="A cute and fluffy golden retriever puppy sitting upright, holding a neatly designed white sign with bold, colorful lettering that reads 'Have a Happy Day!' in cheerful fonts. The puppy has expressive, sparkling eyes, a happy smile, and fluffy ears slightly flopped. The background is a vibrant and sunny meadow with soft-focus flowers, glowing sunlight filtering through the trees, and a warm golden glow that enhances the joyful atmosphere. The sign is framed with small decorative flowers, adding a charming and wholesome touch. Ensure the text on the sign is clear and legible.")
851
-
852
- width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=768)
853
- height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=768)
854
  guidance = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, label="Guidance", value=3.5)
855
  inference_steps = gr.Slider(
856
  label="Inference steps",
857
  minimum=1,
858
  maximum=30,
859
  step=1,
860
- value=30,
861
  )
862
  seed = gr.Number(label="Seed", precision=-1)
863
  do_img2img = gr.Checkbox(label="Image to Image", value=False)
864
- init_image = gr.Image(label="Input Image", visible=False)
865
- image2image_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Noising strength", value=0.8, visible=False)
866
- resize_img = gr.Checkbox(label="Resize image", value=True, visible=False)
867
- generate_button = gr.Button("Generate")
868
-
 
 
 
 
 
 
869
  with gr.Column():
870
- output_image = gr.Image(label="Generated Image")
871
- output_seed = gr.Text(label="Used Seed")
872
- output_translated = gr.Text(label="Translated Prompt")
873
-
874
- # Examples 컴포넌트 추가
875
- gr.Examples(
876
- examples=[
877
- "a tiny astronaut hatching from an egg on the moon",
878
- "썬글라스 착용한 귀여운 흰색 고양이가 'LOVE'라는 표지판을 들고있다",
879
- "桜が流れる夜の街、照明",
880
- ],
881
- inputs=prompt, # 예제가 입력될 컴포넌트 지정
882
- )
883
 
884
  do_img2img.change(
885
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
@@ -889,12 +729,20 @@ def create_demo():
889
 
890
  generate_button.click(
891
  fn=generate_image,
892
- inputs=[prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img],
893
- outputs=[output_image, output_seed, output_translated]
 
 
 
 
894
  )
895
-
896
  return demo
897
 
898
  if __name__ == "__main__":
 
899
  demo = create_demo()
900
- demo.launch()
 
 
 
 
 
1
+ import os
2
  import spaces
3
 
4
  import time
 
6
  import torch
7
  from PIL import Image
8
  from torchvision import transforms
9
+ from dataclasses import dataclass, field
10
  import math
11
  from typing import Callable
12
 
 
21
  from torch import Tensor, nn
22
  from transformers import CLIPTextModel, CLIPTokenizer
23
  from transformers import T5EncoderModel, T5Tokenizer
 
 
24
 
25
+ # ---------------- Encoders ----------------
 
26
 
27
  class HFEmbedder(nn.Module):
28
  def __init__(self, version: str, max_length: int, **hf_kwargs):
 
57
  output_hidden_states=False,
58
  )
59
  return outputs[self.output_key]
 
60
 
61
  device = "cuda"
62
  t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
63
  clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
64
  ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
 
 
 
65
 
66
  # ---------------- NF4 ----------------
67
 
 
68
  def functional_linear_4bits(x, weight, bias):
69
+ import bitsandbytes as bnb
70
  out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
71
  out = out.to(x)
72
  return out
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  class ForgeParams4bit(Params4bit):
75
+ """Subclass to force re-quantization to GPU if needed."""
76
  def to(self, *args, **kwargs):
77
+ import torch
78
  device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
79
  if device is not None and device.type == "cuda" and not self.bnb_quantized:
80
  return self._quantize(device)
 
82
  n = ForgeParams4bit(
83
  torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
84
  requires_grad=self.requires_grad,
85
+ quant_state=self.quant_state,
 
 
86
  compress_statistics=False,
87
  blocksize=64,
88
  quant_type=self.quant_type,
 
95
  self.quant_state = n.quant_state
96
  return n
97
 
98
+ class ForgeLoader4Bit(nn.Module):
 
99
  def __init__(self, *, device, dtype, quant_type, **kwargs):
100
  super().__init__()
101
+ self.dummy = nn.Parameter(torch.empty(1, device=device, dtype=dtype))
102
  self.weight = None
103
  self.quant_state = None
104
  self.bias = None
 
106
 
107
  def _save_to_state_dict(self, destination, prefix, keep_vars):
108
  super()._save_to_state_dict(destination, prefix, keep_vars)
109
+ from bitsandbytes.nn.modules import QuantState
110
  quant_state = getattr(self.weight, "quant_state", None)
111
  if quant_state is not None:
112
  for k, v in quant_state.as_dict(packed=True).items():
113
  destination[prefix + "weight." + k] = v if keep_vars else v.detach()
114
  return
115
 
116
+ def _load_from_state_dict(
117
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
118
+ ):
119
+ from bitsandbytes.nn.modules import Params4bit
120
+ import torch
121
 
122
+ quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
123
  if any('bitsandbytes' in k for k in quant_state_keys):
124
  quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
 
125
  self.weight = ForgeParams4bit.from_prequantized(
126
  data=state_dict[prefix + 'weight'],
127
  quantized_stats=quant_state_dict,
128
  requires_grad=False,
 
129
  device=torch.device('cuda'),
130
  module=self
131
  )
 
133
 
134
  if prefix + 'bias' in state_dict:
135
  self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
 
136
  del self.dummy
137
  elif hasattr(self, 'dummy'):
138
  if prefix + 'weight' in state_dict:
 
153
  else:
154
  super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
155
 
 
156
  class Linear(ForgeLoader4Bit):
157
  def __init__(self, *args, device=None, dtype=None, **kwargs):
158
  super().__init__(device=device, dtype=dtype, quant_type='nf4')
159
 
160
  def forward(self, x):
161
  self.weight.quant_state = self.quant_state
 
162
  if self.bias is not None and self.bias.dtype != x.dtype:
 
 
163
  self.bias.data = self.bias.data.to(x.dtype)
 
164
  return functional_linear_4bits(x, self.weight, self.bias)
 
165
 
166
+ import torch.nn as nn
167
  nn.Linear = Linear
168
 
 
169
  # ---------------- Model ----------------
170
 
 
171
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
172
  q, k = apply_rope(q, k, pe)
 
173
  x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
 
174
  x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1)
 
175
  return x
176
 
 
177
  def rope(pos, dim, theta):
178
+ import torch
179
  scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
180
  omega = 1.0 / (theta ** scale)
 
 
181
  out = pos.unsqueeze(-1) * omega.unsqueeze(0)
 
182
  cos_out = torch.cos(out)
183
  sin_out = torch.sin(out)
184
  out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
 
 
185
  b, n, d, _ = out.shape
186
  out = out.view(b, n, d, 2, 2)
 
187
  return out.float()
188
 
 
189
  def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
190
  xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
191
  xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
 
193
  xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
194
  return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
195
 
 
196
  class EmbedND(nn.Module):
197
  def __init__(self, dim: int, theta: int, axes_dim: list[int]):
198
  super().__init__()
 
201
  self.axes_dim = axes_dim
202
 
203
  def forward(self, ids: Tensor) -> Tensor:
204
+ import torch
205
  n_axes = ids.shape[-1]
206
  emb = torch.cat(
207
  [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
208
  dim=-3,
209
  )
 
210
  return emb.unsqueeze(1)
211
 
 
212
  def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
213
+ import torch, math
 
 
 
 
 
 
 
214
  t = time_factor * t
215
  half = dim // 2
 
 
 
 
 
216
  freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
 
217
  args = t[:, None].float() * freqs[None]
218
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
219
  if dim % 2:
 
222
  embedding = embedding.to(t)
223
  return embedding
224
 
 
225
  class MLPEmbedder(nn.Module):
226
  def __init__(self, in_dim: int, hidden_dim: int):
227
  super().__init__()
 
232
  def forward(self, x: Tensor) -> Tensor:
233
  return self.out_layer(self.silu(self.in_layer(x)))
234
 
 
235
  class RMSNorm(torch.nn.Module):
236
  def __init__(self, dim: int):
237
  super().__init__()
238
  self.scale = nn.Parameter(torch.ones(dim))
239
 
240
  def forward(self, x: Tensor):
241
+ import torch
242
  x_dtype = x.dtype
243
  x = x.float()
244
  rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
245
  return (x * rrms).to(dtype=x_dtype) * self.scale
246
 
 
247
  class QKNorm(torch.nn.Module):
248
  def __init__(self, dim: int):
249
  super().__init__()
 
255
  k = self.key_norm(k)
256
  return q.to(v), k.to(v)
257
 
 
258
  class SelfAttention(nn.Module):
259
  def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
260
  super().__init__()
261
  self.num_heads = num_heads
 
 
262
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
263
+ head_dim = dim // num_heads
264
  self.norm = QKNorm(head_dim)
265
  self.proj = nn.Linear(dim, dim)
266
 
267
  def forward(self, x: Tensor, pe: Tensor) -> Tensor:
268
  qkv = self.qkv(x)
 
269
  B, L, _ = qkv.shape
270
  qkv = qkv.view(B, L, 3, self.num_heads, -1)
271
  q, k, v = qkv.permute(2, 0, 3, 1, 4)
 
274
  x = self.proj(x)
275
  return x
276
 
277
+ from dataclasses import dataclass
278
 
279
  @dataclass
280
  class ModulationOut:
 
282
  scale: Tensor
283
  gate: Tensor
284
 
 
285
  class Modulation(nn.Module):
286
  def __init__(self, dim: int, double: bool):
287
  super().__init__()
 
289
  self.multiplier = 6 if double else 3
290
  self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
291
 
292
+ def forward(self, vec: Tensor):
293
  out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
294
+ first = ModulationOut(*out[:3])
295
+ second = ModulationOut(*out[3:]) if self.is_double else None
296
+ return first, second
 
 
 
297
 
298
  class DoubleStreamBlock(nn.Module):
299
  def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
300
  super().__init__()
 
301
  mlp_hidden_dim = int(hidden_size * mlp_ratio)
302
  self.num_heads = num_heads
303
  self.hidden_size = hidden_size
304
  self.img_mod = Modulation(hidden_size, double=True)
305
  self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
306
  self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
 
307
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
308
  self.img_mlp = nn.Sequential(
309
  nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
310
  nn.GELU(approximate="tanh"),
311
  nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
312
  )
 
313
  self.txt_mod = Modulation(hidden_size, double=True)
314
  self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
315
  self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
 
316
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
317
  self.txt_mlp = nn.Sequential(
318
  nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
 
324
  img_mod1, img_mod2 = self.img_mod(vec)
325
  txt_mod1, txt_mod2 = self.txt_mod(vec)
326
 
327
+ # Image attention
328
  img_modulated = self.img_norm1(img)
329
  img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
330
  img_qkv = self.img_attn.qkv(img_modulated)
 
331
  B, L, _ = img_qkv.shape
332
  H = self.num_heads
333
  D = img_qkv.shape[-1] // (3 * H)
334
  img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
335
  img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
336
 
337
+ # Text attention
338
  txt_modulated = self.txt_norm1(txt)
339
  txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
340
  txt_qkv = self.txt_attn.qkv(txt_modulated)
 
341
  B, L, _ = txt_qkv.shape
342
  txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
343
  txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
344
 
345
+ # Combined attention
346
  q = torch.cat((txt_q, img_q), dim=2)
347
  k = torch.cat((txt_k, img_k), dim=2)
348
  v = torch.cat((txt_v, img_v), dim=2)
 
349
  attn = attention(q, k, v, pe=pe)
350
  txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
351
 
352
+ # Img final
353
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
354
  img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
355
 
356
+ # Text final
357
  txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
358
  txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
359
  return img, txt
360
 
 
361
  class SingleStreamBlock(nn.Module):
 
 
 
 
 
362
  def __init__(
363
  self,
364
  hidden_size: int,
 
371
  self.num_heads = num_heads
372
  head_dim = hidden_size // num_heads
373
  self.scale = qk_scale or head_dim**-0.5
 
374
  self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
 
375
  self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
 
376
  self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
 
377
  self.norm = QKNorm(head_dim)
 
378
  self.hidden_size = hidden_size
379
  self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
 
380
  self.mlp_act = nn.GELU(approximate="tanh")
381
  self.modulation = Modulation(hidden_size, double=False)
382
 
 
384
  mod, _ = self.modulation(vec)
385
  x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
386
  qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
 
 
387
  qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
388
  q, k, v = qkv.permute(2, 0, 3, 1, 4)
389
  q, k = self.norm(q, k, v)
 
 
390
  attn = attention(q, k, v, pe=pe)
 
391
  output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
392
  return x + mod.gate * output
 
393
 
394
  class LastLayer(nn.Module):
395
  def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
 
403
  x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
404
  x = self.linear(x)
405
  return x
406
+
407
+ from dataclasses import dataclass, field
408
+
409
+ @dataclass
410
  class FluxParams:
411
  in_channels: int = 64
412
  vec_in_dim: int = 768
 
416
  num_heads: int = 24
417
  depth: int = 19
418
  depth_single_blocks: int = 38
419
+ axes_dim: list[int] = field(default_factory=lambda: [16, 56, 56])
420
+ theta: int = 10000
421
  qkv_bias: bool = True
422
  guidance_embed: bool = True
423
 
 
424
  class Flux(nn.Module):
 
 
 
 
425
  def __init__(self, params = FluxParams()):
426
  super().__init__()
 
427
  self.params = params
428
  self.in_channels = params.in_channels
429
  self.out_channels = self.in_channels
 
478
  ) -> Tensor:
479
  if img.ndim != 3 or txt.ndim != 3:
480
  raise ValueError("Input img and txt tensors must have 3 dimensions.")
 
 
481
  img = self.img_in(img)
482
  vec = self.time_in(timestep_embedding(timesteps, 256))
483
  if self.params.guidance_embed:
484
  if guidance is None:
485
+ raise ValueError("No guidance strength provided for guidance-distilled model.")
486
  vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
487
  vec = vec + self.vector_in(y)
488
  txt = self.txt_in(txt)
 
489
  ids = torch.cat((txt_ids, img_ids), dim=1)
490
  pe = self.pe_embedder(ids)
 
491
  for block in self.double_blocks:
492
  img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
 
493
  img = torch.cat((txt, img), 1)
494
  for block in self.single_blocks:
495
  img = block(img, vec=vec, pe=pe)
496
  img = img[:, txt.shape[1] :, ...]
497
+ img = self.final_layer(img, vec)
 
498
  return img
499
 
 
500
  def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
501
+ import torch
502
  bs, c, h, w = img.shape
503
  if bs == 1 and not isinstance(prompt, str):
504
  bs = len(prompt)
 
505
  img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
506
  if img.shape[0] == 1 and bs > 1:
507
  img = repeat(img, "1 ... -> bs ...", bs=bs)
 
508
  img_ids = torch.zeros(h // 2, w // 2, 3)
509
  img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
510
  img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
511
  img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
 
512
  if isinstance(prompt, str):
513
  prompt = [prompt]
514
  txt = t5(prompt)
515
  if txt.shape[0] == 1 and bs > 1:
516
  txt = repeat(txt, "1 ... -> bs ...", bs=bs)
517
  txt_ids = torch.zeros(bs, txt.shape[1], 3)
 
518
  vec = clip(prompt)
519
  if vec.shape[0] == 1 and bs > 1:
520
  vec = repeat(vec, "1 ... -> bs ...", bs=bs)
 
521
  return {
522
  "img": img,
523
  "img_ids": img_ids.to(img.device),
 
526
  "vec": vec.to(img.device),
527
  }
528
 
 
529
  def time_shift(mu: float, sigma: float, t: Tensor):
530
+ import math
531
  return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
532
 
 
533
  def get_lin_function(
534
  x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
535
  ) -> Callable[[float], float]:
536
+ import math
537
  m = (y2 - y1) / (x2 - x1)
538
  b = y1 - m * x1
539
  return lambda x: m * x + b
540
 
 
541
  def get_schedule(
542
  num_steps: int,
543
  image_seq_len: int,
 
545
  max_shift: float = 1.15,
546
  shift: bool = True,
547
  ) -> list[float]:
548
+ import torch
549
+ import math
550
  timesteps = torch.linspace(1, 0, num_steps + 1)
 
 
551
  if shift:
 
552
  mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
553
  timesteps = time_shift(mu, 1.0, timesteps)
 
554
  return timesteps.tolist()
555
 
 
556
  def denoise(
557
  model: Flux,
 
558
  img: Tensor,
559
  img_ids: Tensor,
560
  txt: Tensor,
561
  txt_ids: Tensor,
562
  vec: Tensor,
 
563
  timesteps: list[float],
564
  guidance: float = 4.0,
565
  ):
566
+ import torch
567
  guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
568
  for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
569
  t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
 
579
  img = img + (t_prev - t_curr) * pred
580
  return img
581
 
 
582
  def unpack(x: Tensor, height: int, width: int) -> Tensor:
583
  return rearrange(
584
  x,
 
596
  height: int
597
  guidance: float
598
  seed: int | None
 
599
 
600
  def get_image(image) -> torch.Tensor | None:
601
  if image is None:
602
  return None
603
  image = Image.fromarray(image).convert("RGB")
 
604
  transform = transforms.Compose([
605
  transforms.ToTensor(),
606
  transforms.Lambda(lambda x: 2.0 * x - 1.0),
 
608
  img: torch.Tensor = transform(image)
609
  return img[None, ...]
610
 
611
+ # Load the NF4 quantized checkpoint
 
 
 
612
  from huggingface_hub import hf_hub_download
613
  from safetensors.torch import load_file
614
 
 
618
  result = model.load_state_dict(sd)
619
  model_zero_init = False
620
 
 
 
 
 
621
  @spaces.GPU
622
  @torch.no_grad()
623
  def generate_image(
 
625
  do_img2img, init_image, image2image_strength, resize_img,
626
  progress=gr.Progress(track_tqdm=True),
627
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  if seed == 0:
629
+ seed = int(random.random() * 1_000_000)
630
+
631
  device = "cuda" if torch.cuda.is_available() else "cpu"
632
  torch_device = torch.device(device)
633
 
 
 
634
  global model, model_zero_init
635
  if not model_zero_init:
636
  model = model.to(torch_device)
637
  model_zero_init = True
638
+
639
  if do_img2img and init_image is not None:
640
  init_image = get_image(init_image)
641
  if resize_img:
 
646
  height = init_image.shape[-2]
647
  width = init_image.shape[-1]
648
  init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
649
+ init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
650
 
651
  generator = torch.Generator(device=device).manual_seed(seed)
652
+ x = torch.randn(
653
+ 1,
654
+ 16,
655
+ 2 * math.ceil(height / 16),
656
+ 2 * math.ceil(width / 16),
657
+ device=device,
658
+ dtype=torch.bfloat16,
659
+ generator=generator
660
+ )
661
+
662
+ timesteps = get_schedule(inference_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
663
 
664
  if do_img2img and init_image is not None:
665
+ t_idx = int((1 - image2image_strength) * inference_steps)
666
  t = timesteps[t_idx]
667
  timesteps = timesteps[t_idx:]
668
  x = t * x + (1.0 - t) * init_image.to(x.dtype)
669
 
670
  inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
671
  x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
 
 
 
 
672
  x = unpack(x.float(), height, width)
673
+
674
  with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
675
+ x = (x / ae.config.scaling_factor) + ae.config.shift_factor
676
  x = ae.decode(x).sample
677
 
678
  x = x.clamp(-1, 1)
679
  x = rearrange(x[0], "c h w -> h w c")
680
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
681
+ return img, seed
 
 
 
 
 
 
 
 
682
 
683
  def create_demo():
684
+ with gr.Blocks(css=".gradio-container {background-color: #282828 !important;}") as demo:
685
+ gr.HTML(
686
+ """
687
+ <div style="text-align: center; margin: 0 auto;">
688
+ <h1 style="color: #ffffff; font-weight: 900;">
689
+ FluxLLama
690
+ </h1>
691
+ </div>
692
+ """
693
+ )
694
  with gr.Row():
695
  with gr.Column():
696
+ prompt = gr.Textbox(label="Prompt", value="A majestic castle on top of a floating island")
697
+ width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=640)
698
+ height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=640)
 
699
  guidance = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, label="Guidance", value=3.5)
700
  inference_steps = gr.Slider(
701
  label="Inference steps",
702
  minimum=1,
703
  maximum=30,
704
  step=1,
705
+ value=16,
706
  )
707
  seed = gr.Number(label="Seed", precision=-1)
708
  do_img2img = gr.Checkbox(label="Image to Image", value=False)
709
+ init_image = gr.Image(label="Initial Image", visible=False)
710
+ image2image_strength = gr.Slider(
711
+ minimum=0.0,
712
+ maximum=1.0,
713
+ step=0.01,
714
+ label="Noising Strength",
715
+ value=0.8,
716
+ visible=False
717
+ )
718
+ resize_img = gr.Checkbox(label="Resize Initial Image", value=True, visible=False)
719
+ generate_button = gr.Button("Generate", variant="primary")
720
  with gr.Column():
721
+ output_image = gr.Image(label="Result")
722
+ output_seed = gr.Text(label="Seed Used")
 
 
 
 
 
 
 
 
 
 
 
723
 
724
  do_img2img.change(
725
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
 
729
 
730
  generate_button.click(
731
  fn=generate_image,
732
+ inputs=[
733
+ prompt, width, height, guidance,
734
+ inference_steps, seed, do_img2img,
735
+ init_image, image2image_strength, resize_img
736
+ ],
737
+ outputs=[output_image, output_seed]
738
  )
 
739
  return demo
740
 
741
  if __name__ == "__main__":
742
+ # Create the demo
743
  demo = create_demo()
744
+ # Enable the queue to handle concurrency
745
+ demo.queue()
746
+ # Launch with show_api=False and share=True to avoid the "bool is not iterable" error
747
+ # and the "ValueError: When localhost is not accessible..." error.
748
+ demo.launch(show_api=False, share=True, server_name="0.0.0.0", mcp_server=True)