A24005179 commited on
Commit
b62cca0
·
verified ·
1 Parent(s): b08f86e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -54
app.py CHANGED
@@ -14,7 +14,7 @@ from PIL import Image
14
  from diffusers import AutoencoderKLHunyuanVideo
15
  from transformers import (
16
  LlamaModel, CLIPTextModel,
17
- LlamaTokenizerFast, CLIPTokenizer
18
  )
19
  from diffusers_helper.hunyuan import (
20
  encode_prompt_conds, vae_decode,
@@ -28,8 +28,10 @@ from diffusers_helper.utils import (
28
  )
29
  from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
30
  from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
 
 
31
 
32
- # Remove or replace GPU-specific imports
33
  device = torch.device("cpu")
34
 
35
  # Load models
@@ -61,7 +63,8 @@ vae = AutoencoderKLHunyuanVideo.from_pretrained(
61
  torch_dtype=torch.float16
62
  ).to(device)
63
 
64
- feature_extractor = SiglipImageProcessor.from_pretrained(
 
65
  "lllyasviel/flux_redux_bfl",
66
  subfolder='feature_extractor'
67
  )
@@ -177,7 +180,6 @@ def worker(
177
  total_latent_sections = int(max(round(total_latent_sections), 1))
178
  job_id = generate_timestamp()
179
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
180
-
181
  try:
182
  llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
183
  if cfg == 1:
@@ -186,44 +188,35 @@ def worker(
186
  llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
187
  llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
188
  llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
189
-
190
  H, W, C = input_image.shape
191
  height, width = find_nearest_bucket(H, W, resolution=640)
192
  input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
193
  Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
194
  input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
195
  input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
196
-
197
  start_latent = vae_encode(input_image_pt, vae).to(device)
198
-
199
  image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
200
  image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
201
-
202
  llama_vec = llama_vec.to(transformer.dtype).to(device)
203
  llama_vec_n = llama_vec_n.to(transformer.dtype).to(device)
204
  clip_l_pooler = clip_l_pooler.to(transformer.dtype).to(device)
205
  clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype).to(device)
206
  image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype).to(device)
207
-
208
  rnd = torch.Generator("cpu").manual_seed(seed)
209
  history_latents = torch.zeros(
210
  size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
211
  dtype=torch.float32
212
  ).to(device)
213
-
214
  history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
215
  total_generated_latent_frames = 1
216
-
217
  for section_index in range(total_latent_sections):
218
  if stream.input_queue.top() == 'end':
219
  stream.output_queue.push(('end', None))
220
  return
221
-
222
  if use_teacache:
223
  transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
224
  else:
225
  transformer.initialize_teacache(enable_teacache=False)
226
-
227
  def callback(d):
228
  preview = d['denoised']
229
  preview = vae_decode_fake(preview)
@@ -238,7 +231,6 @@ def worker(
238
  desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}'
239
  stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
240
  return
241
-
242
  indices = torch.arange(
243
  0, sum([1, 16, 2, 1, latent_window_size])
244
  ).unsqueeze(0)
@@ -249,7 +241,6 @@ def worker(
249
  clean_latent_1x_indices,
250
  latent_indices
251
  ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
252
-
253
  clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
254
  clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[
255
  :, :, -sum([16, 2, 1]):, :, :
@@ -258,7 +249,6 @@ def worker(
258
  [start_latent.to(history_latents), clean_latents_1x],
259
  dim=2
260
  )
261
-
262
  generated_latents = sample_hunyuan(
263
  transformer=transformer,
264
  sampler='unipc',
@@ -288,10 +278,8 @@ def worker(
288
  clean_latent_4x_indices=clean_latent_4x_indices,
289
  callback=callback,
290
  )
291
-
292
  total_generated_latent_frames += int(generated_latents.shape[2])
293
  history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
294
-
295
  real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
296
  if history_pixels is None:
297
  history_pixels = vae_decode(real_history_latents, vae).cpu()
@@ -304,14 +292,11 @@ def worker(
304
  history_pixels = soft_append_bcthw(
305
  history_pixels, current_pixels, overlapped_frames
306
  )
307
-
308
  output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
309
  save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
310
  stream.output_queue.push(('file', output_filename))
311
-
312
  except Exception as e:
313
  traceback.print_exc()
314
-
315
  stream.output_queue.push(('end', None))
316
  return
317
 
@@ -332,14 +317,17 @@ def process(
332
  use_teacache=True, mp4_crf=16, quality_radio="640x360", aspect_ratio="1:1"
333
  ):
334
  global stream
 
 
335
  quality_map = {
336
  "360p": (640, 360),
337
  "480p": (854, 480),
338
  "540p": (960, 540),
339
  "720p": (1280, 720),
340
- "640x360": (640, 360), # fallback for default
341
  }
342
- # Aspect ratio map: (width, height)
 
343
  aspect_map = {
344
  "1:1": (1, 1),
345
  "3:4": (3, 4),
@@ -347,53 +335,36 @@ def process(
347
  "16:9": (16, 9),
348
  "9:16": (9, 16),
349
  }
350
- selected_quality = quality_map.get(quality_radio, (640, 360))
351
- base_width, base_height = selected_quality
352
 
 
 
 
353
  if t2v:
354
- # Use aspect ratio to determine final width/height
355
  ar_w, ar_h = aspect_map.get(aspect_ratio, (1, 1))
 
356
  if ar_w >= ar_h:
357
- target_height = base_height
358
- target_width = int(round(target_height * ar_w / ar_h))
359
- else:
360
- target_width = base_width
361
  target_height = int(round(target_width * ar_h / ar_w))
 
 
362
  input_image = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255
363
  print(f"Using blank white image for text-to-video mode, {target_width}x{target_height} ({aspect_ratio})")
364
  else:
365
- target_width, target_height = selected_quality
366
- if isinstance(input_image, dict) and "composite" in input_image:
367
- composite_rgba_uint8 = input_image["composite"]
368
- rgb_uint8 = composite_rgba_uint8[:, :, :3]
369
- mask_uint8 = composite_rgba_uint8[:, :, 3]
370
- h, w = rgb_uint8.shape[:2]
371
- background_uint8 = np.full((h, w, 3), 255, dtype=np.uint8)
372
- alpha_normalized_float32 = mask_uint8.astype(np.float32) / 255.0
373
- alpha_mask_float32 = np.stack([alpha_normalized_float32]*3, axis=2)
374
- blended_image_float32 = rgb_uint8.astype(np.float32) * alpha_mask_float32 + \
375
- background_uint8.astype(np.float32) * (1.0 - alpha_mask_float32)
376
- input_image = np.clip(blended_image_float32, 0, 255).astype(np.uint8)
377
- elif input_image is None:
378
- raise ValueError("Please provide an input image or enable Text to Video mode")
379
- else:
380
- input_image = input_image.astype(np.uint8)
381
 
382
  yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
383
-
384
  stream = AsyncStream()
385
-
386
  async_run(
387
  worker, input_image, prompt, n_prompt, seed,
388
  total_second_length, latent_window_size, steps,
389
  cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
390
  )
391
-
392
  output_filename = None
393
-
394
  while True:
395
  flag, data = stream.output_queue.next()
396
-
397
  if flag == 'file':
398
  output_filename = data
399
  yield (
@@ -404,7 +375,6 @@ def process(
404
  gr.update(interactive=False),
405
  gr.update(interactive=True)
406
  )
407
-
408
  elif flag == 'progress':
409
  preview, desc, html = data
410
  yield (
@@ -415,7 +385,6 @@ def process(
415
  gr.update(interactive=False),
416
  gr.update(interactive=True)
417
  )
418
-
419
  elif flag == 'end':
420
  yield (
421
  output_filename,
@@ -430,7 +399,6 @@ def process(
430
  def end_process():
431
  stream.input_queue.push('end')
432
 
433
-
434
  quick_prompts = [
435
  'The girl dances gracefully, with clear movements, full of charm.',
436
  'A character doing some simple body movements.'
 
14
  from diffusers import AutoencoderKLHunyuanVideo
15
  from transformers import (
16
  LlamaModel, CLIPTextModel,
17
+ LlamaTokenizerFast, CLIPTokenizer, AutoImageProcessor
18
  )
19
  from diffusers_helper.hunyuan import (
20
  encode_prompt_conds, vae_decode,
 
28
  )
29
  from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
30
  from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
31
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
32
+ from diffusers_helper.bucket_tools import find_nearest_bucket
33
 
34
+ # Set device to CPU
35
  device = torch.device("cpu")
36
 
37
  # Load models
 
63
  torch_dtype=torch.float16
64
  ).to(device)
65
 
66
+ # Use AutoImageProcessor instead of SiglipImageProcessor
67
+ feature_extractor = AutoImageProcessor.from_pretrained(
68
  "lllyasviel/flux_redux_bfl",
69
  subfolder='feature_extractor'
70
  )
 
180
  total_latent_sections = int(max(round(total_latent_sections), 1))
181
  job_id = generate_timestamp()
182
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
 
183
  try:
184
  llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
185
  if cfg == 1:
 
188
  llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
189
  llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
190
  llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
 
191
  H, W, C = input_image.shape
192
  height, width = find_nearest_bucket(H, W, resolution=640)
193
  input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
194
  Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
195
  input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
196
  input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
 
197
  start_latent = vae_encode(input_image_pt, vae).to(device)
 
198
  image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
199
  image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
 
200
  llama_vec = llama_vec.to(transformer.dtype).to(device)
201
  llama_vec_n = llama_vec_n.to(transformer.dtype).to(device)
202
  clip_l_pooler = clip_l_pooler.to(transformer.dtype).to(device)
203
  clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype).to(device)
204
  image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype).to(device)
 
205
  rnd = torch.Generator("cpu").manual_seed(seed)
206
  history_latents = torch.zeros(
207
  size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
208
  dtype=torch.float32
209
  ).to(device)
 
210
  history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
211
  total_generated_latent_frames = 1
 
212
  for section_index in range(total_latent_sections):
213
  if stream.input_queue.top() == 'end':
214
  stream.output_queue.push(('end', None))
215
  return
 
216
  if use_teacache:
217
  transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
218
  else:
219
  transformer.initialize_teacache(enable_teacache=False)
 
220
  def callback(d):
221
  preview = d['denoised']
222
  preview = vae_decode_fake(preview)
 
231
  desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}'
232
  stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
233
  return
 
234
  indices = torch.arange(
235
  0, sum([1, 16, 2, 1, latent_window_size])
236
  ).unsqueeze(0)
 
241
  clean_latent_1x_indices,
242
  latent_indices
243
  ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
 
244
  clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
245
  clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[
246
  :, :, -sum([16, 2, 1]):, :, :
 
249
  [start_latent.to(history_latents), clean_latents_1x],
250
  dim=2
251
  )
 
252
  generated_latents = sample_hunyuan(
253
  transformer=transformer,
254
  sampler='unipc',
 
278
  clean_latent_4x_indices=clean_latent_4x_indices,
279
  callback=callback,
280
  )
 
281
  total_generated_latent_frames += int(generated_latents.shape[2])
282
  history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
 
283
  real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
284
  if history_pixels is None:
285
  history_pixels = vae_decode(real_history_latents, vae).cpu()
 
292
  history_pixels = soft_append_bcthw(
293
  history_pixels, current_pixels, overlapped_frames
294
  )
 
295
  output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
296
  save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
297
  stream.output_queue.push(('file', output_filename))
 
298
  except Exception as e:
299
  traceback.print_exc()
 
300
  stream.output_queue.push(('end', None))
301
  return
302
 
 
317
  use_teacache=True, mp4_crf=16, quality_radio="640x360", aspect_ratio="1:1"
318
  ):
319
  global stream
320
+
321
+ # Map quality options to actual resolutions
322
  quality_map = {
323
  "360p": (640, 360),
324
  "480p": (854, 480),
325
  "540p": (960, 540),
326
  "720p": (1280, 720),
327
+ "640x360": (640, 360), # fallback
328
  }
329
+
330
+ # Map aspect ratio strings to width/height ratios
331
  aspect_map = {
332
  "1:1": (1, 1),
333
  "3:4": (3, 4),
 
335
  "16:9": (16, 9),
336
  "9:16": (9, 16),
337
  }
 
 
338
 
339
+ # Get target resolution based on selected quality
340
+ target_width, target_height = quality_map.get(quality_radio, (640, 360))
341
+
342
  if t2v:
 
343
  ar_w, ar_h = aspect_map.get(aspect_ratio, (1, 1))
344
+ # Recalculate based on aspect ratio
345
  if ar_w >= ar_h:
 
 
 
 
346
  target_height = int(round(target_width * ar_h / ar_w))
347
+ else:
348
+ target_width = int(round(target_height * ar_w / ar_h))
349
  input_image = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255
350
  print(f"Using blank white image for text-to-video mode, {target_width}x{target_height} ({aspect_ratio})")
351
  else:
352
+ # Resize and crop input image to match selected resolution
353
+ H, W, C = input_image.shape
354
+ height, width = find_nearest_bucket(H, W, resolution=target_width)
355
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
356
+ Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
 
359
  stream = AsyncStream()
 
360
  async_run(
361
  worker, input_image, prompt, n_prompt, seed,
362
  total_second_length, latent_window_size, steps,
363
  cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
364
  )
 
365
  output_filename = None
 
366
  while True:
367
  flag, data = stream.output_queue.next()
 
368
  if flag == 'file':
369
  output_filename = data
370
  yield (
 
375
  gr.update(interactive=False),
376
  gr.update(interactive=True)
377
  )
 
378
  elif flag == 'progress':
379
  preview, desc, html = data
380
  yield (
 
385
  gr.update(interactive=False),
386
  gr.update(interactive=True)
387
  )
 
388
  elif flag == 'end':
389
  yield (
390
  output_filename,
 
399
  def end_process():
400
  stream.input_queue.push('end')
401
 
 
402
  quick_prompts = [
403
  'The girl dances gracefully, with clear movements, full of charm.',
404
  'A character doing some simple body movements.'