Spaces:
Runtime error
Runtime error
import sys, os | |
import gradio as gr | |
## if kgen not exist | |
try: | |
import kgen | |
except: | |
GH_TOKEN = os.getenv("GITHUB_TOKEN") | |
git_url = f"https://{GH_TOKEN}@github.com/KohakuBlueleaf/TIPO-KGen@tipo" | |
## call pip install | |
os.system(f"pip install git+{git_url}") | |
import re | |
import random | |
from time import time | |
import torch | |
from transformers import set_seed | |
if sys.platform == "win32": | |
# dev env in windows, @spaces.GPU will cause problem | |
def GPU(**kwargs): | |
return lambda x: x | |
else: | |
from spaces import GPU | |
import kgen.models as models | |
import kgen.executor.tipo as tipo | |
from kgen.formatter import seperate_tags, apply_format | |
from kgen.generate import generate | |
from diff import load_model, encode_prompts | |
from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT | |
sdxl_pipe = load_model() | |
sdxl_pipe.text_encoder.to("cpu") | |
sdxl_pipe.text_encoder_2.to("cpu") | |
sdxl_pipe.vae.to("cpu") | |
sdxl_pipe.k_diffusion_model.to("cpu") | |
models.load_model("Amber-River/tipo", device="cuda", subfolder="500M-epoch3") | |
generate(max_new_tokens=4) | |
torch.cuda.empty_cache() | |
DEFAULT_TAGS = """ | |
1girl, king halo (umamusume), umamusume, | |
ningen mame, ciloranko, ogipote, misu kasumi, | |
solo, leaning forward, sky, | |
masterpiece, absurdres, sensitive, newest | |
""".strip() | |
DEFAULT_NL = """ | |
An illustration of a girl | |
""".strip() | |
def format_time(timing): | |
total = timing["total"] | |
generate_pass = timing["generate_pass"] | |
result = "" | |
result += f""" | |
### Process Time | |
| Total | {total:5.2f} sec / {generate_pass:5} Passes | {generate_pass/total:7.2f} Passes Per Second| | |
|-|-|-| | |
""" | |
if "generated_tokens" in timing: | |
total_generated_tokens = timing["generated_tokens"] | |
total_input_tokens = timing["input_tokens"] | |
if "generated_tokens" in timing and "total_sampling" in timing: | |
sampling_time = timing["total_sampling"] / 1000 | |
process_time = timing["prompt_process"] / 1000 | |
model_time = timing["total_eval"] / 1000 | |
result += f"""| Process | {process_time:5.2f} sec / {total_input_tokens:5} Tokens | {total_input_tokens/process_time:7.2f} Tokens Per Second| | |
| Sampling | {sampling_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/sampling_time:7.2f} Tokens Per Second| | |
| Eval | {model_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/model_time:7.2f} Tokens Per Second| | |
""" | |
if "generated_tokens" in timing: | |
result += f""" | |
### Processed Tokens: | |
* {total_input_tokens:} Input Tokens | |
* {total_generated_tokens:} Output Tokens | |
""" | |
return result | |
def generate( | |
tags, | |
nl_prompt, | |
black_list, | |
temp, | |
output_format, | |
target_length, | |
top_p, | |
min_p, | |
top_k, | |
seed, | |
escape_brackets, | |
): | |
torch.cuda.empty_cache() | |
default_format = DEFAULT_FORMAT[output_format] | |
tipo.BAN_TAGS = [t.strip() for t in black_list.split(",") if t.strip()] | |
generation_setting = { | |
"seed": seed, | |
"temperature": temp, | |
"top_p": top_p, | |
"min_p": min_p, | |
"top_k": top_k, | |
} | |
inputs = seperate_tags(tags.split(",")) | |
if nl_prompt: | |
if "<|extended|>" in default_format: | |
inputs["extended"] = nl_prompt | |
elif "<|generated|>" in default_format: | |
inputs["generated"] = nl_prompt | |
input_prompt = apply_format(inputs, default_format) | |
if escape_brackets: | |
input_prompt = re.sub(r"([()\[\]])", r"\\\1", input_prompt) | |
meta, operations, general, nl_prompt = tipo.parse_tipo_request( | |
seperate_tags(tags.split(",")), | |
nl_prompt, | |
tag_length_target=target_length, | |
generate_extra_nl_prompt="<|generated|>" in default_format or not nl_prompt, | |
) | |
t0 = time() | |
for result, timing in tipo.tipo_runner_generator( | |
meta, operations, general, nl_prompt, **generation_setting | |
): | |
result = apply_format(result, default_format) | |
if escape_brackets: | |
result = re.sub(r"([()\[\]])", r"\\\1", result) | |
timing["total"] = time() - t0 | |
yield result, input_prompt, format_time(timing) | |
torch.cuda.empty_cache() | |
def generate_image( | |
seed, | |
prompt, | |
prompt2, | |
): | |
torch.cuda.empty_cache() | |
set_seed(seed) | |
sdxl_pipe.text_encoder.to("cuda") | |
sdxl_pipe.text_encoder_2.to("cuda") | |
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( | |
encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT) | |
) | |
sdxl_pipe.vae.to("cuda") | |
sdxl_pipe.k_diffusion_model.to("cuda") | |
print(prompt_embeds.device) | |
result2 = sdxl_pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_embeds2, | |
negative_pooled_prompt_embeds=neg_pooled_embeds2, | |
num_inference_steps=24, | |
width=1024, | |
height=1024, | |
guidance_scale=6.0, | |
).images[0] | |
sdxl_pipe.text_encoder.to("cpu") | |
sdxl_pipe.text_encoder_2.to("cpu") | |
sdxl_pipe.vae.to("cpu") | |
sdxl_pipe.k_diffusion_model.to("cpu") | |
torch.cuda.empty_cache() | |
yield result2, None | |
set_seed(seed) | |
sdxl_pipe.text_encoder.to("cuda") | |
sdxl_pipe.text_encoder_2.to("cuda") | |
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( | |
encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT) | |
) | |
sdxl_pipe.vae.to("cuda") | |
sdxl_pipe.k_diffusion_model.to("cuda") | |
result = sdxl_pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_embeds2, | |
negative_pooled_prompt_embeds=neg_pooled_embeds2, | |
num_inference_steps=24, | |
width=1024, | |
height=1024, | |
guidance_scale=6.0, | |
).images[0] | |
sdxl_pipe.text_encoder.to("cpu") | |
sdxl_pipe.text_encoder_2.to("cpu") | |
sdxl_pipe.vae.to("cpu") | |
sdxl_pipe.k_diffusion_model.to("cpu") | |
torch.cuda.empty_cache() | |
yield result2, result | |
if __name__ == "__main__": | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
with gr.Accordion("Introduction and Instructions", open=False): | |
gr.Markdown( | |
""" | |
## TIPO Demo | |
### What is this | |
TIPO is a tool to extend, generate, refine the input prompt for T2I models. | |
<br>It can work on both Danbooru tags and Natural Language. Which means you can use it on almost all the existed T2I models. | |
<br>You can take it as "pro max" version of [DTG](https://huggingface.co/KBlueLeaf/DanTagGen-delta-rev2) | |
### How to use this demo | |
1. Enter your tags(optional): put the desired tags into "danboru tags" box | |
2. Enter your NL Prompt(optional): put the desired natural language prompt into "Natural Language Prompt" box | |
3. Enter your black list(optional): put the desired black list into "black list" box | |
4. Adjust the settings: length, temp, top_p, min_p, top_k, seed ... | |
4. Click "TIPO" button: you will see refined prompt on "result" box | |
5. If you like the result, click "Generate Image From Result" button | |
* You will see 2 generated images, left one is based on your prompt, right one is based on refined prompt | |
* The backend is diffusers, there are no weighting mechanism, so Escape Brackets is default to False | |
### Why inference code is private? When will it be open sourced? | |
1. This model/tool is still under development, currently is early Alpha version. | |
2. I'm doing some research and projects based on this. | |
3. The model is released under CC-BY-NC-ND License currently. If you have interest, you can implement inference by yourself. | |
4. Once the project/research are done, I will open source all these models/codes with Apache2 license. | |
### Notification | |
**TIPO is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model. | |
<br>The generated image is come from [Kohaku-XL-Zeta](https://huggingface.co/KBlueLeaf/Kohaku-XL-Zeta) model** | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
tags_input = gr.TextArea( | |
label="Danbooru Tags", | |
lines=7, | |
show_copy_button=True, | |
interactive=True, | |
value=DEFAULT_TAGS, | |
placeholder="Enter danbooru tags here", | |
) | |
nl_prompt_input = gr.Textbox( | |
label="Natural Language Prompt", | |
lines=7, | |
show_copy_button=True, | |
interactive=True, | |
value=DEFAULT_NL, | |
placeholder="Enter Natural Language Prompt here", | |
) | |
black_list = gr.TextArea( | |
label="Black List (seperated by comma)", | |
lines=4, | |
interactive=True, | |
value="monochrome", | |
placeholder="Enter tag/nl black list here", | |
) | |
with gr.Column(scale=2): | |
output_format = gr.Dropdown( | |
label="Output Format", | |
choices=list(DEFAULT_FORMAT.keys()), | |
value="Both, tag first (recommend)", | |
) | |
target_length = gr.Dropdown( | |
label="Target Length", | |
choices=["very_short", "short", "long", "very_long"], | |
value="long", | |
) | |
temp = gr.Slider( | |
label="Temp", | |
minimum=0.0, | |
maximum=1.5, | |
value=0.5, | |
step=0.05, | |
) | |
top_p = gr.Slider( | |
label="Top P", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
) | |
min_p = gr.Slider( | |
label="Min P", | |
minimum=0.0, | |
maximum=0.2, | |
value=0.05, | |
step=0.01, | |
) | |
top_k = gr.Slider( | |
label="Top K", minimum=0, maximum=120, value=60, step=1 | |
) | |
with gr.Row(): | |
seed = gr.Number( | |
label="Seed", | |
minimum=0, | |
maximum=2147483647, | |
value=20090220, | |
step=1, | |
) | |
escape_brackets = gr.Checkbox( | |
label="Escape Brackets", value=False | |
) | |
submit = gr.Button("TIPO!", variant="primary") | |
with gr.Accordion("Speed statstics", open=False): | |
cost_time = gr.Markdown() | |
with gr.Column(scale=5): | |
result = gr.TextArea( | |
label="Result", lines=8, show_copy_button=True, interactive=False | |
) | |
input_prompt = gr.Textbox( | |
label="Input Prompt", lines=1, interactive=False, visible=False | |
) | |
gen_img = gr.Button( | |
"Generate Image from Result", variant="primary", interactive=False | |
) | |
with gr.Row(): | |
with gr.Column(): | |
img1 = gr.Image(label="Original Propmt", interactive=False) | |
with gr.Column(): | |
img2 = gr.Image(label="Generated Prompt", interactive=False) | |
def generate_wrapper(*args): | |
yield "", "", "", gr.update(interactive=False), | |
for i in generate(*args): | |
yield *i, gr.update(interactive=False) | |
yield *i, gr.update(interactive=True) | |
submit.click( | |
generate_wrapper, | |
[ | |
tags_input, | |
nl_prompt_input, | |
black_list, | |
temp, | |
output_format, | |
target_length, | |
top_p, | |
min_p, | |
top_k, | |
seed, | |
escape_brackets, | |
], | |
[ | |
result, | |
input_prompt, | |
cost_time, | |
gen_img, | |
], | |
queue=True, | |
) | |
def generate_image_wrapper(seed, result, input_prompt): | |
for img1, img2 in generate_image(seed, result, input_prompt): | |
yield img1, img2, gr.update(interactive=False) | |
yield img1, img2, gr.update(interactive=True) | |
gen_img.click( | |
generate_image_wrapper, | |
[seed, result, input_prompt], | |
[img1, img2, submit], | |
queue=True, | |
) | |
gen_img.click( | |
lambda *args: gr.update(interactive=False), | |
None, | |
[submit], | |
queue=False, | |
) | |
demo.launch() | |