Spaces:
Runtime error
Runtime error
File size: 5,121 Bytes
8b54513 3c55139 22ff2b2 5bbfa70 3f9caff 3c55139 8b54513 3c55139 288480f bb89818 3c55139 8b54513 3c55139 8b54513 3c55139 22ff2b2 3f9caff 3c55139 3f9caff 3c55139 8b54513 3c55139 8b54513 3c55139 8b54513 3c55139 a2552b3 8b54513 3c55139 8b54513 3c55139 8b54513 5bb0eb0 5813aac 3c55139 8b54513 134a323 a66a8c1 134a323 8b54513 134a323 8b54513 d25ae12 59bb14b 134a323 59bb14b 134a323 3c55139 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import gradio as gr
from torchvision.transforms.functional import to_tensor
from huggingface_hub import hf_hub_download, snapshot_download, login
from tok.ar_dtok.ar_model import ARModel
from t2i_inference import T2IConfig, TextToImageInference
def generate_text(self, image: str, prompt: str) -> str:
image = image.convert('RGB')
image = to_tensor(image).unsqueeze(0).to(self.device)
image_code = self.visual_tokenizer.encoder(image.to(self.config.dtype))['bottleneck_rep']
image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()])
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{image_text}\n{prompt}"}
]
input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(input_text, return_tensors="pt")
gen_ids = self.model.generate(
inputs.input_ids.to(self.device),
max_new_tokens=512,
do_sample=True)
return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
login(token=os.getenv('HF_TOKEN'))
config = T2IConfig()
config.model = snapshot_download("csuhan/Tar-7B-v0.1")
config.ar_path = {
"1024px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth"),
"512px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_512px.pth"),
}
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
inference = TextToImageInference(config)
def generate_image(prompt, resolution, top_p, top_k, cfg_scale):
image = inference.generate_image(prompt, resolution, top_p, top_k, cfg_scale)
return image
def clear_inputs_t2i():
return "", None
def understand_image(image, prompt):
return generate_text(inference, image, prompt)
def clear_inputs_i2t():
return None, "", ""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div align="center">
### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
[πΈοΈ Project Page](http://tar.csuhan.com) β’ [π Paper](http://arxiv.org/abs/2506.18898) β’ [π» Code](https://github.com/csuhan/Tar) β’ [π¦ Model](https://huggingface.co/collections/csuhan/tar-68538273b5537d0bee712648)
</div>
""",
elem_id="title",
)
with gr.Tab("Image Generation"):
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
with gr.Accordion("Advanced Settings", open=False):
resolution = gr.Radio(
["512px", "1024px"], value="1024px", label="Resolution"
)
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
with gr.Row():
generate_btn = gr.Button("Generate")
clear_btn = gr.Button("Clear")
with gr.Column(scale=2):
output_image = gr.Image(label="Generated Image")
with gr.Row():
gr.Examples(
examples=[
["a vibrant arrangement of blue and yellow flowers, with delicate petals and lush green stems, placed in a clear glass vase. The vase is situated on a polished wooden table, which reflects the soft light illuminating the room. Around the vase, there are a few scattered leaves, adding a touch of natural charm to the setting.", "1024px", 0.95, 1200, 4.0],
["a cat", "512px", 0.95, 1200, 4.0],
],
inputs=[prompt, resolution, top_p, top_k, cfg_scale],
label="Example"
)
generate_btn.click(
generate_image,
inputs=[prompt, resolution, top_p, top_k, cfg_scale],
outputs=output_image
)
clear_btn.click(
clear_inputs_t2i,
outputs=[prompt, output_image]
)
with gr.Tab("Image Understanding"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Upload Image", type="pil")
question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.")
with gr.Row():
qa_btn = gr.Button("Generate")
clear_btn_i2t = gr.Button("Clear")
with gr.Column(scale=1):
answer_output = gr.Textbox(label="Response", lines=4)
qa_btn.click(
understand_image,
inputs=[image_input, question_input],
outputs=answer_output
)
clear_btn_i2t.click(
clear_inputs_i2t,
outputs=[image_input, question_input, answer_output]
)
demo.launch(share=True)
|