File size: 49,881 Bytes
e0336bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
import os
import torch
import traceback
import einops
import numpy as np
import argparse
import math
import decord
from tqdm import tqdm
import pathlib
from datetime import datetime
import imageio_ffmpeg
import tempfile
import shutil
import subprocess
import sys

from PIL import Image
try:
    from frame_pack.hunyuan_video_packed import load_packed_model
    from frame_pack.framepack_utils import (
        load_vae,
        load_text_encoder1,
        load_text_encoder2,
        load_image_encoders
    )
    from frame_pack.hunyuan import encode_prompt_conds, vae_decode, vae_encode # vae_decode_fake might be needed for previews if added
    from frame_pack.utils import crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp
    from frame_pack.k_diffusion_hunyuan import sample_hunyuan
    from frame_pack.clip_vision import hf_clip_vision_encode
    from frame_pack.bucket_tools import find_nearest_bucket
    from diffusers_helper.utils import save_bcthw_as_mp4 # from a common helper library
    from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, \
                                       move_model_to_device_with_memory_preservation, \
                                       offload_model_from_device_for_memory_preservation, \
                                       fake_diffusers_current_device, DynamicSwapInstaller, \
                                       unload_complete_models, load_model_as_complete
    # For LoRA
    from networks import lora_framepack 
    try:
        from lycoris.kohya import create_network_from_weights
    except ImportError:
        pass # Lycoris optional
    from base_wan_generate_video import merge_lora_weights # Assuming this is accessible
except ImportError as e:
    print(f"Error importing FramePack related modules: {e}. Ensure they are in PYTHONPATH.")
    sys.exit(1)


# --- Global Model Variables ---
text_encoder = None
text_encoder_2 = None
tokenizer = None
tokenizer_2 = None
vae = None
feature_extractor = None
image_encoder = None
transformer = None

high_vram = False
free_mem_gb = 0.0

outputs_folder = './outputs/' # Default, can be overridden by --output_dir

@torch.no_grad()
def video_encode(video_path, resolution, no_resize, vae_model, vae_batch_size=16, device="cuda", width=None, height=None):
    video_path = str(pathlib.Path(video_path).resolve())
    print(f"Processing video for encoding: {video_path}")

    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA is not available, falling back to CPU for video_encode")
        device = "cpu"

    try:
        print("Initializing VideoReader...")
        vr = decord.VideoReader(video_path)
        fps = vr.get_avg_fps()
        if fps == 0:
             print("Warning: VideoReader reported FPS as 0. Attempting to get it via OpenCV.")
             import cv2
             cap = cv2.VideoCapture(video_path)
             fps_cv = cap.get(cv2.CAP_PROP_FPS)
             cap.release()
             if fps_cv > 0:
                 fps = fps_cv
                 print(f"Using FPS from OpenCV: {fps}")
             else:
                 # Fallback FPS if all else fails
                 fps = 25 
                 print(f"Failed to determine FPS for the input video. Defaulting to {fps} FPS.")


        num_real_frames = len(vr)
        print(f"Video loaded: {num_real_frames} frames, FPS: {fps}")

        latent_size_factor = 4 # Hunyuan VAE downsamples by 8, but generation often uses 4x frame groups
        num_frames = (num_real_frames // latent_size_factor) * latent_size_factor
        if num_frames != num_real_frames:
            print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility (multiple of {latent_size_factor})")

        if num_frames == 0:
            raise ValueError(f"Video too short ({num_real_frames} frames) or becomes 0 after truncation. Needs at least {latent_size_factor} frames.")
        num_real_frames = num_frames

        print("Reading video frames...")
        frames_np_all = vr.get_batch(range(num_real_frames)).asnumpy()
        print(f"Frames read: {frames_np_all.shape}")

        native_height, native_width = frames_np_all.shape[1], frames_np_all.shape[2]
        print(f"Native video resolution: {native_width}x{native_height}")

        target_h_arg = native_height if height is None else height
        target_w_arg = native_width if width is None else width

        if not no_resize:
            actual_target_height, actual_target_width = find_nearest_bucket(target_h_arg, target_w_arg, resolution=resolution)
            print(f"Adjusted resolution for VAE encoding: {actual_target_width}x{actual_target_height}")
        else:
            actual_target_width = (native_width // 8) * 8
            actual_target_height = (native_height // 8) * 8
            if actual_target_width != native_width or actual_target_height != native_height:
                 print(f"Using native resolution, adjusted to be divisible by 8: {actual_target_width}x{actual_target_height}")
            else:
                print(f"Using native resolution without resizing: {actual_target_width}x{actual_target_height}")

        processed_frames_list = []
        for frame_idx in range(frames_np_all.shape[0]):
            frame = frames_np_all[frame_idx]
            frame_resized_np = resize_and_center_crop(frame, target_width=actual_target_width, target_height=actual_target_height)
            processed_frames_list.append(frame_resized_np)

        processed_frames_np_stack = np.stack(processed_frames_list)
        print(f"Frames preprocessed: {processed_frames_np_stack.shape}")

        input_image_np_for_clip_first = processed_frames_np_stack[0]
        input_image_np_for_clip_last = processed_frames_np_stack[-1]


        print("Converting frames to tensor...")
        frames_pt = torch.from_numpy(processed_frames_np_stack).float() / 127.5 - 1.0
        frames_pt = frames_pt.permute(0, 3, 1, 2) # B, H, W, C -> B, C, H, W
        frames_pt = frames_pt.unsqueeze(0).permute(0, 2, 1, 3, 4) # B, C, H, W -> 1, C, B, H, W (as VAE expects 1,C,F,H,W)
        print(f"Tensor shape for VAE: {frames_pt.shape}")

        input_video_pixels_cpu = frames_pt.clone().cpu() 

        print(f"Moving VAE and tensor to device: {device}")
        vae_model.to(device)
        frames_pt = frames_pt.to(device)

        print(f"Encoding input video frames with VAE (batch size: {vae_batch_size})")
        all_latents_list = []
        vae_model.eval()
        with torch.no_grad():
            for i in tqdm(range(0, frames_pt.shape[2], vae_batch_size), desc="VAE Encoding Video Frames", mininterval=0.1):
                batch_frames_pt = frames_pt[:, :, i:i + vae_batch_size]
                try:
                    batch_latents = vae_encode(batch_frames_pt, vae_model)
                    all_latents_list.append(batch_latents.cpu())
                except RuntimeError as e:
                    print(f"Error during VAE encoding: {str(e)}")
                    if "out of memory" in str(e).lower() and device == "cuda":
                        print("CUDA out of memory during VAE encoding. Try reducing --vae_batch_size or use CPU for VAE.")
                    raise

        history_latents_cpu = torch.cat(all_latents_list, dim=2)
        print(f"History latents shape (original video): {history_latents_cpu.shape}")

        start_latent_cpu = history_latents_cpu[:, :, :1].clone()
        end_of_input_video_latent_cpu = history_latents_cpu[:, :, -1:].clone()
        print(f"Start latent shape (for conditioning): {start_latent_cpu.shape}")
        print(f"End of input video latent shape: {end_of_input_video_latent_cpu.shape}")


        if device == "cuda":
            vae_model.to(cpu) # Move VAE back to CPU
            torch.cuda.empty_cache()
            print("VAE moved back to CPU, CUDA cache cleared")

        return (start_latent_cpu, input_image_np_for_clip_first, 
                history_latents_cpu, fps, 
                actual_target_height, actual_target_width, 
                input_video_pixels_cpu, 
                end_of_input_video_latent_cpu, input_image_np_for_clip_last)

    except Exception as e:
        print(f"Error in video_encode: {str(e)}")
        traceback.print_exc()
        raise

@torch.no_grad()
def image_encode(image_np, target_width, target_height, vae_model, image_encoder_model, feature_extractor_model, device="cuda"):
    """
    Encode a single image into a latent and compute its CLIP vision embedding.
    """
    global high_vram # Use global high_vram status
    print("Processing single image for encoding (e.g., end_frame)...")
    try:
        print(f"Using target resolution for image encoding: {target_width}x{target_height}")

        processed_image_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height)

        image_pt = torch.from_numpy(processed_image_np).float() / 127.5 - 1.0
        image_pt = image_pt.permute(2, 0, 1).unsqueeze(0).unsqueeze(2)  # N C F H W (N=1, F=1)
        
        target_vae_device = device
        if not high_vram: load_model_as_complete(vae_model, target_device=target_vae_device)
        else: vae_model.to(target_vae_device)
        image_pt_device = image_pt.to(target_vae_device)
        
        latent = vae_encode(image_pt_device, vae_model).cpu() # Encode and move to CPU
        print(f"Single image VAE output shape (latent): {latent.shape}")

        if not high_vram: unload_complete_models(vae_model) # Offload VAE if low VRAM

        target_img_enc_device = device
        if not high_vram: load_model_as_complete(image_encoder_model, target_device=target_img_enc_device)
        else: image_encoder_model.to(target_img_enc_device)

        clip_embedding_output = hf_clip_vision_encode(processed_image_np, feature_extractor_model, image_encoder_model)
        clip_embedding = clip_embedding_output.last_hidden_state.cpu() # Encode and move to CPU
        print(f"Single image CLIP embedding shape: {clip_embedding.shape}")

        if not high_vram: unload_complete_models(image_encoder_model) # Offload image encoder if low VRAM
        
        if device == "cuda":
            torch.cuda.empty_cache()
            # print("CUDA cache cleared after single image_encode")

        return latent, clip_embedding, processed_image_np

    except Exception as e:
        print(f"Error in image_encode: {str(e)}")
        traceback.print_exc()
        raise

def set_mp4_comments_imageio_ffmpeg(input_file, comments):
    try:
        ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
        if not os.path.exists(input_file):
            print(f"Error: Input file {input_file} does not exist")
            return False
        temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
        command = [
            ffmpeg_path, '-i', input_file, '-metadata', f'comment={comments}',
            '-c:v', 'copy', '-c:a', 'copy', '-y', temp_file
        ]
        result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False)
        if result.returncode == 0:
            shutil.move(temp_file, input_file)
            print(f"Successfully added comments to {input_file}")
            return True
        else:
            if os.path.exists(temp_file): os.remove(temp_file)
            print(f"Error: FFmpeg failed with message:\n{result.stderr}")
            return False
    except Exception as e:
        if 'temp_file' in locals() and os.path.exists(temp_file): os.remove(temp_file)
        print(f"Error saving prompt to video metadata, ffmpeg may be required: "+str(e))
        return False

@torch.no_grad()
def do_generation_work(
    input_video_path, prompt, n_prompt, seed,
    end_frame_path, end_frame_weight, # New arguments
    resolution_max_dim,
    additional_second_length, 
    latent_window_size, steps, cfg, gs, rs,
    gpu_memory_preservation, use_teacache, no_resize, mp4_crf,
    num_clean_frames, vae_batch_size,
    extension_only
):
    global high_vram, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, feature_extractor, image_encoder, transformer, args

    print('--- Starting Video Generation (with End Frame support) ---')

    try:
        # --- Text Encoding ---
        print('Text encoding...')
        target_text_enc_device = str(gpu if torch.cuda.is_available() else cpu)
        if not high_vram:
            if text_encoder: fake_diffusers_current_device(text_encoder, target_text_enc_device) # DynamicSwapInstaller for text_encoder
            if text_encoder_2: load_model_as_complete(text_encoder_2, target_device=target_text_enc_device)
        else:
            if text_encoder: text_encoder.to(target_text_enc_device)
            if text_encoder_2: text_encoder_2.to(target_text_enc_device)

        llama_vec_gpu, clip_l_pooler_gpu = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
        if cfg == 1.0: # Note: Original FramePack usually uses gs, cfg=1 means gs is active
            llama_vec_n_gpu, clip_l_pooler_n_gpu = torch.zeros_like(llama_vec_gpu), torch.zeros_like(clip_l_pooler_gpu)
        else: # If cfg > 1.0, it implies standard CFG, so n_prompt is used. gs should be 1.0 in this case.
            llama_vec_n_gpu, clip_l_pooler_n_gpu = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

        # Store on CPU
        llama_vec_padded_cpu, llama_attention_mask_cpu = crop_or_pad_yield_mask(llama_vec_gpu.cpu(), length=512)
        llama_vec_n_padded_cpu, llama_attention_mask_n_cpu = crop_or_pad_yield_mask(llama_vec_n_gpu.cpu(), length=512)
        clip_l_pooler_cpu = clip_l_pooler_gpu.cpu()
        clip_l_pooler_n_cpu = clip_l_pooler_n_gpu.cpu()
        
        if not high_vram: unload_complete_models(text_encoder_2) # text_encoder is managed by DynamicSwap

        # --- Video and End Frame Encoding ---
        print('Encoding input video...')
        video_encode_device = str(gpu if torch.cuda.is_available() else cpu)
        (start_latent_input_cpu, input_image_np_first, 
         video_latents_history_cpu, fps, height, width, 
         input_video_pixels_cpu,
         end_of_input_video_latent_cpu, input_image_np_last) = video_encode(
            input_video_path, resolution_max_dim, no_resize, vae, 
            vae_batch_size=vae_batch_size, device=video_encode_device,
            width=None, height=None # video_encode will use resolution_max_dim
        )
        if fps <= 0: raise ValueError("FPS from input video is 0 or invalid.")

        end_latent_from_file_cpu, end_clip_embedding_from_file_cpu = None, None
        if end_frame_path:
            print(f"Encoding provided end frame from: {end_frame_path}")
            end_frame_pil = Image.open(end_frame_path).convert("RGB")
            end_frame_np = np.array(end_frame_pil)
            end_latent_from_file_cpu, end_clip_embedding_from_file_cpu, _ = image_encode(
                end_frame_np, target_width=width, target_height=height, 
                vae_model=vae, image_encoder_model=image_encoder, 
                feature_extractor_model=feature_extractor, device=video_encode_device
            )

        # --- CLIP Vision Encoding for first and last frames of input video ---
        print('CLIP Vision encoding for input video frames...')
        target_img_enc_device = str(gpu if torch.cuda.is_available() else cpu)
        if not high_vram: load_model_as_complete(image_encoder, target_device=target_img_enc_device)
        else: image_encoder.to(target_img_enc_device)

        # For original FramePack, image_embeddings in sample_hunyuan often comes from the *start* image.
        # Script 2 uses end_of_input_video_embedding or a blend with the explicit end_frame.
        # We will follow script 2 for conditioning.
        # start_clip_embedding_cpu = hf_clip_vision_encode(input_image_np_first, feature_extractor, image_encoder).last_hidden_state.cpu()
        end_of_input_video_clip_embedding_cpu = hf_clip_vision_encode(input_image_np_last, feature_extractor, image_encoder).last_hidden_state.cpu()

        if not high_vram: unload_complete_models(image_encoder)

        # Determine final image embedding for sampling loop
        if end_clip_embedding_from_file_cpu is not None:
            print(f"Blending end-of-input-video embedding with provided end_frame embedding (weight: {end_frame_weight})")
            final_clip_embedding_for_sampling_cpu = \
                (1.0 - end_frame_weight) * end_of_input_video_clip_embedding_cpu + \
                end_frame_weight * end_clip_embedding_from_file_cpu
        else:
            print("Using end-of-input-video's last frame embedding for image conditioning.")
            final_clip_embedding_for_sampling_cpu = end_of_input_video_clip_embedding_cpu.clone()
        
        # --- Prepare for Sampling Loop ---
        target_transformer_device = str(gpu if torch.cuda.is_available() else cpu)
        if not high_vram:
            if transformer: move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation)
        else:
            if transformer: transformer.to(target_transformer_device)
        
        cond_device = transformer.device
        cond_dtype = transformer.dtype

        # Move conditioning tensors to transformer's device and dtype
        llama_vec = llama_vec_padded_cpu.to(device=cond_device, dtype=cond_dtype)
        llama_attention_mask = llama_attention_mask_cpu.to(device=cond_device) # Mask is usually bool/int
        clip_l_pooler = clip_l_pooler_cpu.to(device=cond_device, dtype=cond_dtype)
        llama_vec_n = llama_vec_n_padded_cpu.to(device=cond_device, dtype=cond_dtype)
        llama_attention_mask_n = llama_attention_mask_n_cpu.to(device=cond_device)
        clip_l_pooler_n = clip_l_pooler_n_cpu.to(device=cond_device, dtype=cond_dtype)
        
        # This is the image embedding that will be used in the sampling loop
        image_embeddings_for_sampling_loop = final_clip_embedding_for_sampling_cpu.to(device=cond_device, dtype=cond_dtype)
        
        # start_latent_for_initial_cond_gpu is the first frame of input video, used for clean_latents_pre
        # However, script 2 uses `video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):]` for clean_latents_pre.
        # And `start_latent` for sample_hunyuan's `clean_latents` is `torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)`
        # For backward generation, the "start_latent" concept for `sample_hunyuan`'s `clean_latents` argument
        # is often the *last frame of the input video* when generating the chunk closest to the input video.
        # Let's use end_of_input_video_latent_cpu for this role when appropriate.

        num_output_pixel_frames_per_section = latent_window_size * 4 # Not -3 here, as this is for total section calc
        if num_output_pixel_frames_per_section == 0:
             raise ValueError("latent_window_size * 4 is zero, cannot calculate total_extension_latent_sections.")
        total_extension_latent_sections = int(max(round((additional_second_length * fps) / num_output_pixel_frames_per_section), 1))

        print(f"Input video FPS: {fps}, Target additional length: {additional_second_length}s")
        print(f"Generating {total_extension_latent_sections} new sections for extension (approx {total_extension_latent_sections * num_output_pixel_frames_per_section / fps:.2f}s).")

        job_id_base = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + \
                 f"_framepack-vidEndFrm_{width}x{height}_{additional_second_length:.1f}s_seed{seed}_s{steps}_gs{gs}_cfg{cfg}"
        
        job_id = job_id_base
        if args.extension_only: # <<< Access args directly
            job_id += "_extonly"
            print("Extension-only mode enabled. Filenames will reflect this.")

        rnd = torch.Generator("cpu").manual_seed(seed)
        
        # Initialize history for generated latents (starts empty or with end_latent_from_file)
        if end_latent_from_file_cpu is not None:
            # This assumes end_latent_from_file_cpu is [1,C,1,H,W], we might need more frames if it's a seed
            # Script 2's logic for clean_latents_post when is_end_of_video seems to use just 1 frame.
            history_latents_generated_cpu = end_latent_from_file_cpu.clone() 
        else:
            channels_dim = video_latents_history_cpu.shape[1] # Get from input video latents
            latent_h, latent_w = height // 8, width // 8
            history_latents_generated_cpu = torch.empty((1, channels_dim, 0, latent_h, latent_w), dtype=torch.float32, device='cpu')
        
        # Initialize history for decoded pixels (starts empty)
        history_pixels_decoded_cpu = None
        
        total_generated_latent_frames_count = history_latents_generated_cpu.shape[2]
        previous_video_path_for_cleanup = None

        # Backward generation loop (from demo_gradio_video+endframe.py)
        latent_paddings = list(reversed(range(total_extension_latent_sections)))
        if total_extension_latent_sections > 4: # Heuristic from script 2
            latent_paddings = [3] + [2] * (total_extension_latent_sections - 3) + [1, 0]

        for loop_idx, latent_padding_val in enumerate(latent_paddings):
            current_section_num_from_end = loop_idx + 1
            is_start_of_extension = (latent_padding_val == 0) # This is the chunk closest to input video
            is_end_of_extension = (latent_padding_val == latent_paddings[0]) # This is the chunk furthest from input video

            print(f"--- Generating Extension: Seed {seed}: Section {current_section_num_from_end}/{total_extension_latent_sections} (backward), padding={latent_padding_val} ---")

            if transformer: transformer.initialize_teacache(enable_teacache=use_teacache, num_steps=steps if use_teacache else 0)
            progress_bar_sampler = tqdm(total=steps, desc=f"Sampling Extension Section {current_section_num_from_end}/{total_extension_latent_sections}", file=sys.stdout, dynamic_ncols=True)
            def sampler_callback_cli(d): progress_bar_sampler.update(1)

            # Context frame calculation (from demo_gradio_video+endframe.py worker)
            # `available_frames` for context refers to previously *generated* frames or input video frames
            # For `clean_latents_pre`, it's always from `video_latents_history_cpu`
            # For `clean_latents_post`, `_2x`, `_4x`, it's from `history_latents_generated_cpu`
            
            effective_clean_frames_count = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1
            
            # For clean_latents_pre (from input video)
            # If is_start_of_extension, we might want stronger anchoring to input video. Script 2 uses full `effective_clean_frames_count`.
            clean_latent_pre_frames_num = effective_clean_frames_count
            if is_start_of_extension: # Closest to input video
                 clean_latent_pre_frames_num = 1 # Script 2 uses 1 to avoid jumpcuts from input video when generating chunk closest to it.

            # For clean_latents_post, _2x, _4x (from previously generated extension chunks)
            available_generated_latents = history_latents_generated_cpu.shape[2]
            
            # `post_frames_num` is for clean_latents_post
            post_frames_num = 1 if is_end_of_extension and end_latent_from_file_cpu is not None else effective_clean_frames_count
            if is_end_of_extension and end_latent_from_file_cpu is not None: post_frames_num = 1 # script 2 detail for end_latent

            num_2x_frames_count = min(2, max(0, available_generated_latents - post_frames_num -1))
            num_4x_frames_count = min(16, max(0, available_generated_latents - post_frames_num - num_2x_frames_count))
            
            # Latent indexing for sample_hunyuan (from script 2)
            latent_padding_size_for_indices = latent_padding_val * latent_window_size
            pixel_frames_to_generate_this_step = latent_window_size * 4 - 3
            
            indices_tensor_gpu = torch.arange(0, 
                clean_latent_pre_frames_num + 
                latent_padding_size_for_indices + 
                latent_window_size + # Note: script 2 uses latent_window_size here for `latent_indices` count
                post_frames_num + 
                num_2x_frames_count + 
                num_4x_frames_count
            ).unsqueeze(0).to(cond_device)

            (clean_latent_indices_pre_gpu, 
             blank_indices_gpu, # For padding
             latent_indices_for_denoising_gpu, # For new generation
             clean_latent_indices_post_gpu,
             clean_latent_2x_indices_gpu,
             clean_latent_4x_indices_gpu
            ) = indices_tensor_gpu.split(
                [clean_latent_pre_frames_num, latent_padding_size_for_indices, latent_window_size, 
                 post_frames_num, num_2x_frames_count, num_4x_frames_count], dim=1
            )
            clean_latent_indices_combined_gpu = torch.cat([clean_latent_indices_pre_gpu, clean_latent_indices_post_gpu], dim=1)

            # Prepare conditioning latents
            # clean_latents_pre_cpu: from end of input video
            actual_pre_frames_to_take = min(clean_latent_pre_frames_num, video_latents_history_cpu.shape[2])
            clean_latents_pre_cpu = video_latents_history_cpu[:, :, -actual_pre_frames_to_take:].clone()
            if clean_latents_pre_cpu.shape[2] < clean_latent_pre_frames_num and clean_latents_pre_cpu.shape[2] > 0: # Pad if necessary
                repeats = math.ceil(clean_latent_pre_frames_num / clean_latents_pre_cpu.shape[2])
                clean_latents_pre_cpu = clean_latents_pre_cpu.repeat(1,1,repeats,1,1)[:,:,:clean_latent_pre_frames_num]
            elif clean_latents_pre_cpu.shape[2] == 0 and clean_latent_pre_frames_num > 0: # Should not happen if video_latents_history_cpu is valid
                clean_latents_pre_cpu = torch.zeros((1,channels_dim,clean_latent_pre_frames_num,latent_h,latent_w),dtype=torch.float32)


            # clean_latents_post_cpu, _2x_cpu, _4x_cpu: from start of `history_latents_generated_cpu`
            current_offset_in_generated = 0
            
            # Post frames
            actual_post_frames_to_take = min(post_frames_num, history_latents_generated_cpu.shape[2])
            if is_end_of_extension and end_latent_from_file_cpu is not None:
                clean_latents_post_cpu = end_latent_from_file_cpu.clone() # Should be [1,C,1,H,W]
            else:
                clean_latents_post_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_post_frames_to_take].clone()
            current_offset_in_generated += clean_latents_post_cpu.shape[2]
            
            if clean_latents_post_cpu.shape[2] < post_frames_num and clean_latents_post_cpu.shape[2] > 0: # Pad
                repeats = math.ceil(post_frames_num / clean_latents_post_cpu.shape[2])
                clean_latents_post_cpu = clean_latents_post_cpu.repeat(1,1,repeats,1,1)[:,:,:post_frames_num]
            elif clean_latents_post_cpu.shape[2] == 0 and post_frames_num > 0: # Fill with zeros if no history and no end_latent
                 clean_latents_post_cpu = torch.zeros((1,channels_dim,post_frames_num,latent_h,latent_w),dtype=torch.float32)

            # 2x frames
            actual_2x_frames_to_take = min(num_2x_frames_count, history_latents_generated_cpu.shape[2] - current_offset_in_generated)
            clean_latents_2x_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_2x_frames_to_take].clone()
            current_offset_in_generated += clean_latents_2x_cpu.shape[2]
            if clean_latents_2x_cpu.shape[2] < num_2x_frames_count and clean_latents_2x_cpu.shape[2] > 0: # Pad
                repeats = math.ceil(num_2x_frames_count / clean_latents_2x_cpu.shape[2])
                clean_latents_2x_cpu = clean_latents_2x_cpu.repeat(1,1,repeats,1,1)[:,:,:num_2x_frames_count]
            elif clean_latents_2x_cpu.shape[2] == 0 and num_2x_frames_count > 0:
                clean_latents_2x_cpu = torch.zeros((1,channels_dim,num_2x_frames_count,latent_h,latent_w),dtype=torch.float32)

            # 4x frames
            actual_4x_frames_to_take = min(num_4x_frames_count, history_latents_generated_cpu.shape[2] - current_offset_in_generated)
            clean_latents_4x_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_4x_frames_to_take].clone()
            if clean_latents_4x_cpu.shape[2] < num_4x_frames_count and clean_latents_4x_cpu.shape[2] > 0: # Pad
                repeats = math.ceil(num_4x_frames_count / clean_latents_4x_cpu.shape[2])
                clean_latents_4x_cpu = clean_latents_4x_cpu.repeat(1,1,repeats,1,1)[:,:,:num_4x_frames_count]
            elif clean_latents_4x_cpu.shape[2] == 0 and num_4x_frames_count > 0:
                clean_latents_4x_cpu = torch.zeros((1,channels_dim,num_4x_frames_count,latent_h,latent_w),dtype=torch.float32)

# Combine pre and post for `clean_latents` argument
            clean_latents_for_sampler_gpu = torch.cat([
                clean_latents_pre_cpu.to(device=cond_device, dtype=torch.float32), 
                clean_latents_post_cpu.to(device=cond_device, dtype=torch.float32)
            ], dim=2)

            # Ensure 2x and 4x latents are None if their frame counts are 0
            # The k_diffusion_hunyuan.sample_hunyuan and the DiT should handle None for these if indices are also empty.
            clean_latents_2x_gpu = None
            if num_2x_frames_count > 0 and clean_latents_2x_cpu.shape[2] > 0:
                clean_latents_2x_gpu = clean_latents_2x_cpu.to(device=cond_device, dtype=torch.float32)
            elif num_2x_frames_count > 0 and clean_latents_2x_cpu.shape[2] == 0: # Should have been filled with zeros if count > 0
                 print(f"Warning: num_2x_frames_count is {num_2x_frames_count} but clean_latents_2x_cpu is empty. Defaulting to None.")


            clean_latents_4x_gpu = None
            if num_4x_frames_count > 0 and clean_latents_4x_cpu.shape[2] > 0:
                clean_latents_4x_gpu = clean_latents_4x_cpu.to(device=cond_device, dtype=torch.float32)
            elif num_4x_frames_count > 0 and clean_latents_4x_cpu.shape[2] == 0:
                 print(f"Warning: num_4x_frames_count is {num_4x_frames_count} but clean_latents_4x_cpu is empty. Defaulting to None.")

            # Also, ensure indices are None or empty if counts are zero.
            # The split logic already ensures this if the split size is 0.
            # clean_latent_2x_indices_gpu will be shape (B, 0) if num_2x_frames_count is 0.
            # The DiT model should correctly interpret an empty indices tensor or None for the corresponding latent.
            generated_latents_gpu_step = sample_hunyuan( 
                transformer=transformer, sampler='unipc', width=width, height=height,
                frames=pixel_frames_to_generate_this_step, # Num frames for current chunk
                real_guidance_scale=cfg, distilled_guidance_scale=gs, guidance_rescale=rs,
                num_inference_steps=steps, generator=rnd,
                prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler,
                negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n,
                device=cond_device, dtype=cond_dtype, 
                image_embeddings=image_embeddings_for_sampling_loop, # Use the blended/final one
                latent_indices=latent_indices_for_denoising_gpu, 
                clean_latents=clean_latents_for_sampler_gpu, 
                clean_latent_indices=clean_latent_indices_combined_gpu,
                clean_latents_2x=clean_latents_2x_gpu, # Can be None
                clean_latent_2x_indices=clean_latent_2x_indices_gpu if num_2x_frames_count > 0 else None, # Pass None if count is 0
                clean_latents_4x=clean_latents_4x_gpu, # Can be None
                clean_latent_4x_indices=clean_latent_4x_indices_gpu if num_4x_frames_count > 0 else None, # Pass None if count is 0
                callback=sampler_callback_cli,
            )
            if progress_bar_sampler: progress_bar_sampler.close()

            # If this was the chunk closest to input video, prepend the last frame of input video for smoother transition
            if is_start_of_extension:
                generated_latents_gpu_step = torch.cat([
                    end_of_input_video_latent_cpu.to(generated_latents_gpu_step), # Use actual last frame latent
                    generated_latents_gpu_step
                ], dim=2)

            # Prepend generated latents to history
            history_latents_generated_cpu = torch.cat([generated_latents_gpu_step.cpu(), history_latents_generated_cpu], dim=2)
            total_generated_latent_frames_count = history_latents_generated_cpu.shape[2]
            
            # --- Decode and Append Pixels ---
            target_vae_device = str(gpu if torch.cuda.is_available() else cpu)
            if not high_vram: 
                if transformer: offload_model_from_device_for_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation)
                if vae: load_model_as_complete(vae, target_device=target_vae_device)
            else: 
                if vae: vae.to(target_vae_device)
            
            # Decode the newly generated part (or a relevant segment for stitching)
            # Script 2 decodes `real_history_latents[:, :, :section_latent_frames]`
            # section_latent_frames = (latent_window_size * 2 + 1) if is_start_of_video else (latent_window_size * 2)
            num_latents_to_decode_for_stitch = (latent_window_size * 2 + 1) if is_start_of_extension else (latent_window_size * 2)
            num_latents_to_decode_for_stitch = min(num_latents_to_decode_for_stitch, history_latents_generated_cpu.shape[2])
            
            latents_for_current_decode_gpu = history_latents_generated_cpu[:, :, :num_latents_to_decode_for_stitch].to(target_vae_device)
            
            pixels_for_current_part_decoded_cpu = vae_decode(latents_for_current_decode_gpu, vae).cpu()

            # Soft append pixels (current_pixels, history_pixels, overlap)
            overlap_for_soft_append = latent_window_size * 4 - 3 
            
            if history_pixels_decoded_cpu is None:
                history_pixels_decoded_cpu = pixels_for_current_part_decoded_cpu
            else:
                overlap_actual = min(overlap_for_soft_append, history_pixels_decoded_cpu.shape[2], pixels_for_current_part_decoded_cpu.shape[2])
                if overlap_actual <=0: # Should not happen with proper windowing
                    history_pixels_decoded_cpu = torch.cat([pixels_for_current_part_decoded_cpu, history_pixels_decoded_cpu], dim=2) # Simple prepend
                else:
                    history_pixels_decoded_cpu = soft_append_bcthw(
                        pixels_for_current_part_decoded_cpu, # Current (prepended)
                        history_pixels_decoded_cpu,          # History
                        overlap=overlap_actual
                    )
            
            if not high_vram: 
                if vae: unload_complete_models(vae) 
                if transformer and not is_start_of_extension : # Reload transformer for next iter
                     move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation)

            # Save intermediate video
            current_output_filename = os.path.join(outputs_folder, f'{job_id}_part{current_section_num_from_end}_totalframes{history_pixels_decoded_cpu.shape[2]}.mp4')
            save_bcthw_as_mp4(history_pixels_decoded_cpu, current_output_filename, fps=fps, crf=mp4_crf)
            print(f"MP4 Preview for section {current_section_num_from_end} saved: {current_output_filename}")
            set_mp4_comments_imageio_ffmpeg(current_output_filename, f"Prompt: {prompt} | Neg: {n_prompt} | Seed: {seed}");
    
            if previous_video_path_for_cleanup is not None and os.path.exists(previous_video_path_for_cleanup):
                try: os.remove(previous_video_path_for_cleanup)
                except Exception as e_del: print(f"Error deleting {previous_video_path_for_cleanup}: {e_del}")
            previous_video_path_for_cleanup = current_output_filename
            
            if is_start_of_extension: # Last iteration of backward loop
                break
        
        # --- Final Video Assembly ---
        if args.extension_only: # <<< Access args directly
            print("Saving only the generated extension...")
            # history_pixels_decoded_cpu already contains only the generated extension due to backward generation
            # and how it's accumulated.
            video_to_save_cpu = history_pixels_decoded_cpu 
            final_output_filename_suffix = "_extension_only_final.mp4"
            final_log_message = "Final extension-only video saved:"
        else:
            print("Appending generated extension to the input video...")
            # input_video_pixels_cpu is (1, C, F_in, H, W)
            # history_pixels_decoded_cpu is (1, C, F_ext, H, W)
            video_to_save_cpu = torch.cat([input_video_pixels_cpu, history_pixels_decoded_cpu], dim=2)
            final_output_filename_suffix = "_final.mp4"
            final_log_message = "Final extended video saved:"
        
        final_output_filename = os.path.join(outputs_folder, f'{job_id}{final_output_filename_suffix}') # job_id already has _extonly if needed
        save_bcthw_as_mp4(video_to_save_cpu, final_output_filename, fps=fps, crf=mp4_crf)
        print(f"{final_log_message} {final_output_filename}")
        set_mp4_comments_imageio_ffmpeg(final_output_filename, f"Prompt: {prompt} | Neg: {n_prompt} | Seed: {seed}");

        if previous_video_path_for_cleanup is not None and os.path.exists(previous_video_path_for_cleanup) and previous_video_path_for_cleanup != final_output_filename:
            try: os.remove(previous_video_path_for_cleanup)
            except Exception as e_del: print(f"Error deleting last part: {e_del}")

    except Exception as e_outer:
        traceback.print_exc()
        print(f"Error during generation: {e_outer}")
    finally:
        if not high_vram: 
            unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
        print("--- Generation work cycle finished. ---")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="FramePack Video Generation CLI (with End Frame)")
    
    # Inputs
    parser.add_argument('--input_video', type=str, required=True, help='Path to the input video file.')
    parser.add_argument('--prompt', type=str, required=True, help='Prompt for video generation.')
    parser.add_argument('--n_prompt', type=str, default="", help='Negative prompt.')
    parser.add_argument('--end_frame', type=str, default=None, help='Optional path to an image to guide the end of the video.')
    parser.add_argument('--end_frame_weight', type=float, default=1.0, help='Weight for the end_frame image conditioning (0.0 to 1.0). Default 1.0.')

    # Generation parameters
    parser.add_argument('--seed', type=int, default=31337, help='Seed for generation.')
    parser.add_argument('--resolution_max_dim', type=int, default=640, help='Target resolution (max width or height for bucket search).')
    parser.add_argument('--total_second_length', type=float, default=5.0, help='Additional video length to generate (seconds).') 
    parser.add_argument('--latent_window_size', type=int, default=9, help='Latent window size (frames for DiT). Orignal FramePack default is 9.')
    parser.add_argument('--steps', type=int, default=25, help='Number of inference steps.')
    parser.add_argument('--cfg', type=float, default=1.0, help='CFG Scale. If > 1.0, n_prompt is used and gs is set to 1.0. Default 1.0 (for distilled guidance).')
    parser.add_argument('--gs', type=float, default=10.0, help='Distilled CFG Scale (Embedded CFG for Original FramePack). Default 10.0.') # Original default
    parser.add_argument('--rs', type=float, default=0.0, help='CFG Re-Scale (usually 0.0).')
    parser.add_argument('--num_clean_frames', type=int, default=5, help='Number of 1x context frames for DiT conditioning. Script2 default 5.')
    
    # Technical parameters
    parser.add_argument('--gpu_memory_preservation', type=float, default=6.0, help='GPU memory to preserve (GB) for low VRAM mode.')
    parser.add_argument('--use_teacache', action='store_true', default=False, help='Enable TeaCache (if DiT supports it).')
    parser.add_argument('--no_resize', action='store_true', default=False, help='Force original video resolution for input video encoding (VAE).')
    parser.add_argument('--mp4_crf', type=int, default=16, help='MP4 CRF value (0-51, lower is better quality).')
    parser.add_argument('--vae_batch_size', type=int, default=-1, help='VAE batch size for input video encoding. Default: auto based on VRAM.')
    parser.add_argument('--output_dir', type=str, default='./outputs/', help="Directory to save output videos.")

    # Model paths
    parser.add_argument('--dit', type=str, required=True, help="Path to local DiT model weights file or directory (e.g., for lllyasviel/FramePackI2V_HY).")
    parser.add_argument('--vae', type=str, required=True, help="Path to local VAE model weights file or directory.")
    parser.add_argument('--text_encoder1', type=str, required=True, help="Path to Text Encoder 1 (Llama) WEIGHT FILE.")
    parser.add_argument('--text_encoder2', type=str, required=True, help="Path to Text Encoder 2 (CLIP) WEIGHT FILE.")
    parser.add_argument('--image_encoder', type=str, required=True, help="Path to Image Encoder (SigLIP) WEIGHT FILE.")
    
    # Advanced model settings
    parser.add_argument('--attn_mode', type=str, default="torch", help="Attention mode for DiT (torch, flash, xformers, etc.).")
    parser.add_argument('--fp8_llm', action='store_true', help="Use fp8 for Text Encoder 1 (Llama).") # from fpack_generate_video
    parser.add_argument("--vae_chunk_size", type=int, default=None, help="Chunk size for CausalConv3d in VAE.")
    parser.add_argument("--vae_spatial_tile_sample_min_size", type=int, default=None, help="Spatial tile sample min size for VAE.")
    
    # LoRA
    parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path(s).")
    parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier(s).")
    parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns.")
    parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns.")
    parser.add_argument('--extension_only', action='store_true', help="Save only the extension video without the input video attached.")

    args = parser.parse_args()
    
    current_device_str = str(gpu if torch.cuda.is_available() else cpu)
    args.device = current_device_str 

    for model_arg_name in ['dit', 'vae', 'text_encoder1', 'text_encoder2', 'image_encoder']:
        path_val = getattr(args, model_arg_name)
        if not os.path.exists(path_val): 
            parser.error(f"Path for --{model_arg_name} not found: {path_val}")

    outputs_folder = args.output_dir 
    os.makedirs(outputs_folder, exist_ok=True)
    print(f"Outputting videos to: {outputs_folder}")

    free_mem_gb = get_cuda_free_memory_gb(gpu if torch.cuda.is_available() else None)
    # Adjusted high_vram threshold, can be tuned
    high_vram = free_mem_gb > 30 # Example: 30GB+ for "high_vram"
    print(f'Free VRAM {free_mem_gb:.2f} GB. High-VRAM Mode: {high_vram}')

    if args.vae_batch_size == -1: 
        if free_mem_gb >= 18: args.vae_batch_size = 64 
        elif free_mem_gb >= 10: args.vae_batch_size = 32
        else: args.vae_batch_size = 16 
        print(f"Auto-set VAE batch size to: {args.vae_batch_size}")
    
    print("Loading models...")
    loading_device_str = str(cpu) # Load to CPU first

    transformer = load_packed_model(
        device=loading_device_str, 
        dit_path=args.dit,
        attn_mode=args.attn_mode, 
        loading_device=loading_device_str 
    )
    print("DiT loaded.")

    if args.lora_weight is not None and len(args.lora_weight) > 0:
        print("Merging LoRA weights...")
        if len(args.lora_multiplier) == 1 and len(args.lora_weight) > 1:
            args.lora_multiplier = args.lora_multiplier * len(args.lora_weight)
        elif len(args.lora_multiplier) != len(args.lora_weight):
            parser.error(f"Number of LoRA weights ({len(args.lora_weight)}) and multipliers ({len(args.lora_multiplier)}) must match, or provide a single multiplier.")
        
        try:
            # Mimic fpack_generate_video.py's LoRA args structure if needed by merge_lora_weights
            if not hasattr(args, 'lycoris'): args.lycoris = False 
            if not hasattr(args, 'save_merged_model'): args.save_merged_model = None 
            
            current_device_for_lora = torch.device(loading_device_str)
            merge_lora_weights(lora_framepack, transformer, args, current_device_for_lora)
            print("LoRA weights merged successfully.")
        except Exception as e_lora:
            print(f"Error merging LoRA weights: {e_lora}")
            traceback.print_exc()

    vae = load_vae(
        vae_path=args.vae, 
        vae_chunk_size=args.vae_chunk_size, 
        vae_spatial_tile_sample_min_size=args.vae_spatial_tile_sample_min_size, 
        device=loading_device_str 
    )
    print("VAE loaded.")
    
    # For text_encoder loading, fpack_generate_video.py uses args.fp8_llm for text_encoder1
    # The f1_video_cli_local.py passes `args` directly. We'll do the same.
    tokenizer, text_encoder = load_text_encoder1(args, device=loading_device_str) 
    print("Text Encoder 1 and Tokenizer 1 loaded.")
    tokenizer_2, text_encoder_2 = load_text_encoder2(args)
    print("Text Encoder 2 and Tokenizer 2 loaded.")
    feature_extractor, image_encoder = load_image_encoders(args)
    print("Image Encoder and Feature Extractor loaded.")

    all_models_list = [transformer, vae, text_encoder, text_encoder_2, image_encoder]
    for model_obj in all_models_list:
        if model_obj is not None:
            model_obj.eval().requires_grad_(False)

    # Set dtypes (Original FramePack typically bfloat16 for DiT, float16 for others)
    if transformer: transformer.to(dtype=torch.bfloat16)
    if vae: vae.to(dtype=torch.float16) 
    if image_encoder: image_encoder.to(dtype=torch.float16)
    if text_encoder: text_encoder.to(dtype=torch.float16) # Or bfloat16 if fp8_llm implies that
    if text_encoder_2: text_encoder_2.to(dtype=torch.float16)
    
    if transformer:
        transformer.high_quality_fp32_output_for_inference = True # Common setting
        print('Transformer: high_quality_fp32_output_for_inference = True')
    
    if vae and not high_vram: 
        vae.enable_slicing()
        vae.enable_tiling()

    target_gpu_device_str = str(gpu if torch.cuda.is_available() else cpu)
    if not high_vram and torch.cuda.is_available():
        print("Low VRAM mode: Setting up dynamic swapping for DiT and Text Encoder 1.")
        if transformer: DynamicSwapInstaller.install_model(transformer, device=target_gpu_device_str)
        if text_encoder: DynamicSwapInstaller.install_model(text_encoder, device=target_gpu_device_str)
        # Other models (VAE, TE2, ImgEnc) will be loaded/offloaded as needed by `load_model_as_complete` / `unload_complete_models`
        if vae: vae.to(cpu)
        if text_encoder_2: text_encoder_2.to(cpu)
        if image_encoder: image_encoder.to(cpu)
    elif torch.cuda.is_available(): 
        print(f"High VRAM mode: Moving all models to {target_gpu_device_str}.")
        for model_obj in all_models_list:
            if model_obj is not None: model_obj.to(target_gpu_device_str)
    else:
        print("Running on CPU. Models remain on CPU.")
    
    print("All models loaded and configured.")
    
    # Adjust gs if cfg > 1.0 (standard CFG mode)
    actual_gs_cli = args.gs
    if args.cfg > 1.0: 
        actual_gs_cli = 1.0 # For standard CFG, distilled guidance is turned off
        print(f"CFG > 1.0 detected ({args.cfg}), this implies standard CFG. Overriding GS to 1.0 from {args.gs}.")

    do_generation_work(
        input_video_path=args.input_video, 
        prompt=args.prompt, 
        n_prompt=args.n_prompt, 
        seed=args.seed,
        end_frame_path=args.end_frame,
        end_frame_weight=args.end_frame_weight,
        resolution_max_dim=args.resolution_max_dim, 
        additional_second_length=args.total_second_length,
        latent_window_size=args.latent_window_size, 
        steps=args.steps, 
        cfg=args.cfg, 
        gs=actual_gs_cli, 
        rs=args.rs, 
        gpu_memory_preservation=args.gpu_memory_preservation, 
        use_teacache=args.use_teacache, 
        no_resize=args.no_resize, 
        mp4_crf=args.mp4_crf, 
        num_clean_frames=args.num_clean_frames, 
        vae_batch_size=args.vae_batch_size,
        extension_only=args.extension_only
    )

    print("Video generation process completed.")