|
import os |
|
import spaces |
|
import time |
|
import gradio as gr |
|
import torch |
|
from torch import Tensor, nn |
|
from PIL import Image |
|
from torchvision import transforms |
|
from dataclasses import dataclass |
|
import math |
|
from typing import Callable |
|
import random |
|
from tqdm import tqdm |
|
import bitsandbytes as bnb |
|
from bitsandbytes.nn.modules import Params4bit, QuantState |
|
from transformers import ( |
|
MarianTokenizer, |
|
MarianMTModel, |
|
CLIPTextModel, CLIPTokenizer, |
|
T5EncoderModel, T5Tokenizer |
|
) |
|
from diffusers import AutoencoderKL |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
from einops import rearrange, repeat |
|
|
|
|
|
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
trans_tokenizer = MarianTokenizer.from_pretrained( |
|
"Helsinki-NLP/opus-mt-ko-en" |
|
) |
|
trans_model = MarianMTModel.from_pretrained( |
|
"Helsinki-NLP/opus-mt-ko-en", |
|
from_tf=True, |
|
torch_dtype=torch.float32, |
|
).to(torch.device("cpu")) |
|
|
|
def translate_ko_to_en(text: str, max_length: int = 512) -> str: |
|
"""ํ๊ธ โ ์์ด ๋ฒ์ญ (CPU)""" |
|
batch = trans_tokenizer([text], return_tensors="pt", padding=True) |
|
|
|
gen = trans_model.generate( |
|
**batch, max_length=max_length |
|
) |
|
return trans_tokenizer.batch_decode(gen, skip_special_tokens=True)[0] |
|
|
|
|
|
|
|
class HFEmbedder(nn.Module): |
|
def __init__(self, version: str, max_length: int, **hf_kwargs): |
|
super().__init__() |
|
self.is_clip = version.startswith("openai") |
|
self.max_length = max_length |
|
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" |
|
|
|
if self.is_clip: |
|
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( |
|
version, max_length=max_length |
|
) |
|
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( |
|
version, **hf_kwargs |
|
) |
|
else: |
|
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( |
|
version, max_length=max_length |
|
) |
|
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( |
|
version, **hf_kwargs |
|
) |
|
|
|
self.hf_module = self.hf_module.eval().requires_grad_(False) |
|
|
|
def forward(self, text: list[str]) -> Tensor: |
|
batch_encoding = self.tokenizer( |
|
text, |
|
truncation=True, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
return_tensors="pt", |
|
) |
|
outputs = self.hf_module( |
|
input_ids=batch_encoding["input_ids"].to(self.hf_module.device), |
|
attention_mask=None, |
|
output_hidden_states=False, |
|
) |
|
return outputs[self.output_key] |
|
|
|
|
|
t5 = HFEmbedder( |
|
"DeepFloyd/t5-v1_1-xxl", |
|
max_length=512, |
|
torch_dtype=torch.bfloat16 |
|
).to(torch_device) |
|
clip = HFEmbedder( |
|
"openai/clip-vit-large-patch14", |
|
max_length=77, |
|
torch_dtype=torch.bfloat16 |
|
).to(torch_device) |
|
ae = AutoencoderKL.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", |
|
subfolder="vae", |
|
torch_dtype=torch.bfloat16 |
|
).to(torch_device) |
|
|
|
|
|
|
|
def functional_linear_4bits(x, weight, bias): |
|
out = bnb.matmul_4bit( |
|
x, weight.t(), bias=bias, quant_state=weight.quant_state |
|
) |
|
return out.to(x) |
|
|
|
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState: |
|
if state is None: |
|
return None |
|
device = device or state.absmax.device |
|
state2 = ( |
|
QuantState( |
|
absmax=state.state2.absmax.to(device), |
|
shape=state.state2.shape, |
|
code=state.state2.code.to(device), |
|
blocksize=state.state2.blocksize, |
|
quant_type=state.state2.quant_type, |
|
dtype=state.state2.dtype, |
|
) |
|
if state.nested |
|
else None |
|
) |
|
return QuantState( |
|
absmax=state.absmax.to(device), |
|
shape=state.shape, |
|
code=state.code.to(device), |
|
blocksize=state.blocksize, |
|
quant_type=state.quant_type, |
|
dtype=state.dtype, |
|
offset=state.offset.to(device) if state.nested else None, |
|
state2=state2, |
|
) |
|
|
|
class ForgeParams4bit(Params4bit): |
|
def to(self, *args, **kwargs): |
|
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs) |
|
if device is not None and device.type == "cuda" and not self.bnb_quantized: |
|
return self._quantize(device) |
|
new = ForgeParams4bit( |
|
torch.nn.Parameter.to( |
|
self, device=device, dtype=dtype, non_blocking=non_blocking |
|
), |
|
requires_grad=self.requires_grad, |
|
quant_state=copy_quant_state(self.quant_state, device), |
|
compress_statistics=False, |
|
blocksize=self.blocksize, |
|
quant_type=self.quant_type, |
|
quant_storage=self.quant_storage, |
|
bnb_quantized=self.bnb_quantized, |
|
module=self.module, |
|
) |
|
self.module.quant_state = new.quant_state |
|
self.data = new.data |
|
self.quant_state = new.quant_state |
|
return new |
|
|
|
class ForgeLoader4Bit(torch.nn.Module): |
|
def __init__(self, *, device, dtype, quant_type, **kwargs): |
|
super().__init__() |
|
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) |
|
self.weight = None |
|
self.quant_state = None |
|
self.bias = None |
|
self.quant_type = quant_type |
|
|
|
def _load_from_state_dict( |
|
self, |
|
state_dict, |
|
prefix, |
|
local_metadata, |
|
strict, |
|
missing_keys, |
|
unexpected_keys, |
|
error_msgs, |
|
): |
|
qs_keys = { |
|
k[len(prefix + "weight.") :] |
|
for k in state_dict |
|
if k.startswith(prefix + "weight.") |
|
} |
|
if any("bitsandbytes" in k for k in qs_keys): |
|
qs = { |
|
k: state_dict[prefix + "weight." + k] for k in qs_keys |
|
} |
|
self.weight = ForgeParams4bit.from_prequantized( |
|
data=state_dict[prefix + "weight"], |
|
quantized_stats=qs, |
|
requires_grad=False, |
|
device=torch.device("cuda"), |
|
module=self, |
|
) |
|
self.quant_state = self.weight.quant_state |
|
if prefix + "bias" in state_dict: |
|
self.bias = torch.nn.Parameter( |
|
state_dict[prefix + "bias"].to(self.dummy) |
|
) |
|
del self.dummy |
|
else: |
|
super()._load_from_state_dict( |
|
state_dict, |
|
prefix, |
|
local_metadata, |
|
strict, |
|
missing_keys, |
|
unexpected_keys, |
|
error_msgs, |
|
) |
|
|
|
class Linear(ForgeLoader4Bit): |
|
def __init__(self, *args, device=None, dtype=None, **kwargs): |
|
super().__init__(device=device, dtype=dtype, quant_type="nf4") |
|
|
|
def forward(self, x): |
|
self.weight.quant_state = self.quant_state |
|
if self.bias is not None and self.bias.dtype != x.dtype: |
|
self.bias.data = self.bias.data.to(x.dtype) |
|
return functional_linear_4bits(x, self.weight, self.bias) |
|
|
|
nn.Linear = Linear |
|
|
|
|
|
|
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: |
|
|
|
q, k = apply_rope(q, k, pe) |
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v) |
|
x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sd = load_file( |
|
hf_hub_download( |
|
repo_id="lllyasviel/flux1-dev-bnb-nf4", |
|
filename="flux1-dev-bnb-nf4-v2.safetensors", |
|
) |
|
) |
|
sd = { |
|
k.replace("model.diffusion_model.", ""): v |
|
for k, v in sd.items() |
|
if "model.diffusion_model" in k |
|
} |
|
|
|
model = Flux().to(torch_device, dtype=torch.bfloat16) |
|
model.load_state_dict(sd) |
|
model_zero_init = False |
|
|
|
|
|
|
|
def get_image(image) -> torch.Tensor | None: |
|
if image is None: |
|
return None |
|
image = Image.fromarray(image).convert("RGB") |
|
tfm = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Lambda(lambda x: 2.0 * x - 1.0), |
|
] |
|
) |
|
return tfm(image)[None, ...] |
|
|
|
def prepare(t5, clip, img, prompt): |
|
bs, c, h, w = img.shape |
|
img = rearrange( |
|
img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 |
|
) |
|
if bs == 1 and isinstance(prompt, list): |
|
img = repeat(img, "1 ... -> bs ...", bs=len(prompt)) |
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) |
|
img_ids[..., 1] = torch.arange(h // 2, device=img.device)[:, None] |
|
img_ids[..., 2] = torch.arange(w // 2, device=img.device)[None, :] |
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=img.shape[0]) |
|
|
|
txt = t5([prompt] if isinstance(prompt, str) else prompt) |
|
if txt.shape[0] == 1 and img.shape[0] > 1: |
|
txt = repeat(txt, "1 ... -> bs ...", bs=img.shape[0]) |
|
txt_ids = torch.zeros(txt.size(0), txt.size(1), 3, device=img.device) |
|
|
|
vec = clip([prompt] if isinstance(prompt, str) else prompt) |
|
if vec.shape[0] == 1 and img.shape[0] > 1: |
|
vec = repeat(vec, "1 ... -> bs ...", bs=img.shape[0]) |
|
|
|
return { |
|
"img": img, |
|
"img_ids": img_ids, |
|
"txt": txt, |
|
"txt_ids": txt_ids, |
|
"vec": vec, |
|
} |
|
|
|
def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True): |
|
timesteps = torch.linspace(1, 0, num_steps + 1) |
|
if shift: |
|
mu = ((max_shift - base_shift) / (4096 - 256)) * image_seq_len + ( |
|
base_shift - (256 * (max_shift - base_shift) / (4096 - 256)) |
|
) |
|
timesteps = timesteps.exp().div((1 / timesteps - 1) ** 1 + mu) |
|
return timesteps.tolist() |
|
|
|
def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance): |
|
guidance_vec = torch.full( |
|
(img.size(0),), guidance, device=img.device, dtype=img.dtype |
|
) |
|
for t_curr, t_prev in tqdm( |
|
zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1 |
|
): |
|
t_vec = torch.full( |
|
(img.size(0),), t_curr, device=img.device, dtype=img.dtype |
|
) |
|
pred = model( |
|
img=img, |
|
img_ids=img_ids, |
|
txt=txt, |
|
txt_ids=txt_ids, |
|
y=vec, |
|
timesteps=t_vec, |
|
guidance=guidance_vec, |
|
) |
|
img = img + (t_prev - t_curr) * pred |
|
return img |
|
|
|
|
|
|
|
@spaces.GPU |
|
@torch.no_grad() |
|
def generate_image( |
|
prompt, |
|
width, |
|
height, |
|
guidance, |
|
inference_steps, |
|
seed, |
|
do_img2img, |
|
init_image, |
|
image2image_strength, |
|
resize_img, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
|
|
if any("\u3131" <= c <= "\u318E" or "\uAC00" <= c <= "\uD7A3" for c in prompt): |
|
prompt = translate_ko_to_en(prompt) |
|
|
|
if seed == 0: |
|
seed = random.randint(1, 1_000_000) |
|
|
|
global model_zero_init, model |
|
if not model_zero_init: |
|
model = model.to(torch_device) |
|
model_zero_init = True |
|
|
|
if do_img2img and init_image is not None: |
|
init_img = get_image(init_image) |
|
if resize_img: |
|
init_img = torch.nn.functional.interpolate( |
|
init_img, (height, width) |
|
) |
|
else: |
|
h0, w0 = init_img.shape[-2:] |
|
init_img = init_img[..., : 16 * (h0 // 16), : 16 * (w0 // 16)] |
|
height, width = init_img.shape[-2:] |
|
init_img = ae.encode( |
|
init_img.to(torch_device).to(torch.bfloat16) |
|
).latent_dist.sample() |
|
init_img = ( |
|
init_img - ae.config.shift_factor |
|
) * ae.config.scaling_factor |
|
else: |
|
init_img = None |
|
|
|
generator = torch.Generator(device=str(torch_device)).manual_seed(seed) |
|
x = torch.randn( |
|
1, |
|
16, |
|
2 * math.ceil(height / 16), |
|
2 * math.ceil(width / 16), |
|
device=torch_device, |
|
dtype=torch.bfloat16, |
|
generator=generator, |
|
) |
|
timesteps = get_schedule( |
|
inference_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True |
|
) |
|
if do_img2img and init_img is not None: |
|
t_idx = int((1 - image2image_strength) * inference_steps) |
|
t = timesteps[t_idx] |
|
timesteps = timesteps[t_idx:] |
|
x = t * x + (1 - t) * init_img.to(x.dtype) |
|
|
|
inp = prepare(t5, clip, x, prompt) |
|
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance) |
|
|
|
x = rearrange( |
|
x[:, inp["txt"].shape[1] :, ...].float(), |
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)", |
|
h=math.ceil(height / 16), |
|
w=math.ceil(width / 16), |
|
ph=2, |
|
pw=2, |
|
) |
|
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): |
|
x = (x / ae.config.scaling_factor) + ae.config.shift_factor |
|
x = ae.decode(x).sample |
|
|
|
x = x.clamp(-1, 1) |
|
img = Image.fromarray( |
|
(127.5 * (rearrange(x[0], "c h w -> h w c") + 1.0)) |
|
.cpu() |
|
.byte() |
|
.numpy() |
|
) |
|
|
|
return img, seed |
|
|
|
css = """ |
|
footer { |
|
visibility: hidden; |
|
} |
|
""" |
|
|
|
def create_demo(): |
|
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: |
|
gr.Markdown( |
|
"# News! Multilingual version " |
|
"[https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual]" |
|
"(https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual)" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
label="Prompt(ํ๊ธ ๊ฐ๋ฅ)", |
|
value="A cute and fluffy golden retriever puppy sitting upright...", |
|
) |
|
width = gr.Slider(128, 2048, 64, label="Width", value=768) |
|
height = gr.Slider(128, 2048, 64, label="Height", value=768) |
|
guidance = gr.Slider(1.0, 5.0, 0.1, label="Guidance", value=3.5) |
|
steps = gr.Slider(1, 30, 1, label="Inference steps", value=30) |
|
seed = gr.Number(label="Seed", precision=0) |
|
do_i2i = gr.Checkbox(label="Image to Image", value=False) |
|
init_img = gr.Image(label="Input Image", visible=False) |
|
strength = gr.Slider( |
|
0.0, 1.0, 0.01, label="Noising strength", value=0.8, visible=False |
|
) |
|
resize = gr.Checkbox(label="Resize image", value=True, visible=False) |
|
btn = gr.Button("Generate") |
|
with gr.Column(): |
|
out_img = gr.Image(label="Generated Image") |
|
out_seed = gr.Text(label="Used Seed") |
|
|
|
do_i2i.change( |
|
fn=lambda x: [gr.update(visible=x)] * 3, |
|
inputs=[do_i2i], |
|
outputs=[init_img, strength, resize], |
|
) |
|
btn.click( |
|
fn=generate_image, |
|
inputs=[ |
|
prompt, |
|
width, |
|
height, |
|
guidance, |
|
steps, |
|
seed, |
|
do_i2i, |
|
init_img, |
|
strength, |
|
resize, |
|
], |
|
outputs=[out_img, out_seed], |
|
) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
create_demo().launch() |
|
|