prithivMLmods commited on
Commit
7f96124
·
verified ·
1 Parent(s): 187df91

add lora support(qwen image)

Browse files
Files changed (1) hide show
  1. app.py +81 -3
app.py CHANGED
@@ -10,6 +10,10 @@ import numpy as np
10
  import time
11
  import zipfile
12
  import os
 
 
 
 
13
 
14
  # Description for the app
15
  DESCRIPTION = """## Qwen Image Hpc/."""
@@ -44,6 +48,45 @@ aspect_ratios = {
44
  "3:4": (1140, 1472)
45
  }
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Generation function for Qwen/Qwen-Image
48
  @spaces.GPU(duration=120)
49
  def generate_qwen(
@@ -57,6 +100,8 @@ def generate_qwen(
57
  num_inference_steps: int = 50,
58
  num_images: int = 1,
59
  zip_images: bool = False,
 
 
60
  progress=gr.Progress(track_tqdm=True),
61
  ):
62
  if randomize_seed:
@@ -64,10 +109,21 @@ def generate_qwen(
64
  generator = torch.Generator(device).manual_seed(seed)
65
 
66
  start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  images = pipe_qwen(
69
  prompt=prompt,
70
- negative_prompt=negative_prompt if negative_prompt else None,
71
  height=height,
72
  width=width,
73
  guidance_scale=guidance_scale,
@@ -88,6 +144,12 @@ def generate_qwen(
88
  for i, img_path in enumerate(image_paths):
89
  zipf.write(img_path, arcname=f"Img_{i}.png")
90
  zip_path = zip_name
 
 
 
 
 
 
91
 
92
  return image_paths, seed, f"{duration:.2f}", zip_path
93
 
@@ -105,6 +167,8 @@ def generate(
105
  num_inference_steps: int,
106
  num_images: int,
107
  zip_images: bool,
 
 
108
  progress=gr.Progress(track_tqdm=True),
109
  ):
110
  final_negative_prompt = negative_prompt if use_negative_prompt else ""
@@ -119,6 +183,8 @@ def generate(
119
  num_inference_steps=num_inference_steps,
120
  num_images=num_images,
121
  zip_images=zip_images,
 
 
122
  progress=progress,
123
  )
124
 
@@ -146,7 +212,7 @@ footer {
146
  '''
147
 
148
  # Gradio interface
149
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
150
  gr.Markdown(DESCRIPTION)
151
  with gr.Row():
152
  prompt = gr.Text(
@@ -165,6 +231,8 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
165
  choices=list(aspect_ratios.keys()),
166
  value="1:1",
167
  )
 
 
168
  with gr.Accordion("Additional Options", open=False):
169
  use_negative_prompt = gr.Checkbox(
170
  label="Use negative prompt",
@@ -223,6 +291,14 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
223
  value=1,
224
  )
225
  zip_images = gr.Checkbox(label="Zip generated images", value=False)
 
 
 
 
 
 
 
 
226
 
227
  gr.Markdown("### Output Information")
228
  seed_display = gr.Textbox(label="Seed used", interactive=False)
@@ -263,6 +339,8 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
263
  num_inference_steps,
264
  num_images,
265
  zip_images,
 
 
266
  ],
267
  outputs=[result, seed_display, generation_time, zip_file],
268
  api_name="run",
@@ -278,4 +356,4 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
278
  )
279
 
280
  if __name__ == "__main__":
281
- demo.queue(max_size=50).launch(share=False, mcp_server=True, ssr_mode=False, show_error=True)
 
10
  import time
11
  import zipfile
12
  import os
13
+ import requests
14
+ from urllib.parse import urlparse
15
+ import tempfile
16
+ import shutil
17
 
18
  # Description for the app
19
  DESCRIPTION = """## Qwen Image Hpc/."""
 
48
  "3:4": (1140, 1472)
49
  }
50
 
51
+ def load_lora_opt(pipe, lora_input):
52
+ lora_input = lora_input.strip()
53
+ if not lora_input:
54
+ return
55
+
56
+ # If it's just an ID like "author/model"
57
+ if "/" in lora_input and not lora_input.startswith("http"):
58
+ pipe.load_lora_weights(lora_input, adapter_name="default")
59
+ return
60
+
61
+ if lora_input.startswith("http"):
62
+ url = lora_input
63
+
64
+ # Repo page (no blob/resolve)
65
+ if "huggingface.co" in url and "/blob/" not in url and "/resolve/" not in url:
66
+ repo_id = urlparse(url).path.strip("/")
67
+ pipe.load_lora_weights(repo_id, adapter_name="default")
68
+ return
69
+
70
+ # Blob link → convert to resolve link
71
+ if "/blob/" in url:
72
+ url = url.replace("/blob/", "/resolve/")
73
+
74
+ # Download direct file
75
+ tmp_dir = tempfile.mkdtemp()
76
+ local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path))
77
+
78
+ try:
79
+ print(f"Downloading LoRA from {url}...")
80
+ resp = requests.get(url, stream=True)
81
+ resp.raise_for_status()
82
+ with open(local_path, "wb") as f:
83
+ for chunk in resp.iter_content(chunk_size=8192):
84
+ f.write(chunk)
85
+ print(f"Saved LoRA to {local_path}")
86
+ pipe.load_lora_weights(local_path, adapter_name="default")
87
+ finally:
88
+ shutil.rmtree(tmp_dir, ignore_errors=True)
89
+
90
  # Generation function for Qwen/Qwen-Image
91
  @spaces.GPU(duration=120)
92
  def generate_qwen(
 
100
  num_inference_steps: int = 50,
101
  num_images: int = 1,
102
  zip_images: bool = False,
103
+ lora_input: str = "",
104
+ lora_scale: float = 1.0,
105
  progress=gr.Progress(track_tqdm=True),
106
  ):
107
  if randomize_seed:
 
109
  generator = torch.Generator(device).manual_seed(seed)
110
 
111
  start_time = time.time()
112
+
113
+ current_adapters = pipe_qwen.get_list_adapters()
114
+ for adapter in current_adapters:
115
+ pipe_qwen.delete_adapters(adapter)
116
+ pipe_qwen.disable_lora()
117
+
118
+ use_lora = False
119
+ if lora_input and lora_input.strip() != "":
120
+ load_lora_opt(pipe_qwen, lora_input)
121
+ pipe_qwen.set_adapters(["default"], adapter_weights=[lora_scale])
122
+ use_lora = True
123
 
124
  images = pipe_qwen(
125
  prompt=prompt,
126
+ negative_prompt=negative_prompt if negative_prompt else "",
127
  height=height,
128
  width=width,
129
  guidance_scale=guidance_scale,
 
144
  for i, img_path in enumerate(image_paths):
145
  zipf.write(img_path, arcname=f"Img_{i}.png")
146
  zip_path = zip_name
147
+
148
+ # Clean up adapters
149
+ current_adapters = pipe_qwen.get_list_adapters()
150
+ for adapter in current_adapters:
151
+ pipe_qwen.delete_adapters(adapter)
152
+ pipe_qwen.disable_lora()
153
 
154
  return image_paths, seed, f"{duration:.2f}", zip_path
155
 
 
167
  num_inference_steps: int,
168
  num_images: int,
169
  zip_images: bool,
170
+ lora_input: str,
171
+ lora_scale: float,
172
  progress=gr.Progress(track_tqdm=True),
173
  ):
174
  final_negative_prompt = negative_prompt if use_negative_prompt else ""
 
183
  num_inference_steps=num_inference_steps,
184
  num_images=num_images,
185
  zip_images=zip_images,
186
+ lora_input=lora_input,
187
+ lora_scale=lora_scale,
188
  progress=progress,
189
  )
190
 
 
212
  '''
213
 
214
  # Gradio interface
215
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme", delete_cache=(240, 240)) as demo:
216
  gr.Markdown(DESCRIPTION)
217
  with gr.Row():
218
  prompt = gr.Text(
 
231
  choices=list(aspect_ratios.keys()),
232
  value="1:1",
233
  )
234
+ with gr.Row():
235
+ lora = gr.Textbox(label="qwen image lora (optional)", placeholder="flymy-ai/qwen-image-anime-irl-lora")
236
  with gr.Accordion("Additional Options", open=False):
237
  use_negative_prompt = gr.Checkbox(
238
  label="Use negative prompt",
 
291
  value=1,
292
  )
293
  zip_images = gr.Checkbox(label="Zip generated images", value=False)
294
+ with gr.Row():
295
+ lora_scale = gr.Slider(
296
+ label="LoRA Scale",
297
+ minimum=0,
298
+ maximum=2,
299
+ step=0.01,
300
+ value=1,
301
+ )
302
 
303
  gr.Markdown("### Output Information")
304
  seed_display = gr.Textbox(label="Seed used", interactive=False)
 
339
  num_inference_steps,
340
  num_images,
341
  zip_images,
342
+ lora,
343
+ lora_scale,
344
  ],
345
  outputs=[result, seed_display, generation_time, zip_file],
346
  api_name="run",
 
356
  )
357
 
358
  if __name__ == "__main__":
359
+ demo.queue(max_size=50).launch(share=False, mcp_server=True, ssr_mode=False, debug=True, show_error=True)