File size: 53,086 Bytes
e0336bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
# Combined and Corrected Script
#!/usr/bin/env python3

import argparse
import os
import sys
import time
import random
import traceback
from datetime import datetime
from pathlib import Path
import re # For parsing section args

import einops
import numpy as np
import torch
import av # For saving video (used by save_bcthw_as_mp4)
from PIL import Image
from tqdm import tqdm
import cv2


# --- Dependencies from diffusers_helper ---
# Ensure this library is installed or in the PYTHONPATH
try:
    # from diffusers_helper.hf_login import login # Not strictly needed for inference if models public/cached
    from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode #, vae_decode_fake # vae_decode_fake not used here
    from diffusers_helper.utils import (save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw,
                                        resize_and_center_crop, generate_timestamp)
    from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
    from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
    from diffusers_helper.memory import (cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation,
                                         offload_model_from_device_for_memory_preservation, fake_diffusers_current_device,
                                         DynamicSwapInstaller, unload_complete_models, load_model_as_complete)
    from diffusers_helper.clip_vision import hf_clip_vision_encode
    from diffusers_helper.bucket_tools import find_nearest_bucket#, bucket_options # bucket_options no longer needed here
except ImportError:
    print("Error: Could not import modules from 'diffusers_helper'.")
    print("Please ensure the 'diffusers_helper' library is installed and accessible.")
    print("You might need to clone the repository and add it to your PYTHONPATH.")
    sys.exit(1)
# --- End Dependencies ---

from diffusers import AutoencoderKLHunyuanVideo
from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
from transformers import SiglipImageProcessor, SiglipVisionModel

# --- Constants ---
DIMENSION_MULTIPLE = 16 # VAE and model constraints often require divisibility by 8 or 16. 16 is safer.
SECTION_ARG_PATTERN = re.compile(r"^(\d+):([^:]+)(?::(.*))?$") # Regex for section arg: number:image_path[:prompt]

def parse_section_args(section_strings):
    """ Parses the --section arguments into a dictionary. """
    section_data = {}
    if not section_strings:
        return section_data
    for section_str in section_strings:
        match = SECTION_ARG_PATTERN.match(section_str)
        if not match:
            print(f"Warning: Invalid section format: '{section_str}'. Expected 'number:image_path[:prompt]'. Skipping.")
            continue
        section_index_str, image_path, prompt_text = match.groups()
        section_index = int(section_index_str)
        prompt_text = prompt_text if prompt_text else None
        if not os.path.exists(image_path):
            print(f"Warning: Image path for section {section_index} ('{image_path}') not found. Skipping section.")
            continue
        if section_index in section_data:
             print(f"Warning: Duplicate section index {section_index}. Overwriting previous entry.")
        section_data[section_index] = (image_path, prompt_text)
        print(f"Parsed section {section_index}: Image='{image_path}', Prompt='{prompt_text}'")
    return section_data


def parse_args():
    parser = argparse.ArgumentParser(description="FramePack HunyuanVideo inference script (CLI version with Advanced End Frame & Section Control)")

    # --- Model Paths ---
    parser.add_argument('--transformer_path', type=str, default='lllyasviel/FramePackI2V_HY', help="Path to the FramePack Transformer model")
    parser.add_argument('--vae_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the VAE model directory")
    parser.add_argument('--text_encoder_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the Llama text encoder directory")
    parser.add_argument('--text_encoder_2_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the CLIP text encoder directory")
    parser.add_argument('--image_encoder_path', type=str, default='lllyasviel/flux_redux_bfl', help="Path to the SigLIP image encoder directory")
    parser.add_argument('--hf_home', type=str, default='./hf_download', help="Directory to download/cache Hugging Face models")

    # --- Input ---
    parser.add_argument("--input_image", type=str, required=True, help="Path to the input image (start frame)")
    parser.add_argument("--end_frame", type=str, default=None, help="Path to the optional end frame image (video end)")
    parser.add_argument("--prompt", type=str, required=True, help="Default prompt for generation")
    parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for generation")
    # <<< START: Modified Arguments for End Frame >>>
    parser.add_argument("--end_frame_weight", type=float, default=0.3, help="End frame influence weight (0.0-1.0) for blending modes ('half', 'progressive'). Higher blends more end frame *conditioning latent*.") # Default lowered further
    parser.add_argument("--end_frame_influence", type=str, default="last",
                       choices=["last", "half", "progressive", "bookend"],
                       help="How to use the global end frame: 'last' (uses end frame for initial context only, no latent blending), 'half' (blends start/end conditioning latents for second half of video), 'progressive' (gradually blends conditioning latents from end to start), 'bookend' (uses end frame conditioning latent ONLY for first generated section IF no section keyframe set, no blending otherwise). All modes use start image embedding.") # Help text updated
    # <<< END: Modified Arguments for End Frame >>>
    # <<< START: New Arguments for Section Control >>>
    parser.add_argument("--section", type=str, action='append',
                        help="Define a keyframe section. Format: 'index:image_path[:prompt]'. Index 0 is the last generated section (video start), 1 is second last, etc. Repeat for multiple sections. Example: --section 0:path/to/start_like.png:'A sunrise' --section 2:path/to/mid.png")
    # <<< END: New Arguments for Section Control >>>

    # --- Output Resolution (Choose ONE method) ---
    parser.add_argument("--target_resolution", type=int, default=None, help=f"Target resolution for the longer side for automatic aspect ratio calculation (bucketing). Used if --width and --height are not specified. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.")
    parser.add_argument("--width", type=int, default=None, help=f"Explicit target width for the output video. Overrides --target_resolution. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.")
    parser.add_argument("--height", type=int, default=None, help=f"Explicit target height for the output video. Overrides --target_resolution. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.")

    # --- Output ---
    parser.add_argument("--save_path", type=str, required=True, help="Directory to save the generated video")
    parser.add_argument("--save_intermediate_sections", action='store_true', help="Save the video after each section is generated and decoded.")
    parser.add_argument("--save_section_final_frames", action='store_true', help="Save the final decoded frame of each generated section as a PNG image.")


    # --- Generation Parameters (Matching Gradio Demo Defaults where applicable) ---
    parser.add_argument("--seed", type=int, default=None, help="Seed for generation. Random if not set.")
    parser.add_argument("--total_second_length", type=float, default=5.0, help="Total desired video length in seconds")
    parser.add_argument("--fps", type=int, default=30, help="Frames per second for the output video")
    parser.add_argument("--steps", type=int, default=25, help="Number of inference steps (changing not recommended)")
    parser.add_argument("--distilled_guidance_scale", "--gs", type=float, default=10.0, help="Distilled CFG Scale (gs)")
    parser.add_argument("--cfg", type=float, default=1.0, help="Classifier-Free Guidance Scale (fixed at 1.0 for FramePack usually)")
    parser.add_argument("--rs", type=float, default=0.0, help="CFG Rescale (fixed at 0.0 for FramePack usually)")
    parser.add_argument("--latent_window_size", type=int, default=9, help="Latent window size (changing not recommended)")

    # --- Performance / Memory ---
    parser.add_argument('--high_vram', action='store_true', help="Force high VRAM mode (loads all models to GPU)")
    parser.add_argument('--low_vram', action='store_true', help="Force low VRAM mode (uses dynamic swapping)")
    parser.add_argument("--gpu_memory_preservation", type=float, default=6.0, help="GPU memory (GB) to preserve when offloading (low VRAM mode)")
    parser.add_argument('--use_teacache', action='store_true', default=True, help="Use TeaCache optimization (default: True)")
    parser.add_argument('--no_teacache', action='store_false', dest='use_teacache', help="Disable TeaCache optimization")
    parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda', 'cpu'). Auto-detects if None.")

    args = parser.parse_args()

    # --- Argument Validation ---
    if args.seed is None:
        args.seed = random.randint(0, 2**32 - 1)
        print(f"Generated random seed: {args.seed}")

    if args.width is not None and args.height is not None:
        if args.width <= 0 or args.height <= 0:
            print(f"Error: Explicit --width ({args.width}) and --height ({args.height}) must be positive.")
            sys.exit(1)
        if args.target_resolution is not None:
            print("Warning: Both --width/--height and --target_resolution specified. Using explicit --width and --height.")
            args.target_resolution = None
    elif args.target_resolution is not None:
        if args.target_resolution <= 0:
            print(f"Error: --target_resolution ({args.target_resolution}) must be positive.")
            sys.exit(1)
        if args.width is not None or args.height is not None:
            print("Error: Cannot specify --target_resolution with only one of --width or --height. Provide both or neither.")
            sys.exit(1)
    else:
        print(f"Warning: No resolution specified. Defaulting to --target_resolution 640.")
        args.target_resolution = 640

    if args.end_frame_weight < 0.0 or args.end_frame_weight > 1.0:
        print(f"Error: --end_frame_weight must be between 0.0 and 1.0 (got {args.end_frame_weight}).")
        sys.exit(1)

    if args.width is not None and args.width % DIMENSION_MULTIPLE != 0:
         print(f"Warning: Specified --width ({args.width}) is not divisible by {DIMENSION_MULTIPLE}. It will be rounded down.")
    if args.height is not None and args.height % DIMENSION_MULTIPLE != 0:
         print(f"Warning: Specified --height ({args.height}) is not divisible by {DIMENSION_MULTIPLE}. It will be rounded down.")
    if args.target_resolution is not None and args.target_resolution % DIMENSION_MULTIPLE != 0:
         print(f"Warning: Specified --target_resolution ({args.target_resolution}) is not divisible by {DIMENSION_MULTIPLE}. The calculated dimensions will be rounded down.")

    if args.end_frame and not os.path.exists(args.end_frame):
        print(f"Error: End frame image not found at '{args.end_frame}'.")
        sys.exit(1)

    args.section_data = parse_section_args(args.section)

    os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(args.hf_home))
    os.makedirs(os.environ['HF_HOME'], exist_ok=True)

    return args


def load_models(args):
    """Loads all necessary models."""
    print("Loading models...")
    if args.device:
        device = torch.device(args.device)
    else:
        device = torch.device(gpu if torch.cuda.is_available() else cpu)
    print(f"Using device: {device}")

    print("  Loading Text Encoder 1 (Llama)...")
    text_encoder = LlamaModel.from_pretrained(args.text_encoder_path, subfolder='text_encoder', torch_dtype=torch.float16).cpu()
    print("  Loading Text Encoder 2 (CLIP)...")
    text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
    print("  Loading Tokenizer 1 (Llama)...")
    tokenizer = LlamaTokenizerFast.from_pretrained(args.text_encoder_path, subfolder='tokenizer')
    print("  Loading Tokenizer 2 (CLIP)...")
    tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path, subfolder='tokenizer_2')
    print("  Loading VAE...")
    vae = AutoencoderKLHunyuanVideo.from_pretrained(args.vae_path, subfolder='vae', torch_dtype=torch.float16).cpu()
    print("  Loading Image Feature Extractor (SigLIP)...")
    feature_extractor = SiglipImageProcessor.from_pretrained(args.image_encoder_path, subfolder='feature_extractor')
    print("  Loading Image Encoder (SigLIP)...")
    image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder_path, subfolder='image_encoder', torch_dtype=torch.float16).cpu()
    print("  Loading Transformer (FramePack)...")
    transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(args.transformer_path, torch_dtype=torch.bfloat16).cpu()

    vae.eval()
    text_encoder.eval()
    text_encoder_2.eval()
    image_encoder.eval()
    transformer.eval()

    transformer.high_quality_fp32_output_for_inference = True
    print('transformer.high_quality_fp32_output_for_inference = True')

    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    text_encoder_2.requires_grad_(False)
    image_encoder.requires_grad_(False)
    transformer.requires_grad_(False)

    print("Models loaded.")
    return {
        "text_encoder": text_encoder,
        "text_encoder_2": text_encoder_2,
        "tokenizer": tokenizer,
        "tokenizer_2": tokenizer_2,
        "vae": vae,
        "feature_extractor": feature_extractor,
        "image_encoder": image_encoder,
        "transformer": transformer,
        "device": device
    }

def adjust_to_multiple(value, multiple):
    """Rounds down value to the nearest multiple."""
    return (value // multiple) * multiple

def mix_latents(latent_a, latent_b, weight_b):
    """Mix two latents with the specified weight for latent_b."""
    if latent_a is None: return latent_b
    if latent_b is None: return latent_a

    target_device = latent_a.device
    target_dtype = latent_a.dtype
    if latent_b.device != target_device:
        latent_b = latent_b.to(target_device)
    if latent_b.dtype != target_dtype:
        latent_b = latent_b.to(dtype=target_dtype)

    if isinstance(weight_b, torch.Tensor):
        weight_b = weight_b.item()

    weight_b = max(0.0, min(1.0, weight_b))

    if weight_b == 0.0:
        return latent_a
    elif weight_b == 1.0:
        return latent_b
    else:
        return (1.0 - weight_b) * latent_a + weight_b * latent_b

def mix_embeddings(embed_a, embed_b, weight_b):
    """Mix two embedding tensors (like CLIP image embeddings) with the specified weight for embed_b."""
    if embed_a is None: return embed_b
    if embed_b is None: return embed_a

    target_device = embed_a.device
    target_dtype = embed_a.dtype
    if embed_b.device != target_device:
        embed_b = embed_b.to(target_device)
    if embed_b.dtype != target_dtype:
        embed_b = embed_b.to(dtype=target_dtype)

    if isinstance(weight_b, torch.Tensor):
        weight_b = weight_b.item()

    weight_b = max(0.0, min(1.0, weight_b))

    if weight_b == 0.0:
        return embed_a
    elif weight_b == 1.0:
        return embed_b
    else:
        return (1.0 - weight_b) * embed_a + weight_b * embed_b


def preprocess_image_for_generation(image_path, target_width, target_height, job_id, output_dir, frame_name="input"):
    """Loads, processes, and saves a single image."""
    try:
        image = Image.open(image_path).convert('RGB')
        image_np = np.array(image)
    except Exception as e:
        print(f"Error loading image '{image_path}': {e}")
        raise

    H_orig, W_orig, _ = image_np.shape
    print(f"  {frame_name.capitalize()} image loaded ({W_orig}x{H_orig}): '{image_path}'")

    image_resized_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height)
    try:
        Image.fromarray(image_resized_np).save(output_dir / f'{job_id}_{frame_name}_resized_{target_width}x{target_height}.png')
    except Exception as e:
        print(f"Warning: Could not save resized image preview for {frame_name}: {e}")

    image_pt = torch.from_numpy(image_resized_np).float() / 127.5 - 1.0
    image_pt = image_pt.permute(2, 0, 1)[None, :, None] # B=1, C=3, T=1, H, W
    print(f"  {frame_name.capitalize()} image processed to tensor shape: {image_pt.shape}")

    return image_np, image_resized_np, image_pt


@torch.no_grad()
def generate_video(args, models):
    """Generates the video using the loaded models and arguments."""

    # Unpack models
    text_encoder = models["text_encoder"]
    text_encoder_2 = models["text_encoder_2"]
    tokenizer = models["tokenizer"]
    tokenizer_2 = models["tokenizer_2"]
    vae = models["vae"]
    feature_extractor = models["feature_extractor"]
    image_encoder = models["image_encoder"]
    transformer = models["transformer"]
    device = models["device"]

    # --- Determine Memory Mode ---
    if args.high_vram and args.low_vram:
        print("Warning: Both --high_vram and --low_vram specified. Defaulting to auto-detection.")
        force_high_vram = force_low_vram = False
    else:
        force_high_vram = args.high_vram
        force_low_vram = args.low_vram

    if force_high_vram:
        high_vram = True
    elif force_low_vram:
        high_vram = False
    else:
        free_mem_gb = get_cuda_free_memory_gb(device) if device.type == 'cuda' else 0
        high_vram = free_mem_gb > 60
        print(f'Auto-detected Free VRAM {free_mem_gb:.2f} GB -> High-VRAM Mode: {high_vram}')

    # --- Configure Models based on VRAM mode ---
    if not high_vram:
        print("Configuring for Low VRAM mode...")
        vae.enable_slicing()
        vae.enable_tiling()
        print("  Installing DynamicSwap for Transformer...")
        DynamicSwapInstaller.install_model(transformer, device=device)
        print("  Installing DynamicSwap for Text Encoder 1...")
        DynamicSwapInstaller.install_model(text_encoder, device=device)
        print("Unloading models from GPU (Low VRAM setup)...")
        unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
    else:
        print("Configuring for High VRAM mode (moving models to GPU)...")
        text_encoder.to(device)
        text_encoder_2.to(device)
        image_encoder.to(device)
        vae.to(device)
        transformer.to(device)
        print("  Models moved to GPU.")

    # --- Prepare Inputs ---
    print("Preparing inputs...")
    prompt = args.prompt
    n_prompt = args.negative_prompt
    seed = args.seed
    total_second_length = args.total_second_length
    latent_window_size = args.latent_window_size
    steps = args.steps
    cfg = args.cfg
    gs = args.distilled_guidance_scale
    rs = args.rs
    gpu_memory_preservation = args.gpu_memory_preservation
    use_teacache = args.use_teacache
    fps = args.fps
    end_frame_path = args.end_frame
    end_frame_influence = args.end_frame_influence
    end_frame_weight = args.end_frame_weight
    section_data = args.section_data
    save_intermediate = args.save_intermediate_sections
    save_section_frames = args.save_section_final_frames

    total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
    total_latent_sections = int(max(round(total_latent_sections), 1))
    print(f"Calculated total latent sections: {total_latent_sections}")

    job_id = generate_timestamp() + f"_seed{seed}"
    output_dir = Path(args.save_path)
    output_dir.mkdir(parents=True, exist_ok=True)
    final_video_path = None

    # --- Section Preprocessing Storage ---
    section_latents = {}
    section_image_embeddings = {} # Still store, might be useful later
    section_prompt_embeddings = {}

    try:
        # --- Text Encoding (Global Prompts) ---
        print("Encoding global text prompts...")
        if not high_vram:
            print("  Low VRAM mode: Loading Text Encoders to GPU...")
            fake_diffusers_current_device(text_encoder, device)
            load_model_as_complete(text_encoder_2, target_device=device)
            print("  Text Encoders loaded.")

        global_llama_vec, global_clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

        if cfg == 1.0:
             print("  CFG scale is 1.0, using zero negative embeddings.")
             global_llama_vec_n, global_clip_l_pooler_n = torch.zeros_like(global_llama_vec), torch.zeros_like(global_clip_l_pooler)
        else:
             print(f"  Encoding negative prompt: '{n_prompt}'")
             global_llama_vec_n, global_clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

        global_llama_vec, global_llama_attention_mask = crop_or_pad_yield_mask(global_llama_vec, length=512)
        global_llama_vec_n, global_llama_attention_mask_n = crop_or_pad_yield_mask(global_llama_vec_n, length=512)
        print("  Global text encoded and processed.")

        # --- Section Text Encoding ---
        if section_data:
            print("Encoding section-specific prompts...")
            for section_index, (img_path, prompt_text) in section_data.items():
                if prompt_text:
                    print(f"  Encoding prompt for section {section_index}: '{prompt_text}'")
                    sec_llama_vec, sec_clip_pooler = encode_prompt_conds(prompt_text, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
                    sec_llama_vec, _ = crop_or_pad_yield_mask(sec_llama_vec, length=512)
                    section_prompt_embeddings[section_index] = (
                        sec_llama_vec.cpu().to(transformer.dtype),
                        sec_clip_pooler.cpu().to(transformer.dtype)
                    )
                    print(f"    Section {section_index} prompt encoded and stored on CPU.")
                else:
                     print(f"  Section {section_index} has no specific prompt, will use global prompt.")

        if not high_vram:
            print("  Low VRAM mode: Unloading Text Encoders from GPU...")
            unload_complete_models(text_encoder_2)
            print("  Text Encoder 2 unloaded.")

        # --- Input Image Processing & Dimension Calculation ---
        print("Processing input image and determining dimensions...")
        try:
            input_image_np_orig, _, _ = preprocess_image_for_generation(
                args.input_image, 1, 1, job_id, output_dir, "temp_input_orig"
            )
        except Exception as e:
             print(f"Error loading input image '{args.input_image}' for dimension check: {e}")
             raise
        H_orig, W_orig, _ = input_image_np_orig.shape
        print(f"  Input image original size: {W_orig}x{H_orig}")

        if args.width is not None and args.height is not None:
            target_w, target_h = args.width, args.height
            print(f"  Using explicit target dimensions: {target_w}x{target_h}")
        elif args.target_resolution is not None:
            print(f"  Calculating dimensions based on target resolution for longer side: {args.target_resolution}")
            target_h, target_w = find_nearest_bucket(H_orig, W_orig, resolution=args.target_resolution)
            print(f"  Calculated dimensions (before adjustment): {target_w}x{target_h}")
        else:
            raise ValueError("Internal Error: Resolution determination failed.")

        final_w = adjust_to_multiple(target_w, DIMENSION_MULTIPLE)
        final_h = adjust_to_multiple(target_h, DIMENSION_MULTIPLE)

        if final_w <= 0 or final_h <= 0:
            print(f"Error: Calculated dimensions ({target_w}x{target_h}) resulted in non-positive dimensions after adjusting to be divisible by {DIMENSION_MULTIPLE} ({final_w}x{final_h}).")
            raise ValueError("Adjusted dimensions are invalid.")

        if final_w != target_w or final_h != target_h:
            print(f"Warning: Adjusted dimensions from {target_w}x{target_h} to {final_w}x{final_h} to be divisible by {DIMENSION_MULTIPLE}.")
        else:
            print(f"  Final dimensions confirmed: {final_w}x{final_h}")

        width, height = final_w, final_h

        if width * height > 1024 * 1024:
             print(f"Warning: Target resolution {width}x{height} is large. Ensure you have sufficient VRAM.")

        _, input_image_resized_np, input_image_pt = preprocess_image_for_generation(
            args.input_image, width, height, job_id, output_dir, "input"
        )

        end_frame_resized_np = None
        end_frame_pt = None
        if end_frame_path:
            _, end_frame_resized_np, end_frame_pt = preprocess_image_for_generation(
                end_frame_path, width, height, job_id, output_dir, "end"
            )

        section_images_resized_np = {}
        section_images_pt = {}
        if section_data:
            print("Processing section keyframe images...")
            for section_index, (img_path, _) in section_data.items():
                _, sec_resized_np, sec_pt = preprocess_image_for_generation(
                    img_path, width, height, job_id, output_dir, f"section{section_index}"
                )
                section_images_resized_np[section_index] = sec_resized_np
                section_images_pt[section_index] = sec_pt

        # --- VAE Encoding ---
        print("VAE encoding initial frame...")
        if not high_vram:
            print("  Low VRAM mode: Loading VAE to GPU...")
            load_model_as_complete(vae, target_device=device)
            print("  VAE loaded.")

        input_image_pt_dev = input_image_pt.to(device=device, dtype=vae.dtype)
        start_latent = vae_encode(input_image_pt_dev, vae) # GPU, vae.dtype
        print(f"  Initial latent shape: {start_latent.shape}")
        print(f"  Start latent stats - Min: {start_latent.min().item():.4f}, Max: {start_latent.max().item():.4f}, Mean: {start_latent.mean().item():.4f}")

        end_frame_latent = None
        if end_frame_pt is not None:
            print("VAE encoding end frame...")
            end_frame_pt_dev = end_frame_pt.to(device=device, dtype=vae.dtype)
            end_frame_latent = vae_encode(end_frame_pt_dev, vae) # GPU, vae.dtype
            print(f"  End frame latent shape: {end_frame_latent.shape}")
            print(f"  End frame latent stats - Min: {end_frame_latent.min().item():.4f}, Max: {end_frame_latent.max().item():.4f}, Mean: {end_frame_latent.mean().item():.4f}")
            if end_frame_latent.shape != start_latent.shape:
                print(f"Warning: End frame latent shape mismatch. Reshaping.")
                try:
                    end_frame_latent = end_frame_latent.reshape(start_latent.shape)
                except Exception as reshape_err:
                     print(f"Error reshaping end frame latent: {reshape_err}. Disabling end frame.")
                     end_frame_latent = None

        if section_images_pt:
             print("VAE encoding section keyframes...")
             for section_index, sec_pt in section_images_pt.items():
                 sec_pt_dev = sec_pt.to(device=device, dtype=vae.dtype)
                 sec_latent = vae_encode(sec_pt_dev, vae) # GPU, vae.dtype
                 print(f"  Section {section_index} latent shape: {sec_latent.shape}")
                 if sec_latent.shape != start_latent.shape:
                     print(f"  Warning: Section {section_index} latent shape mismatch. Reshaping.")
                     try:
                         sec_latent = sec_latent.reshape(start_latent.shape)
                     except Exception as reshape_err:
                         print(f"  Error reshaping section {section_index} latent: {reshape_err}. Skipping section latent.")
                         continue
                 # Store on CPU as float32 for context/blending later
                 section_latents[section_index] = sec_latent.cpu().float()
                 print(f"  Section {section_index} latent encoded and stored on CPU.")

        if not high_vram:
            print("  Low VRAM mode: Unloading VAE from GPU...")
            unload_complete_models(vae)
            print("  VAE unloaded.")

        # Move essential latents to CPU as float32 for context/blending
        start_latent = start_latent.cpu().float()
        if end_frame_latent is not None:
            end_frame_latent = end_frame_latent.cpu().float()

        # --- CLIP Vision Encoding ---
        print("CLIP Vision encoding image(s)...")
        if not high_vram:
            print("  Low VRAM mode: Loading Image Encoder to GPU...")
            load_model_as_complete(image_encoder, target_device=device)
            print("  Image Encoder loaded.")

        # Encode start frame - WILL BE USED CONSISTENTLY for image_embeddings
        image_encoder_output = hf_clip_vision_encode(input_image_resized_np, feature_extractor, image_encoder)
        start_image_embedding = image_encoder_output.last_hidden_state # GPU, image_encoder.dtype
        print(f"  Start image embedding shape: {start_image_embedding.shape}")

        # Encode end frame (if provided) - Only needed if extending later
        # end_frame_embedding = None # Not needed for this strategy
        # if end_frame_resized_np is not None:
        #     pass # Skip encoding for now

        # Encode section frames (if provided) - Store for potential future use
        if section_images_resized_np:
             print("CLIP Vision encoding section keyframes (storing on CPU)...")
             for section_index, sec_resized_np in section_images_resized_np.items():
                 sec_output = hf_clip_vision_encode(sec_resized_np, feature_extractor, image_encoder)
                 sec_embedding = sec_output.last_hidden_state
                 section_image_embeddings[section_index] = sec_embedding.cpu().to(transformer.dtype)
                 print(f"  Section {section_index} embedding shape: {sec_embedding.shape}. Stored on CPU.")

        if not high_vram:
            print("  Low VRAM mode: Unloading Image Encoder from GPU...")
            unload_complete_models(image_encoder)
            print("  Image Encoder unloaded.")

        # Move start image embedding to CPU (transformer dtype)
        target_dtype = transformer.dtype
        start_image_embedding = start_image_embedding.cpu().to(target_dtype)

        # --- Prepare Global Embeddings for Transformer (CPU, transformer.dtype) ---
        print("Preparing global embeddings for Transformer...")
        global_llama_vec = global_llama_vec.cpu().to(target_dtype)
        global_llama_vec_n = global_llama_vec_n.cpu().to(target_dtype)
        global_clip_l_pooler = global_clip_l_pooler.cpu().to(target_dtype)
        global_clip_l_pooler_n = global_clip_l_pooler_n.cpu().to(target_dtype)
        print(f"  Global Embeddings prepared on CPU with dtype {target_dtype}.")

        # --- Sampling Setup ---
        print("Setting up sampling...")
        rnd = torch.Generator(cpu).manual_seed(seed)
        num_frames = latent_window_size * 4 - 3
        print(f"  Latent frames per sampling step (num_frames input): {num_frames}")

        latent_c, latent_h, latent_w = start_latent.shape[1], start_latent.shape[3], start_latent.shape[4]
        context_latents = torch.zeros(size=(1, latent_c, 1 + 2 + 16, latent_h, latent_w), dtype=torch.float32).cpu()

        accumulated_generated_latents = None
        history_pixels = None

        latent_paddings = list(reversed(range(total_latent_sections)))
        if total_latent_sections > 4:
            latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
            print(f"  Using adjusted padding sequence for >4 sections: {latent_paddings}")
        else:
            print(f"  Using standard padding sequence: {latent_paddings}")

        # --- [MODIFIED] Restore Initial Context Initialization ---
        if end_frame_latent is not None:
            print("  Initializing context buffer's first slot with end frame latent.")
            context_latents[:, :, 0:1, :, :] = end_frame_latent.cpu().float() # Ensure float32 CPU
        else:
            print("  No end frame latent available. Initial context remains zeros.")
        # --- End Modified Context Initialization ---

        # --- Main Sampling Loop (Generates Backward: End -> Start) ---
        start_time = time.time()
        num_loops = len(latent_paddings)

        for i_loop, latent_padding in enumerate(latent_paddings):
            section_start_time = time.time()
            current_section_index_from_end = latent_padding
            is_first_generation_step = (i_loop == 0)
            is_last_generation_step = (latent_padding == 0)

            print(f"\n--- Starting Generation Step {i_loop+1}/{num_loops} (Section Index from End: {current_section_index_from_end}, First Step: {is_first_generation_step}, Last Step: {is_last_generation_step}) ---")
            latent_padding_size = latent_padding * latent_window_size
            print(f'  Padding size (latent frames): {latent_padding_size}, Window size (latent frames): {latent_window_size}')

            # --- Select Conditioning Inputs for this Section ---

            # 1. Conditioning Latent (`clean_latents_pre`) - Calculate Blend
            # Determine the base latent (start or section-specific)
            base_conditioning_latent = start_latent # Default to start (float32 CPU)
            if current_section_index_from_end in section_latents:
                base_conditioning_latent = section_latents[current_section_index_from_end] # Use section if available (float32 CPU)
                print(f"  Using SECTION {current_section_index_from_end} latent as base conditioning latent.")
            else:
                print(f"  Using START frame latent as base conditioning latent.")

            # Apply 'bookend' override to the base latent for the first step only
            if end_frame_influence == "bookend" and is_first_generation_step and end_frame_latent is not None:
                if current_section_index_from_end not in section_latents:
                     base_conditioning_latent = end_frame_latent # float32 CPU
                     print("  Applying 'bookend': Overriding base conditioning latent with END frame latent for first step.")

            # Blend the base conditioning latent with the end frame latent based on mode/weight
            current_conditioning_latent = base_conditioning_latent # Initialize with base
            current_end_frame_latent_weight = 0.0
            if end_frame_latent is not None: # Only blend if end frame exists
                if end_frame_influence == 'progressive':
                    progress = i_loop / max(1, num_loops - 1)
                    current_end_frame_latent_weight = args.end_frame_weight * (1.0 - progress)
                elif end_frame_influence == 'half':
                    if i_loop < num_loops / 2:
                        current_end_frame_latent_weight = args.end_frame_weight
                # For 'last' and 'bookend', weight remains 0, no blending needed

                current_end_frame_latent_weight = max(0.0, min(1.0, current_end_frame_latent_weight))

                if current_end_frame_latent_weight > 1e-4: # Mix only if weight is significant
                    print(f"  Blending Conditioning Latent: Base<-{1.0-current_end_frame_latent_weight:.3f} | End->{current_end_frame_latent_weight:.3f} (Mode: {end_frame_influence})")
                    # Ensure both inputs to mix_latents are float32 CPU
                    current_conditioning_latent = mix_latents(base_conditioning_latent.cpu().float(),
                                                              end_frame_latent.cpu().float(),
                                                              current_end_frame_latent_weight)
                #else:
                #    print(f"  Using BASE conditioning latent (Mode: {end_frame_influence}, Blend Weight near zero).") # Can be verbose
            #else:
            #    print(f"  Using BASE conditioning latent (No end frame specified for blending).") # Can be verbose


            # 2. Image Embedding - Use Fixed Start Embedding
            current_image_embedding = start_image_embedding # transformer.dtype CPU
            print(f"  Using fixed START frame image embedding.")


            # 3. Text Embedding (Select section or global)
            if current_section_index_from_end in section_prompt_embeddings:
                 current_llama_vec, current_clip_pooler = section_prompt_embeddings[current_section_index_from_end]
                 print(f"  Using SECTION {current_section_index_from_end} prompt embeddings.")
            else:
                 current_llama_vec = global_llama_vec
                 current_clip_pooler = global_clip_l_pooler
                 print(f"  Using GLOBAL prompt embeddings.")

            current_llama_vec_n = global_llama_vec_n
            current_clip_pooler_n = global_clip_l_pooler_n
            current_llama_attention_mask = global_llama_attention_mask
            current_llama_attention_mask_n = global_llama_attention_mask_n

            # --- Prepare Sampler Inputs ---
            indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
            clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = \
                indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
            clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)

            # Prepare conditioning latents (float32 CPU)
            clean_latents_pre = current_conditioning_latent # Use the potentially blended one
            clean_latents_post, clean_latents_2x, clean_latents_4x = \
                context_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
            clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
            print(f"  Final Conditioning shapes (CPU): clean={clean_latents.shape}, 2x={clean_latents_2x.shape}, 4x={clean_latents_4x.shape}")
            print(f"  Clean Latents Pre stats - Min: {clean_latents_pre.min().item():.4f}, Max: {clean_latents_pre.max().item():.4f}, Mean: {clean_latents_pre.mean().item():.4f}")


            # Load Transformer (Low VRAM)
            if not high_vram:
                print("  Moving Transformer to GPU...")
                unload_complete_models()
                move_model_to_device_with_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation)
                fake_diffusers_current_device(text_encoder, device)

            # Configure TeaCache
            if use_teacache:
                transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
                print("  TeaCache enabled.")
            else:
                transformer.initialize_teacache(enable_teacache=False)
                print("  TeaCache disabled.")

            # --- Run Sampling ---
            print(f"  Starting sampling ({steps} steps) for {num_frames} latent frames...")
            sampling_step_start_time = time.time()

            pbar = tqdm(total=steps, desc=f"  Section {current_section_index_from_end} Sampling", leave=False)
            def callback(d):
                pbar.update(1)
                return

            current_sampler_device = transformer.device
            current_text_encoder_device = text_encoder.device if not high_vram else device

            # Move tensors to device just before sampling
            _prompt_embeds = current_llama_vec.to(current_text_encoder_device)
            _prompt_embeds_mask = current_llama_attention_mask.to(current_text_encoder_device)
            _prompt_poolers = current_clip_pooler.to(current_sampler_device)
            _negative_prompt_embeds = current_llama_vec_n.to(current_text_encoder_device)
            _negative_prompt_embeds_mask = current_llama_attention_mask_n.to(current_text_encoder_device)
            _negative_prompt_poolers = current_clip_pooler_n.to(current_sampler_device)
            _image_embeddings = current_image_embedding.to(current_sampler_device) # Fixed start embedding
            _latent_indices = latent_indices.to(current_sampler_device)
            # Pass conditioning latents (now potentially blended) to sampler
            _clean_latents = clean_latents.to(current_sampler_device, dtype=transformer.dtype)
            _clean_latent_indices = clean_latent_indices.to(current_sampler_device)
            _clean_latents_2x = clean_latents_2x.to(current_sampler_device, dtype=transformer.dtype)
            _clean_latent_2x_indices = clean_latent_2x_indices.to(current_sampler_device)
            _clean_latents_4x = clean_latents_4x.to(current_sampler_device, dtype=transformer.dtype)
            _clean_latent_4x_indices = clean_latent_4x_indices.to(current_sampler_device)

            generated_latents_gpu = sample_hunyuan(
                transformer=transformer,
                sampler='unipc',
                width=width,
                height=height,
                frames=num_frames,
                real_guidance_scale=cfg,
                distilled_guidance_scale=gs,
                guidance_rescale=rs,
                num_inference_steps=steps,
                generator=rnd,
                prompt_embeds=_prompt_embeds,
                prompt_embeds_mask=_prompt_embeds_mask,
                prompt_poolers=_prompt_poolers,
                negative_prompt_embeds=_negative_prompt_embeds,
                negative_prompt_embeds_mask=_negative_prompt_embeds_mask,
                negative_prompt_poolers=_negative_prompt_poolers,
                device=current_sampler_device,
                dtype=transformer.dtype,
                image_embeddings=_image_embeddings, # Using fixed start embedding
                latent_indices=_latent_indices,
                clean_latents=_clean_latents, # Using potentially blended latents
                clean_latent_indices=_clean_latent_indices,
                clean_latents_2x=_clean_latents_2x,
                clean_latent_2x_indices=_clean_latent_2x_indices,
                clean_latents_4x=_clean_latents_4x,
                clean_latent_4x_indices=_clean_latent_4x_indices,
                callback=callback,
            )
            pbar.close()
            sampling_step_end_time = time.time()
            print(f"  Sampling finished in {sampling_step_end_time - sampling_step_start_time:.2f} seconds.")
            print(f"  Raw generated latent shape for this step: {generated_latents_gpu.shape}")
            print(f"  Generated latents stats (GPU) - Min: {generated_latents_gpu.min().item():.4f}, Max: {generated_latents_gpu.max().item():.4f}, Mean: {generated_latents_gpu.mean().item():.4f}")

            # Move generated latents to CPU as float32
            generated_latents_cpu = generated_latents_gpu.cpu().float()
            del generated_latents_gpu, _prompt_embeds, _prompt_embeds_mask, _prompt_poolers, _negative_prompt_embeds, _negative_prompt_embeds_mask, _negative_prompt_poolers
            del _image_embeddings, _latent_indices, _clean_latents, _clean_latent_indices, _clean_latents_2x, _clean_latent_2x_indices, _clean_latents_4x, _clean_latent_4x_indices
            if device.type == 'cuda': torch.cuda.empty_cache()

            # Offload Transformer and TE1 (Low VRAM)
            if not high_vram:
                print("  Low VRAM mode: Offloading Transformer and Text Encoder from GPU...")
                offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation)
                offload_model_from_device_for_memory_preservation(text_encoder, target_device=device, preserved_memory_gb=gpu_memory_preservation)
                print("  Transformer and Text Encoder offloaded.")

            # --- History/Context Update ---
            if is_last_generation_step:
                print("  Last generation step: Prepending start frame latent to generated latents.")
                generated_latents_cpu = torch.cat([start_latent.cpu().float(), generated_latents_cpu], dim=2)
                print(f"  Shape after prepending start latent: {generated_latents_cpu.shape}")

            context_latents = torch.cat([generated_latents_cpu, context_latents], dim=2)
            print(f"  Context buffer updated. New shape: {context_latents.shape}")

            # Accumulate the generated latents for the final video output
            if accumulated_generated_latents is None:
                 accumulated_generated_latents = generated_latents_cpu
            else:
                 accumulated_generated_latents = torch.cat([generated_latents_cpu, accumulated_generated_latents], dim=2)

            current_total_latent_frames = accumulated_generated_latents.shape[2]
            print(f"  Accumulated generated latents updated. Total latent frames: {current_total_latent_frames}")
            print(f"  Accumulated latents stats - Min: {accumulated_generated_latents.min().item():.4f}, Max: {accumulated_generated_latents.max().item():.4f}, Mean: {accumulated_generated_latents.mean().item():.4f}")

            # --- VAE Decoding & Merging ---
            print("  Decoding generated latents and merging video...")
            decode_start_time = time.time()

            if not high_vram:
                print("    Moving VAE to GPU...")
                offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation)
                unload_complete_models(text_encoder, text_encoder_2, image_encoder)
                load_model_as_complete(vae, target_device=device)
                print("    VAE loaded.")

            print(f"    Decoding current section's latents (shape: {generated_latents_cpu.shape}) for append.")
            latents_to_decode_for_append = generated_latents_cpu.to(device=device, dtype=vae.dtype)
            current_pixels = vae_decode(latents_to_decode_for_append, vae).cpu().float() # Decode and move to CPU float32
            print(f"    Decoded pixels for append shape: {current_pixels.shape}")
            del latents_to_decode_for_append
            if device.type == 'cuda': torch.cuda.empty_cache()

            if history_pixels is None:
                 history_pixels = current_pixels
                 print(f"    Initialized history_pixels shape: {history_pixels.shape}")
            else:
                append_overlap = 3
                print(f"    Appending section with pixel overlap: {append_overlap}")
                history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlap=append_overlap)
                print(f"    Appended. New total pixel shape: {history_pixels.shape}")

            if not high_vram:
                print("    Low VRAM mode: Unloading VAE from GPU...")
                unload_complete_models(vae)
                print("    VAE unloaded.")

            decode_end_time = time.time()
            print(f"  Decoding and merging finished in {decode_end_time - decode_start_time:.2f} seconds.")

            # --- Save Intermediate/Section Output ---
            current_num_pixel_frames = history_pixels.shape[2]

            if save_section_frames:
                try:
                    first_frame_index = 0 # Index 0 of the newly decoded chunk is the first frame generated in this step
                    frame_to_save = current_pixels[0, :, first_frame_index, :, :]
                    frame_to_save = einops.rearrange(frame_to_save, 'c h w -> h w c')
                    frame_to_save_np = frame_to_save.cpu().numpy()
                    frame_to_save_np = np.clip((frame_to_save_np * 127.5 + 127.5), 0, 255).astype(np.uint8)
                    section_frame_filename = output_dir / f'{job_id}_section_start_frame_idx{current_section_index_from_end}.png' # Renamed for clarity
                    Image.fromarray(frame_to_save_np).save(section_frame_filename)
                    print(f"  Saved first generated pixel frame of section {current_section_index_from_end} (from decoded chunk) to: {section_frame_filename}")
                except Exception as e:
                     print(f"  [WARN] Error saving section {current_section_index_from_end} start frame image: {e}")

            if save_intermediate or is_last_generation_step:
                output_filename = output_dir / f'{job_id}_step{i_loop+1}_idx{current_section_index_from_end}_frames{current_num_pixel_frames}_{width}x{height}.mp4'
                print(f"  Saving {'intermediate' if not is_last_generation_step else 'final'} video ({current_num_pixel_frames} frames) to: {output_filename}")
                try:
                    save_bcthw_as_mp4(history_pixels.float(), str(output_filename), fps=int(fps))
                    print(f"  Saved video using save_bcthw_as_mp4")
                    if not is_last_generation_step:
                        print(f"INTERMEDIATE_VIDEO_PATH:{output_filename}")
                    final_video_path = str(output_filename)
                except Exception as e:
                    print(f"  Error saving video using save_bcthw_as_mp4: {e}")
                    traceback.print_exc()
                    # Fallback save attempt
                    try:
                        first_frame_img = history_pixels.float()[0, :, 0].permute(1, 2, 0).cpu().numpy()
                        first_frame_img = (first_frame_img * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
                        frame_path = str(output_filename).replace('.mp4', '_first_frame_ERROR.png')
                        Image.fromarray(first_frame_img).save(frame_path)
                        print(f"  Saved first frame as image to {frame_path} due to video saving error.")
                    except Exception as frame_err:
                        print(f"  Could not save first frame either: {frame_err}")

            section_end_time = time.time()
            print(f"--- Generation Step {i_loop+1} finished in {section_end_time - section_start_time:.2f} seconds ---")

            if is_last_generation_step:
                print("\nFinal generation step completed.")
                break

        # --- Final Video Saved During Last Step ---
        if final_video_path and os.path.exists(final_video_path):
            print(f"\nSuccessfully generated: {final_video_path}")
            print(f"ACTUAL_FINAL_PATH:{final_video_path}")
            return final_video_path
        else:
             print("\nError: Final video path not found or not saved correctly.")
             return None

    except Exception as e:
        print("\n--- ERROR DURING GENERATION ---")
        traceback.print_exc()
        print("-----------------------------")
        if 'history_pixels' in locals() and history_pixels is not None and history_pixels.shape[2] > 0:
             partial_output_name = output_dir / f"{job_id}_partial_ERROR_{history_pixels.shape[2]}_frames_{width}x{height}.mp4"
             print(f"Attempting to save partial video to: {partial_output_name}")
             try:
                 save_bcthw_as_mp4(history_pixels.float(), str(partial_output_name), fps=fps)
                 print(f"ACTUAL_FINAL_PATH:{partial_output_name}")
                 return str(partial_output_name)
             except Exception as save_err:
                 print(f"Error saving partial video during error handling: {save_err}")
                 traceback.print_exc()

        print("Status: Error occurred, no video saved.")
        return None

    finally:
        print("Performing final model cleanup...")
        try:
            unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
        except Exception as e:
             print(f"Error during final model unload: {e}")
             pass
        if device.type == 'cuda':
            torch.cuda.empty_cache()
            print("CUDA cache cleared.")

def main():
    args = parse_args()
    models = load_models(args)
    final_path = generate_video(args, models)
    if final_path:
        print(f"\nVideo generation finished. Final path: {final_path}")
        sys.exit(0)
    else:
        print("\nVideo generation failed.")
        sys.exit(1)

if __name__ == "__main__":
    main()