A24005179 commited on
Commit
b08f86e
·
verified ·
1 Parent(s): f991534

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -138
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
-
3
  os.environ['HF_HOME'] = os.path.abspath(
4
  os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))
5
  )
6
-
7
  import gradio as gr
8
  import torch
9
  import traceback
@@ -12,7 +10,6 @@ import safetensors.torch as sf
12
  import numpy as np
13
  import math
14
  import spaces
15
-
16
  from PIL import Image
17
  from diffusers import AutoencoderKLHunyuanVideo
18
  from transformers import (
@@ -31,68 +28,54 @@ from diffusers_helper.utils import (
31
  )
32
  from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
33
  from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
34
- from diffusers_helper.memory import (
35
- cpu, gpu,
36
- get_cuda_free_memory_gb,
37
- move_model_to_device_with_memory_preservation,
38
- offload_model_from_device_for_memory_preservation,
39
- fake_diffusers_current_device,
40
- DynamicSwapInstaller,
41
- unload_complete_models,
42
- load_model_as_complete
43
- )
44
- from diffusers_helper.thread_utils import AsyncStream, async_run
45
- from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
46
- from transformers import SiglipImageProcessor, SiglipVisionModel
47
- from diffusers_helper.clip_vision import hf_clip_vision_encode
48
- from diffusers_helper.bucket_tools import find_nearest_bucket
49
 
50
- # Check GPU memory
51
- free_mem_gb = get_cuda_free_memory_gb(gpu)
52
- high_vram = free_mem_gb > 60
53
-
54
- print(f'Free VRAM {free_mem_gb} GB')
55
- print(f'High-VRAM Mode: {high_vram}')
56
 
57
  # Load models
58
  text_encoder = LlamaModel.from_pretrained(
59
  "hunyuanvideo-community/HunyuanVideo",
60
  subfolder='text_encoder',
61
  torch_dtype=torch.float16
62
- ).cpu()
 
63
  text_encoder_2 = CLIPTextModel.from_pretrained(
64
  "hunyuanvideo-community/HunyuanVideo",
65
  subfolder='text_encoder_2',
66
  torch_dtype=torch.float16
67
- ).cpu()
 
68
  tokenizer = LlamaTokenizerFast.from_pretrained(
69
  "hunyuanvideo-community/HunyuanVideo",
70
  subfolder='tokenizer'
71
  )
 
72
  tokenizer_2 = CLIPTokenizer.from_pretrained(
73
  "hunyuanvideo-community/HunyuanVideo",
74
  subfolder='tokenizer_2'
75
  )
 
76
  vae = AutoencoderKLHunyuanVideo.from_pretrained(
77
  "hunyuanvideo-community/HunyuanVideo",
78
  subfolder='vae',
79
  torch_dtype=torch.float16
80
- ).cpu()
81
 
82
  feature_extractor = SiglipImageProcessor.from_pretrained(
83
  "lllyasviel/flux_redux_bfl",
84
  subfolder='feature_extractor'
85
  )
 
86
  image_encoder = SiglipVisionModel.from_pretrained(
87
  "lllyasviel/flux_redux_bfl",
88
  subfolder='image_encoder',
89
  torch_dtype=torch.float16
90
- ).cpu()
91
 
92
  transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
93
  'lllyasviel/FramePack_F1_I2V_HY_20250503',
94
  torch_dtype=torch.bfloat16
95
- ).cpu()
96
 
97
  # Evaluation mode
98
  vae.eval()
@@ -101,14 +84,6 @@ text_encoder_2.eval()
101
  image_encoder.eval()
102
  transformer.eval()
103
 
104
- # Slicing/Tiling for low VRAM
105
- if not high_vram:
106
- vae.enable_slicing()
107
- vae.enable_tiling()
108
-
109
- transformer.high_quality_fp32_output_for_inference = True
110
- print('transformer.high_quality_fp32_output_for_inference = True')
111
-
112
  # Move to correct dtype
113
  transformer.to(dtype=torch.bfloat16)
114
  vae.to(dtype=torch.float16)
@@ -123,19 +98,7 @@ text_encoder_2.requires_grad_(False)
123
  image_encoder.requires_grad_(False)
124
  transformer.requires_grad_(False)
125
 
126
- # DynamicSwap if low VRAM
127
- if not high_vram:
128
- DynamicSwapInstaller.install_model(transformer, device=gpu)
129
- DynamicSwapInstaller.install_model(text_encoder, device=gpu)
130
- else:
131
- text_encoder.to(gpu)
132
- text_encoder_2.to(gpu)
133
- image_encoder.to(gpu)
134
- vae.to(gpu)
135
- transformer.to(gpu)
136
-
137
  stream = AsyncStream()
138
-
139
  outputs_folder = './outputs/'
140
  os.makedirs(outputs_folder, exist_ok=True)
141
 
@@ -145,7 +108,6 @@ examples = [
145
  ["img_examples/3.png", "The woman dances elegantly among the blossoms, spinning slowly with flowing sleeves and graceful hand movements."]
146
  ]
147
 
148
- # Example generation (optional)
149
  def generate_examples(input_image, prompt):
150
  t2v=False
151
  n_prompt=""
@@ -156,32 +118,24 @@ def generate_examples(input_image, prompt):
156
  cfg=1.0
157
  gs=10.0
158
  rs=0.0
159
- gpu_memory_preservation=6
160
  use_teacache=True
161
  mp4_crf=16
162
-
163
  global stream
164
-
165
  if t2v:
166
  default_height, default_width = 640, 640
167
  input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
168
  print("No input image provided. Using a blank white image.")
169
-
170
  yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
171
-
172
  stream = AsyncStream()
173
-
174
  async_run(
175
  worker, input_image, prompt, n_prompt, seed,
176
  total_second_length, latent_window_size, steps,
177
  cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
178
  )
179
-
180
  output_filename = None
181
-
182
  while True:
183
  flag, data = stream.output_queue.next()
184
-
185
  if flag == 'file':
186
  output_filename = data
187
  yield (
@@ -192,7 +146,6 @@ def generate_examples(input_image, prompt):
192
  gr.update(interactive=False),
193
  gr.update(interactive=True)
194
  )
195
-
196
  if flag == 'progress':
197
  preview, desc, html = data
198
  yield (
@@ -203,7 +156,6 @@ def generate_examples(input_image, prompt):
203
  gr.update(interactive=False),
204
  gr.update(interactive=True)
205
  )
206
-
207
  if flag == 'end':
208
  yield (
209
  output_filename,
@@ -221,84 +173,44 @@ def worker(
221
  total_second_length, latent_window_size, steps,
222
  cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
223
  ):
224
- # Calculate total sections
225
  total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
226
  total_latent_sections = int(max(round(total_latent_sections), 1))
227
-
228
  job_id = generate_timestamp()
229
-
230
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
231
 
232
  try:
233
- # Unload if VRAM is low
234
- if not high_vram:
235
- unload_complete_models(
236
- text_encoder, text_encoder_2, image_encoder, vae, transformer
237
- )
238
-
239
- # Text encoding
240
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
241
-
242
- if not high_vram:
243
- fake_diffusers_current_device(text_encoder, gpu)
244
- load_model_as_complete(text_encoder_2, target_device=gpu)
245
-
246
  llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
247
-
248
  if cfg == 1:
249
  llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
250
  else:
251
  llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
252
-
253
  llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
254
  llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
255
 
256
- # Process image
257
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
258
-
259
  H, W, C = input_image.shape
260
  height, width = find_nearest_bucket(H, W, resolution=640)
261
  input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
262
-
263
  Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
264
-
265
  input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
266
  input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
267
 
268
- # VAE encoding
269
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
270
-
271
- if not high_vram:
272
- load_model_as_complete(vae, target_device=gpu)
273
- start_latent = vae_encode(input_image_pt, vae)
274
-
275
- # CLIP Vision
276
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
277
 
278
- if not high_vram:
279
- load_model_as_complete(image_encoder, target_device=gpu)
280
  image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
281
  image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
282
 
283
- # Convert dtype
284
- llama_vec = llama_vec.to(transformer.dtype)
285
- llama_vec_n = llama_vec_n.to(transformer.dtype)
286
- clip_l_pooler = clip_l_pooler.to(transformer.dtype)
287
- clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
288
- image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
289
-
290
- # Start sampling
291
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
292
 
293
  rnd = torch.Generator("cpu").manual_seed(seed)
294
-
295
  history_latents = torch.zeros(
296
  size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
297
  dtype=torch.float32
298
- ).cpu()
299
- history_pixels = None
300
 
301
- # Add start_latent
302
  history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
303
  total_generated_latent_frames = 1
304
 
@@ -307,15 +219,6 @@ def worker(
307
  stream.output_queue.push(('end', None))
308
  return
309
 
310
- print(f'section_index = {section_index}, total_latent_sections = {total_latent_sections}')
311
-
312
- if not high_vram:
313
- unload_complete_models()
314
- move_model_to_device_with_memory_preservation(
315
- transformer, target_device=gpu,
316
- preserved_memory_gb=gpu_memory_preservation
317
- )
318
-
319
  if use_teacache:
320
  transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
321
  else:
@@ -326,11 +229,9 @@ def worker(
326
  preview = vae_decode_fake(preview)
327
  preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
328
  preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
329
-
330
  if stream.input_queue.top() == 'end':
331
  stream.output_queue.push(('end', None))
332
  raise KeyboardInterrupt('User ends the task.')
333
-
334
  current_step = d['i'] + 1
335
  percentage = int(100.0 * current_step / steps)
336
  hint = f'Sampling {current_step}/{steps}'
@@ -350,11 +251,9 @@ def worker(
350
  ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
351
 
352
  clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
353
-
354
  clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[
355
  :, :, -sum([16, 2, 1]):, :, :
356
  ].split([16, 2, 1], dim=2)
357
-
358
  clean_latents = torch.cat(
359
  [start_latent.to(history_latents), clean_latents_1x],
360
  dim=2
@@ -377,7 +276,7 @@ def worker(
377
  negative_prompt_embeds=llama_vec_n,
378
  negative_prompt_embeds_mask=llama_attention_mask_n,
379
  negative_prompt_poolers=clip_l_pooler_n,
380
- device=gpu,
381
  dtype=torch.bfloat16,
382
  image_embeddings=image_encoder_last_hidden_state,
383
  latent_indices=latent_indices,
@@ -393,18 +292,12 @@ def worker(
393
  total_generated_latent_frames += int(generated_latents.shape[2])
394
  history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
395
 
396
- if not high_vram:
397
- offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
398
- load_model_as_complete(vae, target_device=gpu)
399
-
400
  real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
401
-
402
  if history_pixels is None:
403
  history_pixels = vae_decode(real_history_latents, vae).cpu()
404
  else:
405
  section_latent_frames = latent_window_size * 2
406
  overlapped_frames = latent_window_size * 4 - 3
407
-
408
  current_pixels = vae_decode(
409
  real_history_latents[:, :, -section_latent_frames:], vae
410
  ).cpu()
@@ -412,21 +305,12 @@ def worker(
412
  history_pixels, current_pixels, overlapped_frames
413
  )
414
 
415
- if not high_vram:
416
- unload_complete_models()
417
-
418
  output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
419
-
420
  save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
421
-
422
- print(f'Decoded. Latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
423
-
424
  stream.output_queue.push(('file', output_filename))
425
 
426
- except:
427
  traceback.print_exc()
428
- if not high_vram:
429
- unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
430
 
431
  stream.output_queue.push(('end', None))
432
  return
 
1
  import os
 
2
  os.environ['HF_HOME'] = os.path.abspath(
3
  os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))
4
  )
 
5
  import gradio as gr
6
  import torch
7
  import traceback
 
10
  import numpy as np
11
  import math
12
  import spaces
 
13
  from PIL import Image
14
  from diffusers import AutoencoderKLHunyuanVideo
15
  from transformers import (
 
28
  )
29
  from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
30
  from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Remove or replace GPU-specific imports
33
+ device = torch.device("cpu")
 
 
 
 
34
 
35
  # Load models
36
  text_encoder = LlamaModel.from_pretrained(
37
  "hunyuanvideo-community/HunyuanVideo",
38
  subfolder='text_encoder',
39
  torch_dtype=torch.float16
40
+ ).to(device)
41
+
42
  text_encoder_2 = CLIPTextModel.from_pretrained(
43
  "hunyuanvideo-community/HunyuanVideo",
44
  subfolder='text_encoder_2',
45
  torch_dtype=torch.float16
46
+ ).to(device)
47
+
48
  tokenizer = LlamaTokenizerFast.from_pretrained(
49
  "hunyuanvideo-community/HunyuanVideo",
50
  subfolder='tokenizer'
51
  )
52
+
53
  tokenizer_2 = CLIPTokenizer.from_pretrained(
54
  "hunyuanvideo-community/HunyuanVideo",
55
  subfolder='tokenizer_2'
56
  )
57
+
58
  vae = AutoencoderKLHunyuanVideo.from_pretrained(
59
  "hunyuanvideo-community/HunyuanVideo",
60
  subfolder='vae',
61
  torch_dtype=torch.float16
62
+ ).to(device)
63
 
64
  feature_extractor = SiglipImageProcessor.from_pretrained(
65
  "lllyasviel/flux_redux_bfl",
66
  subfolder='feature_extractor'
67
  )
68
+
69
  image_encoder = SiglipVisionModel.from_pretrained(
70
  "lllyasviel/flux_redux_bfl",
71
  subfolder='image_encoder',
72
  torch_dtype=torch.float16
73
+ ).to(device)
74
 
75
  transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
76
  'lllyasviel/FramePack_F1_I2V_HY_20250503',
77
  torch_dtype=torch.bfloat16
78
+ ).to(device)
79
 
80
  # Evaluation mode
81
  vae.eval()
 
84
  image_encoder.eval()
85
  transformer.eval()
86
 
 
 
 
 
 
 
 
 
87
  # Move to correct dtype
88
  transformer.to(dtype=torch.bfloat16)
89
  vae.to(dtype=torch.float16)
 
98
  image_encoder.requires_grad_(False)
99
  transformer.requires_grad_(False)
100
 
 
 
 
 
 
 
 
 
 
 
 
101
  stream = AsyncStream()
 
102
  outputs_folder = './outputs/'
103
  os.makedirs(outputs_folder, exist_ok=True)
104
 
 
108
  ["img_examples/3.png", "The woman dances elegantly among the blossoms, spinning slowly with flowing sleeves and graceful hand movements."]
109
  ]
110
 
 
111
  def generate_examples(input_image, prompt):
112
  t2v=False
113
  n_prompt=""
 
118
  cfg=1.0
119
  gs=10.0
120
  rs=0.0
121
+ gpu_memory_preservation=6 # unused
122
  use_teacache=True
123
  mp4_crf=16
 
124
  global stream
 
125
  if t2v:
126
  default_height, default_width = 640, 640
127
  input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
128
  print("No input image provided. Using a blank white image.")
 
129
  yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
 
130
  stream = AsyncStream()
 
131
  async_run(
132
  worker, input_image, prompt, n_prompt, seed,
133
  total_second_length, latent_window_size, steps,
134
  cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
135
  )
 
136
  output_filename = None
 
137
  while True:
138
  flag, data = stream.output_queue.next()
 
139
  if flag == 'file':
140
  output_filename = data
141
  yield (
 
146
  gr.update(interactive=False),
147
  gr.update(interactive=True)
148
  )
 
149
  if flag == 'progress':
150
  preview, desc, html = data
151
  yield (
 
156
  gr.update(interactive=False),
157
  gr.update(interactive=True)
158
  )
 
159
  if flag == 'end':
160
  yield (
161
  output_filename,
 
173
  total_second_length, latent_window_size, steps,
174
  cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
175
  ):
 
176
  total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
177
  total_latent_sections = int(max(round(total_latent_sections), 1))
 
178
  job_id = generate_timestamp()
 
179
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
180
 
181
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
 
183
  if cfg == 1:
184
  llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
185
  else:
186
  llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
 
187
  llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
188
  llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
189
 
 
 
 
190
  H, W, C = input_image.shape
191
  height, width = find_nearest_bucket(H, W, resolution=640)
192
  input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
 
193
  Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
 
194
  input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
195
  input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
196
 
197
+ start_latent = vae_encode(input_image_pt, vae).to(device)
 
 
 
 
 
 
 
 
198
 
 
 
199
  image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
200
  image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
201
 
202
+ llama_vec = llama_vec.to(transformer.dtype).to(device)
203
+ llama_vec_n = llama_vec_n.to(transformer.dtype).to(device)
204
+ clip_l_pooler = clip_l_pooler.to(transformer.dtype).to(device)
205
+ clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype).to(device)
206
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype).to(device)
 
 
 
 
207
 
208
  rnd = torch.Generator("cpu").manual_seed(seed)
 
209
  history_latents = torch.zeros(
210
  size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
211
  dtype=torch.float32
212
+ ).to(device)
 
213
 
 
214
  history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
215
  total_generated_latent_frames = 1
216
 
 
219
  stream.output_queue.push(('end', None))
220
  return
221
 
 
 
 
 
 
 
 
 
 
222
  if use_teacache:
223
  transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
224
  else:
 
229
  preview = vae_decode_fake(preview)
230
  preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
231
  preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
 
232
  if stream.input_queue.top() == 'end':
233
  stream.output_queue.push(('end', None))
234
  raise KeyboardInterrupt('User ends the task.')
 
235
  current_step = d['i'] + 1
236
  percentage = int(100.0 * current_step / steps)
237
  hint = f'Sampling {current_step}/{steps}'
 
251
  ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
252
 
253
  clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
 
254
  clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[
255
  :, :, -sum([16, 2, 1]):, :, :
256
  ].split([16, 2, 1], dim=2)
 
257
  clean_latents = torch.cat(
258
  [start_latent.to(history_latents), clean_latents_1x],
259
  dim=2
 
276
  negative_prompt_embeds=llama_vec_n,
277
  negative_prompt_embeds_mask=llama_attention_mask_n,
278
  negative_prompt_poolers=clip_l_pooler_n,
279
+ device=device,
280
  dtype=torch.bfloat16,
281
  image_embeddings=image_encoder_last_hidden_state,
282
  latent_indices=latent_indices,
 
292
  total_generated_latent_frames += int(generated_latents.shape[2])
293
  history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
294
 
 
 
 
 
295
  real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
 
296
  if history_pixels is None:
297
  history_pixels = vae_decode(real_history_latents, vae).cpu()
298
  else:
299
  section_latent_frames = latent_window_size * 2
300
  overlapped_frames = latent_window_size * 4 - 3
 
301
  current_pixels = vae_decode(
302
  real_history_latents[:, :, -section_latent_frames:], vae
303
  ).cpu()
 
305
  history_pixels, current_pixels, overlapped_frames
306
  )
307
 
 
 
 
308
  output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
 
309
  save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
 
 
 
310
  stream.output_queue.push(('file', output_filename))
311
 
312
+ except Exception as e:
313
  traceback.print_exc()
 
 
314
 
315
  stream.output_queue.push(('end', None))
316
  return