File size: 5,121 Bytes
8b54513
 
3c55139
22ff2b2
5bbfa70
3f9caff
3c55139
8b54513
3c55139
 
 
288480f
bb89818
3c55139
8b54513
3c55139
 
 
 
8b54513
3c55139
 
 
 
 
 
 
 
 
 
 
22ff2b2
3f9caff
 
 
 
3c55139
 
 
 
3f9caff
 
3c55139
8b54513
3c55139
 
8b54513
3c55139
 
8b54513
3c55139
a2552b3
8b54513
3c55139
 
 
 
8b54513
3c55139
8b54513
5bb0eb0
5813aac
3c55139
 
 
 
 
 
8b54513
134a323
 
 
 
 
 
 
 
a66a8c1
134a323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b54513
134a323
 
 
8b54513
d25ae12
59bb14b
 
 
 
 
 
 
 
 
 
134a323
59bb14b
 
 
 
 
 
 
 
 
 
134a323
3c55139
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import gradio as gr
from torchvision.transforms.functional import to_tensor
from huggingface_hub import hf_hub_download, snapshot_download, login

from tok.ar_dtok.ar_model import ARModel
from t2i_inference import T2IConfig, TextToImageInference

def generate_text(self, image: str, prompt: str) -> str:
    image = image.convert('RGB')
    image = to_tensor(image).unsqueeze(0).to(self.device)
    
    image_code = self.visual_tokenizer.encoder(image.to(self.config.dtype))['bottleneck_rep']
    image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()])
    
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": f"{image_text}\n{prompt}"}
    ]
    
    input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = self.tokenizer(input_text, return_tensors="pt")
    
    gen_ids = self.model.generate(
        inputs.input_ids.to(self.device),
        max_new_tokens=512,
        do_sample=True)
    return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]

login(token=os.getenv('HF_TOKEN'))
config = T2IConfig()
config.model = snapshot_download("csuhan/Tar-7B-v0.1")
config.ar_path = {
    "1024px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth"),
    "512px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_512px.pth"),
}
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
inference = TextToImageInference(config)

def generate_image(prompt, resolution, top_p, top_k, cfg_scale):
    image = inference.generate_image(prompt, resolution, top_p, top_k, cfg_scale)
    return image

def clear_inputs_t2i():
    return "", None

def understand_image(image, prompt):
    return generate_text(inference, image, prompt)

def clear_inputs_i2t():
    return None, "", ""

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        <div align="center">

        ### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations  

        [πŸ•ΈοΈ Project Page](http://tar.csuhan.com) β€’ [πŸ“„ Paper](http://arxiv.org/abs/2506.18898) β€’ [πŸ’» Code](https://github.com/csuhan/Tar) β€’ [πŸ“¦ Model](https://huggingface.co/collections/csuhan/tar-68538273b5537d0bee712648)

        </div>
        """,
        elem_id="title",
    )
    with gr.Tab("Image Generation"):
        with gr.Row():
            with gr.Column(scale=1):
                prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
                with gr.Accordion("Advanced Settings", open=False):
                    resolution = gr.Radio(
                        ["512px", "1024px"], value="1024px", label="Resolution"
                    )
                    top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
                    top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
                    cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
                with gr.Row():
                    generate_btn = gr.Button("Generate")
                    clear_btn = gr.Button("Clear")
            with gr.Column(scale=2):
                output_image = gr.Image(label="Generated Image")
        with gr.Row():
            gr.Examples(
                examples=[
                    ["a vibrant arrangement of blue and yellow flowers, with delicate petals and lush green stems, placed in a clear glass vase. The vase is situated on a polished wooden table, which reflects the soft light illuminating the room. Around the vase, there are a few scattered leaves, adding a touch of natural charm to the setting.", "1024px", 0.95, 1200, 4.0],
                    ["a cat", "512px", 0.95, 1200, 4.0],
                ],
                inputs=[prompt, resolution, top_p, top_k, cfg_scale],
                label="Example"
            )


        generate_btn.click(
            generate_image, 
            inputs=[prompt, resolution, top_p, top_k, cfg_scale], 
            outputs=output_image
        )
        clear_btn.click(
            clear_inputs_t2i, 
            outputs=[prompt, output_image]
        )

    with gr.Tab("Image Understanding"):
        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(label="Upload Image", type="pil")
                question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.")
                with gr.Row():
                    qa_btn = gr.Button("Generate")
                    clear_btn_i2t = gr.Button("Clear")
            with gr.Column(scale=1):
                answer_output = gr.Textbox(label="Response", lines=4)

        qa_btn.click(
            understand_image,
            inputs=[image_input, question_input],
            outputs=answer_output
        )

        clear_btn_i2t.click(
            clear_inputs_i2t,
            outputs=[image_input, question_input, answer_output]
        )

demo.launch(share=True)