openfree commited on
Commit
5297ee2
ยท
verified ยท
1 Parent(s): cdb6b46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -11
app.py CHANGED
@@ -4,9 +4,20 @@ import random
4
  import torch
5
  from PIL import Image
6
  import os
 
 
7
 
8
  import spaces
9
 
 
 
 
 
 
 
 
 
 
10
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
11
  from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
12
  from kolors.models.modeling_chatglm import ChatGLMModel
@@ -14,16 +25,307 @@ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
14
  from kolors.models.unet_2d_condition import UNet2DConditionModel
15
  from diffusers import AutoencoderKL, EulerDiscreteScheduler
16
 
17
- from huggingface_hub import snapshot_download
18
 
19
- import ast #์ถ”๊ฐ€ ์‚ฝ์ž…, requirements: albumentations ์ถ”๊ฐ€
20
- script_repr = os.getenv("APP")
21
- if script_repr is None:
22
- print("Error: Environment variable 'APP' not set.")
23
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- try:
26
- exec(script_repr)
27
- except Exception as e:
28
- print(f"Error executing script: {e}")
29
- sys.exit(1)
 
4
  import torch
5
  from PIL import Image
6
  import os
7
+ import sys
8
+ import importlib.util
9
 
10
  import spaces
11
 
12
+ # ์ค‘์š”: ํŒจ์น˜ ์ ์šฉ - huggingface_hub์— cached_download ํ•จ์ˆ˜ ์ถ”๊ฐ€
13
+ import huggingface_hub
14
+ if not hasattr(huggingface_hub, "cached_download"):
15
+ # ๊ธฐ์กด hf_hub_download ํ•จ์ˆ˜๋ฅผ cached_download๋กœ ๋ณ„์นญ ์ถ”๊ฐ€
16
+ huggingface_hub.cached_download = huggingface_hub.hf_hub_download
17
+
18
+ # ๊ทธ ํ›„ ๋‚˜๋จธ์ง€ ์ž„ํฌํŠธ ์ง„ํ–‰
19
+ from huggingface_hub import snapshot_download, hf_hub_download, model_info
20
+
21
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
22
  from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
23
  from kolors.models.modeling_chatglm import ChatGLMModel
 
25
  from kolors.models.unet_2d_condition import UNet2DConditionModel
26
  from diffusers import AutoencoderKL, EulerDiscreteScheduler
27
 
 
28
 
29
+ device = "cuda"
30
+ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
31
+ ckpt_dir = f'{root_dir}/weights/Kolors'
32
+
33
+ snapshot_download(repo_id="Kwai-Kolors/Kolors", local_dir=ckpt_dir)
34
+ snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", local_dir=f"{root_dir}/weights/Kolors-IP-Adapter-Plus")
35
+
36
+ # Load models
37
+ text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
38
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
39
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
40
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
41
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
42
+
43
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
44
+ f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder',
45
+ ignore_mismatched_sizes=True
46
+ ).to(dtype=torch.float16, device=device)
47
+
48
+ ip_img_size = 336
49
+ clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
50
+
51
+ pipe = StableDiffusionXLPipeline(
52
+ vae=vae,
53
+ text_encoder=text_encoder,
54
+ tokenizer=tokenizer,
55
+ unet=unet,
56
+ scheduler=scheduler,
57
+ image_encoder=image_encoder,
58
+ feature_extractor=clip_image_processor,
59
+ force_zeros_for_empty_prompt=False
60
+ ).to(device)
61
+
62
+ if hasattr(pipe.unet, 'encoder_hid_proj'):
63
+ pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
64
+
65
+ pipe.load_ip_adapter(f'{root_dir}/weights/Kolors-IP-Adapter-Plus', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
66
+
67
+ MAX_SEED = np.iinfo(np.int32).max
68
+ MAX_IMAGE_SIZE = 1024
69
+
70
+ # ----------------------------------------------
71
+ # infer ํ•จ์ˆ˜ (๊ธฐ์กด ๋กœ์ง ๊ทธ๋Œ€๋กœ ์œ ์ง€)
72
+ # ----------------------------------------------
73
+ @spaces.GPU(duration=80)
74
+ def infer(
75
+ user_prompt,
76
+ ip_adapter_image,
77
+ ip_adapter_scale=0.5,
78
+ negative_prompt="",
79
+ seed=100,
80
+ randomize_seed=False,
81
+ width=1024,
82
+ height=1024,
83
+ guidance_scale=5.0,
84
+ num_inference_steps=50,
85
+ progress=gr.Progress(track_tqdm=True)
86
+ ):
87
+ # ์ˆจ๊ฒจ์ง„(๊ธฐ๋ณธ/ํ•„์ˆ˜) ํ”„๋กฌํ”„ํŠธ
88
+ hidden_prompt = (
89
+ "Ghibli Studio style, Charming hand-drawn anime-style illustration"
90
+ )
91
+
92
+ # ์‹ค์ œ๋กœ ํŒŒ์ดํ”„๋ผ์ธ์— ์ „๋‹ฌํ•  ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ
93
+ prompt = f"{hidden_prompt}, {user_prompt}"
94
+
95
+ if randomize_seed:
96
+ seed = random.randint(0, MAX_SEED)
97
+
98
+ generator = torch.Generator(device="cuda").manual_seed(seed)
99
+ pipe.to("cuda")
100
+ image_encoder.to("cuda")
101
+ pipe.image_encoder = image_encoder
102
+ pipe.set_ip_adapter_scale([ip_adapter_scale])
103
+
104
+ image = pipe(
105
+ prompt=prompt,
106
+ ip_adapter_image=[ip_adapter_image],
107
+ negative_prompt=negative_prompt,
108
+ height=height,
109
+ width=width,
110
+ num_inference_steps=num_inference_steps,
111
+ guidance_scale=guidance_scale,
112
+ num_images_per_prompt=1,
113
+ generator=generator,
114
+ ).images[0]
115
+
116
+ return image, seed
117
+
118
+ examples = [
119
+ [
120
+ "background alps",
121
+ "gh0.webp",
122
+ 0.5
123
+ ],
124
+ [
125
+ "dancing",
126
+ "gh5.jpg",
127
+ 0.5
128
+ ],
129
+ [
130
+ "smile",
131
+ "gh2.jpg",
132
+ 0.5
133
+ ],
134
+ [
135
+ "3d style",
136
+ "gh3.webp",
137
+ 0.6
138
+ ],
139
+ [
140
+ "with Pikachu",
141
+ "gh4.jpg",
142
+ 0.5
143
+ ],
144
+ [
145
+ "Ghibli Studio style, Charming hand-drawn anime-style illustration",
146
+ "gh7.jpg",
147
+ 0.5
148
+ ],
149
+ [
150
+ "Ghibli Studio style, Charming hand-drawn anime-style illustration",
151
+ "gh1.jpg",
152
+ 0.5
153
+ ],
154
+ ]
155
+
156
+ # --------------------------
157
+ # ๊ฐœ์„ ๋œ UI๋ฅผ ์œ„ํ•œ CSS
158
+ # --------------------------
159
+ css = """
160
+ body {
161
+ background: linear-gradient(135deg, #f5f7fa, #c3cfe2);
162
+ font-family: 'Helvetica Neue', Arial, sans-serif;
163
+ color: #333;
164
+ margin: 0;
165
+ padding: 0;
166
+ }
167
+ #col-container {
168
+ margin: 0 auto !important;
169
+ max-width: 720px;
170
+ background: rgba(255,255,255,0.85);
171
+ border-radius: 16px;
172
+ padding: 2rem;
173
+ box-shadow: 0 8px 24px rgba(0,0,0,0.1);
174
+ }
175
+ #header-title {
176
+ text-align: center;
177
+ font-size: 2rem;
178
+ font-weight: bold;
179
+ margin-bottom: 1rem;
180
+ }
181
+ #prompt-row {
182
+ display: flex;
183
+ gap: 0.5rem;
184
+ align-items: center;
185
+ margin-bottom: 1rem;
186
+ }
187
+ #prompt-text {
188
+ flex: 1;
189
+ }
190
+ #result img {
191
+ object-position: top;
192
+ border-radius: 8px;
193
+ }
194
+ #result .image-container {
195
+ height: 100%;
196
+ }
197
+ .gr-button {
198
+ background-color: #2E8BFB !important;
199
+ color: white !important;
200
+ border: none !important;
201
+ transition: background-color 0.2s ease;
202
+ }
203
+ .gr-button:hover {
204
+ background-color: #186EDB !important;
205
+ }
206
+ .gr-slider input[type=range] {
207
+ accent-color: #2E8BFB !important;
208
+ }
209
+ .gr-box {
210
+ background-color: #fafafa !important;
211
+ border: 1px solid #ddd !important;
212
+ border-radius: 8px !important;
213
+ padding: 1rem !important;
214
+ }
215
+ #advanced-settings {
216
+ margin-top: 1rem;
217
+ border-radius: 8px;
218
+ }
219
+ """
220
+
221
+ with gr.Blocks(theme="apriel", css=css) as demo:
222
+ with gr.Column(elem_id="col-container"):
223
+ gr.Markdown("<div id='header-title'>Open Meme Studio</div>")
224
+ gr.Markdown("<div id='header-title' style='font-size: 12px;'>Community: https://discord.gg/openfreeai</div>")
225
+
226
+ # ์ƒ๋‹จ: ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ + ์‹คํ–‰ ๋ฒ„ํŠผ
227
+ with gr.Row(elem_id="prompt-row"):
228
+ prompt = gr.Text(
229
+ label="Prompt",
230
+ show_label=False,
231
+ max_lines=1,
232
+ placeholder="Enter your prompt",
233
+ elem_id="prompt-text",
234
+ )
235
+ run_button = gr.Button("Run", elem_id="run-button")
236
+
237
+ # ๊ฐ€์šด๋ฐ: ์ด๋ฏธ์ง€ ์ž…๋ ฅ๊ณผ ์Šฌ๋ผ์ด๋”, ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€
238
+ with gr.Row():
239
+ with gr.Column():
240
+ ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil")
241
+ ip_adapter_scale = gr.Slider(
242
+ label="Image influence scale",
243
+ info="Use 1 for creating variations",
244
+ minimum=0.0,
245
+ maximum=1.0,
246
+ step=0.05,
247
+ value=0.5,
248
+ )
249
+ result = gr.Image(label="Result", elem_id="result")
250
+
251
+ # ํ•˜๋‹จ: ๊ณ ๊ธ‰ ์„ค์ •(Accordion)
252
+ with gr.Accordion("Advanced Settings", open=False, elem_id="advanced-settings"):
253
+ negative_prompt = gr.Text(
254
+ label="Negative prompt",
255
+ max_lines=2,
256
+ placeholder=(
257
+ "Copy(worst quality, low quality:1.4), bad anatomy, bad hands, text, error, "
258
+ "missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, "
259
+ "normal quality, jpeg artifacts, signature, watermark, username, blurry, "
260
+ "artist name, (deformed iris, deformed pupils:1.2), (semi-realistic, cgi, "
261
+ "3d, render:1.1), amateur, (poorly drawn hands, poorly drawn face:1.2)"
262
+ ),
263
+ )
264
+ seed = gr.Slider(
265
+ label="Seed",
266
+ minimum=0,
267
+ maximum=MAX_SEED,
268
+ step=1,
269
+ value=0,
270
+ )
271
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
272
+ with gr.Row():
273
+ width = gr.Slider(
274
+ label="Width",
275
+ minimum=256,
276
+ maximum=MAX_IMAGE_SIZE,
277
+ step=32,
278
+ value=1024,
279
+ )
280
+ height = gr.Slider(
281
+ label="Height",
282
+ minimum=256,
283
+ maximum=MAX_IMAGE_SIZE,
284
+ step=32,
285
+ value=1024,
286
+ )
287
+ with gr.Row():
288
+ guidance_scale = gr.Slider(
289
+ label="Guidance scale",
290
+ minimum=0.0,
291
+ maximum=10.0,
292
+ step=0.1,
293
+ value=5.0,
294
+ )
295
+ num_inference_steps = gr.Slider(
296
+ label="Number of inference steps",
297
+ minimum=1,
298
+ maximum=100,
299
+ step=1,
300
+ value=50,
301
+ )
302
+
303
+ # ์˜ˆ์‹œ๋“ค
304
+ gr.Examples(
305
+ examples=examples,
306
+ fn=infer,
307
+ inputs=[prompt, ip_adapter_image, ip_adapter_scale],
308
+ outputs=[result, seed],
309
+ cache_examples="lazy"
310
+ )
311
+
312
+ # ๋ฒ„ํŠผ ํด๋ฆญ/ํ”„๋กฌํ”„ํŠธ ์—”ํ„ฐ ์‹œ ์‹คํ–‰
313
+ gr.on(
314
+ triggers=[run_button.click, prompt.submit],
315
+ fn=infer,
316
+ inputs=[
317
+ prompt,
318
+ ip_adapter_image,
319
+ ip_adapter_scale,
320
+ negative_prompt,
321
+ seed,
322
+ randomize_seed,
323
+ width,
324
+ height,
325
+ guidance_scale,
326
+ num_inference_steps
327
+ ],
328
+ outputs=[result, seed]
329
+ )
330
 
331
+ demo.queue().launch()