stzhao commited on
Commit
5a639eb
·
verified ·
1 Parent(s): 2795e8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -133
app.py CHANGED
@@ -1,154 +1,167 @@
 
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
 
44
  guidance_scale=guidance_scale,
45
  num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
 
48
  generator=generator,
 
49
  ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
  )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
  )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
-
119
- with gr.Row():
120
  guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
  maximum=10.0,
 
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
+ import os
2
  import gradio as gr
 
 
 
 
 
3
  import torch
4
+ from diffusers import Lumina2Pipeline
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ # Set up environment
8
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
9
+
10
+ # Load models
11
+ def load_models():
12
+ model_name = "/mnt/petrelfs/zhaoshitian/models/LeX-Enhancer-full"
13
+
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ torch_dtype="auto",
17
+ device_map="auto"
18
+ )
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+
21
+ pipe = Lumina2Pipeline.from_pretrained(
22
+ "/mnt/hwfile/alpha_vl/qilongwu/checkpoints/LeX-Lumina",
23
+ torch_dtype=torch.bfloat16
24
+ )
25
+ pipe.to("cuda")
26
+
27
+ return model, tokenizer, pipe
28
 
29
+ model, tokenizer, pipe = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def generate_enhanced_caption(image_caption, text_caption):
32
+ """Generate enhanced caption using the LeX-Enhancer model"""
33
+ combined_caption = f"{image_caption}, with the text on it: {text_caption}."
34
+ instruction = """
35
+ Below is the simple caption of an image with text. Please deduce the detailed description of the image based on this simple caption. Note: 1. The description should only include visual elements and should not contain any extended meanings. 2. The visual elements should be as rich as possible, such as the main objects in the image, their respective attributes, the spatial relationships between the objects, lighting and shadows, color style, any text in the image and its style, etc. 3. The output description should be a single paragraph and should not be structured. 4. The description should avoid certain situations, such as pure white or black backgrounds, blurry text, excessive rendering of text, or harsh visual styles. 5. The detailed caption should be human readable and fluent. 6. Avoid using vague expressions such as "may be" or "might be"; the generated caption must be in a definitive, narrative tone. 7. Do not use negative sentence structures, such as "there is nothing in the image," etc. The entire caption should directly describe the content of the image. 8. The entire output should be limited to 200 words.
36
+ """
37
+ messages = [
38
+ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
39
+ {"role": "user", "content": instruction + "\nSimple Caption:\n" + combined_caption}
40
+ ]
41
+ text = tokenizer.apply_chat_template(
42
+ messages,
43
+ tokenize=False,
44
+ add_generation_prompt=True
45
+ )
46
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
47
 
48
+ generated_ids = model.generate(
49
+ **model_inputs,
50
+ max_new_tokens=1024
51
+ )
52
+ generated_ids = [
53
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
54
+ ]
55
+
56
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
57
+ enhanced_caption = response.split("</think>", -1)[-1].strip(" ").strip("\n")
58
+
59
+ return combined_caption, enhanced_caption
60
+
61
+ def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
62
+ """Generate image using Lumina2Pipeline"""
63
+ generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None
64
+
65
  image = pipe(
66
+ enhanced_caption,
67
+ height=1024,
68
+ width=1024,
69
  guidance_scale=guidance_scale,
70
  num_inference_steps=num_inference_steps,
71
+ cfg_trunc_ratio=1,
72
+ cfg_normalization=True,
73
+ max_sequence_length=256,
74
  generator=generator,
75
+ system_prompt="You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts.",
76
  ).images[0]
77
+
78
+ return image
79
+
80
+ def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale):
81
+ """Run the complete pipeline from captions to final image"""
82
+ combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
83
+ image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale)
84
+
85
+ return {
86
+ "combined_caption": combined_caption,
87
+ "enhanced_caption": enhanced_caption,
88
+ "image": image
89
+ }
90
+
91
+ # Gradio interface
92
+ with gr.Blocks() as demo:
93
+ gr.Markdown("# LeX-Enhancer & Lumina2 Demo")
94
+ gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with Lumina2")
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ image_caption = gr.Textbox(
99
+ lines=2,
100
+ label="Image Caption",
101
+ placeholder="Describe the visual content of the image",
102
+ value="A picture of a group of people gathered in front of a world map"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
+ text_caption = gr.Textbox(
105
+ lines=2,
106
+ label="Text Caption",
107
+ placeholder="Describe any text that should appear in the image",
108
+ value="\"Communicate\" in purple, \"Execute\" in yellow"
 
 
109
  )
110
+
111
+ with gr.Accordion("Advanced Settings", open=False):
112
+ seed = gr.Slider(
113
+ minimum=0,
114
+ maximum=100000,
115
+ value=0,
116
+ step=1,
117
+ label="Seed (0 for random)"
 
 
118
  )
119
+ num_inference_steps = gr.Slider(
120
+ minimum=20,
121
+ maximum=100,
122
+ value=80,
123
+ step=1,
124
+ label="Number of Inference Steps"
 
125
  )
 
 
126
  guidance_scale = gr.Slider(
127
+ minimum=1.0,
 
128
  maximum=10.0,
129
+ value=4.0,
130
  step=0.1,
131
+ label="Guidance Scale"
 
 
 
 
 
 
 
 
132
  )
133
+
134
+ submit_btn = gr.Button("Generate", variant="primary")
135
+
136
+ with gr.Column():
137
+ output_image = gr.Image(label="Generated Image")
138
+ combined_caption_box = gr.Textbox(
139
+ label="Combined Caption",
140
+ interactive=False
141
+ )
142
+ enhanced_caption_box = gr.Textbox(
143
+ label="Enhanced Caption",
144
+ interactive=False,
145
+ lines=5
146
+ )
147
+
148
+ # Example prompts
149
+ examples = [
150
+ ["A modern office workspace", "\"Innovation\" in bold blue letters at the center"],
151
+ ["A beach sunset scene", "\"Relax\" in cursive white text in the corner"],
152
+ ["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"]
153
+ ]
154
+ gr.Examples(
155
+ examples=examples,
156
+ inputs=[image_caption, text_caption],
157
+ label="Example Inputs"
158
+ )
159
+
160
+ submit_btn.click(
161
+ fn=run_pipeline,
162
+ inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale],
163
+ outputs=[output_image, combined_caption_box, enhanced_caption_box]
164
  )
165
 
166
  if __name__ == "__main__":
167
+ demo.launch(server_name="0.0.0.0")