diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..cd5917f570faa272edb0a774e063d4bcbba416fb 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +images/screenshot.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..732e83dd206d67c22356104ba66c46967b1cd38b --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +__pycache__/ +.venv +venv/ +logs/ +uv.lock +.env +env/ +outputs/ +GIMM-VFI/ +hunyuan/ +temp_frames/ +wan/wan2.1_i2v_480p_14B_bf16.safetensors +wan/wan2.1_t2v_14B_bf16.safetensors +wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth +wan/models_t5_umt5-xxl-enc-bf16.pth +triton-3.0.0-cp310-cp310-win_amd64.whl +wan/Wan2.1_VAE.pth +flash_attn-2.7.4+cu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl diff --git a/.python-version b/.python-version new file mode 100644 index 0000000000000000000000000000000000000000..c8cfe3959183f8e9a50f83f54cd723f2dc9c252d --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.10 diff --git a/README.ja.md b/README.ja.md new file mode 100644 index 0000000000000000000000000000000000000000..d372c6e97340e03b7ee5bd0b8b1ebfcd50326eab --- /dev/null +++ b/README.ja.md @@ -0,0 +1,377 @@ +# Musubi Tuner + +[English](./README.md) | [日本語](./README.ja.md) + +## 目次 + +- [はじめに](#はじめに) + - [最近の更新](#最近の更新) + - [リリースについて](#リリースについて) +- [概要](#概要) + - [ハードウェア要件](#ハードウェア要件) + - [特徴](#特徴) +- [インストール](#インストール) +- [モデルのダウンロード](#モデルのダウンロード) + - [HunyuanVideoの公式モデルを使う](#HunyuanVideoの公式モデルを使う) + - [Text EncoderにComfyUI提供のモデルを使う](#Text-EncoderにComfyUI提供のモデルを使う) +- [使い方](#使い方) + - [データセット設定](#データセット設定) + - [latentの事前キャッシュ](#latentの事前キャッシュ) + - [Text Encoder出力の事前キャッシュ](#Text-Encoder出力の事前キャッシュ) + - [学習](#学習) + - [LoRAの重みのマージ](#LoRAの重みのマージ) + - [推論](#推論) + - [LoRAの形式の変換](#LoRAの形式の変換) +- [その他](#その他) + - [SageAttentionのインストール方法](#SageAttentionのインストール方法) +- [免責事項](#免責事項) +- [コントリビューションについて](#コントリビューションについて) +- [ライセンス](#ライセンス) + +## はじめに + +このリポジトリは、HunyuanVideoのLoRA学習用のコマンドラインツールです。このリポジトリは非公式であり、公式のHunyuanVideoリポジトリとは関係ありません。 + +*リポジトリは開発中です。* + +### 最近の更新 + +- 2025/01/20 + - uv によるインストール手順を試験的に追加しました。PR [#51](https://github.com/kohya-ss/musubi-tuner/pull/51) bmaltais 氏に感謝いたします。ただ、設定等は詰められていないため、フィードバックを歓迎します。 + - 高度な設定に、[TensorBoard形式のログの保存と参照](./docs/advanced_config.md#save-and-view-logs-in-tensorboard-format--tensorboard形式のログの保存と参照)を追加しました。 + +- 2025/01/19 + - latentとText Encoder出力の事前キャッシュ時に、データセットに含まれないキャッシュファイルを自動で消去するようにしました。これにより予期しないファイルが残り、学習に使用されてしまう問題が解消されます。 + - `--keep_cache`で今までと同様にキャッシュファイルを残すことができます。 + - Text Encoder出力の事前キャッシュ時に、`--skip_existing`を指定すると正しく動作しない問題を修正しました。 + +- 2025/01/18 + - `hv_generate_video.py`でvideo2videoの推論が可能になりました。詳細は[推論](#推論)を参照してください。 + +- 2025/01/16 + - LoRAの重みをマージするスクリプト、`merge_lora.py`が追加されました。PR [#37](https://github.com/kohya-ss/musubi-tuner/pull/37) kaykyr氏に感謝いたします。詳細は[LoRAの重みのマージ](#LoRAの重みのマージ)を参照してください。 + - サンプルの学習設定を、学習率を2e-4に、`--timestep_sampling`を`shift`に、`--discrete_flow_shift`を7.0に変更しました。より高速な学習が期待されます。詳細は[学習](#学習)を参照してください。 + +- 2025/01/14 + - `hv_generate_video.py`に、LoRAマージ後のDiTモデルを保存する`--save_merged_model`オプションを暫定的に追加しました。詳細は[推論](#推論)を参照してください。 + +- 2025/01/13 + - 学習中のサンプル画像(動画)がぼやける現象に対応するため、サンプル生成時の設定を変更しました。詳細は[こちら](./docs/sampling_during_training.md)をご参照ください。 + - 推論時にdiscrete flow shiftとguidance scaleを正しく設定する必要がありますが、学習時の設定がそのまま使われていたため、この事象が発生していました。デフォルト値を設定したため、改善されると思われます。また`--fs`でdiscrete flow shiftを、`--g`でguidance scaleを指定できます。 + +### リリースについて + +Musubi Tunerの解説記事執筆や、関連ツールの開発に取り組んでくださる方々に感謝いたします。このプロジェクトは開発中のため、互換性のない変更や機能追加が起きる可能性があります。想定外の互換性問題を避けるため、参照用として[リリース](https://github.com/kohya-ss/musubi-tuner/releases)をお使いください。 + +最新のリリースとバージョン履歴は[リリースページ](https://github.com/kohya-ss/musubi-tuner/releases)で確認できます。 + +## 概要 + +### ハードウェア要件 + +- VRAM: 静止画での学習は12GB以上推奨、動画での学習は24GB以上推奨。 + - *解像度等の学習設定により異なります。*12GBでは解像度 960x544 以下とし、`--blocks_to_swap`、`--fp8_llm`等の省メモリオプションを使用してください。 +- メインメモリ: 64GB以上を推奨、32GB+スワップで動作するかもしれませんが、未検証です。 + +### 特徴 + +- 省メモリに特化 +- Windows対応(Linuxでの動作報告もあります) +- マルチGPUには対応していません + +## インストール + +### pipによるインストール + +Python 3.10以上を使用してください(3.10で動作確認済み)。 + +適当な仮想環境を作成し、ご利用のCUDAバージョンに合わせたPyTorchとtorchvisionをインストールしてください。 + +PyTorchはバージョン2.5.1以上を使用してください([補足](#PyTorchのバージョンについて))。 + +```bash +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 +``` + +以下のコマンドを使用して、必要な依存関係をインストールします。 + +```bash +pip install -r requirements.txt +``` + +オプションとして、FlashAttention、SageAttention(推論にのみ使用、インストール方法は[こちら](#SageAttentionのインストール方法)を参照)を使用できます。 + +また、`ascii-magic`(データセットの確認に使用)、`matplotlib`(timestepsの可視化に使用)、`tensorboard`(学習ログの記録に使用)を必要に応じてインストールしてください。 + +```bash +pip install ascii-magic matplotlib tensorboard +``` +### uvによるインストール + +uvを使用してインストールすることもできますが、uvによるインストールは試験的なものです。フィードバックを歓迎します。 + +#### Linux/MacOS + +```sh +curl -LsSf https://astral.sh/uv/install.sh | sh +``` + +表示される指示に従い、pathを設定してください。 + +#### Windows + +```powershell +powershell -c "irm https://astral.sh/uv/install.ps1 | iex" +``` + +表示される指示に従い、PATHを設定するか、この時点でシステムを再起動してください。 + +## モデルのダウンロード + +以下のいずれかの方法で、モデルをダウンロードしてください。 + +### HunyuanVideoの公式モデルを使う + +[公式のREADME](https://github.com/Tencent/HunyuanVideo/blob/main/ckpts/README.md)を参考にダウンロードし、任意のディレクトリに以下のように配置します。 + +``` + ckpts + ├──hunyuan-video-t2v-720p + │ ├──transformers + │ ├──vae + ├──text_encoder + ├──text_encoder_2 + ├──... +``` + +### Text EncoderにComfyUI提供のモデルを使う + +こちらの方法の方がより簡単です。DiTとVAEのモデルはHumyuanVideoのものを使用します。 + +https://huggingface.co/tencent/HunyuanVideo/tree/main/hunyuan-video-t2v-720p/transformers から、[mp_rank_00_model_states.pt](https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt) をダウンロードし、任意のディレクトリに配置します。 + +(同じページにfp8のモデルもありますが、未検証です。) + +`--fp8_base`を指定して学習する場合は、`mp_rank_00_model_states.pt`の代わりに、[こちら](https://huggingface.co/kohya-ss/HunyuanVideo-fp8_e4m3fn-unofficial)の`mp_rank_00_model_states_fp8.safetensors`を使用可能です。(このファイルは非公式のもので、重みを単純にfloat8_e4m3fnに変換したものです。) + +また、https://huggingface.co/tencent/HunyuanVideo/tree/main/hunyuan-video-t2v-720p/vae から、[pytorch_model.pt](https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/vae/pytorch_model.pt) をダウンロードし、任意のディレクトリに配置します。 + +Text EncoderにはComfyUI提供のモデルを使用させていただきます。[ComyUIのページ](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)を参考に、https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/tree/main/split_files/text_encoders から、llava_llama3_fp16.safetensors (Text Encoder 1、LLM)と、clip_l.safetensors (Text Encoder 2、CLIP)をダウンロードし、任意のディレクトリに配置します。 + +(同じページにfp8のLLMモデルもありますが、動作未検証です。) + +## 使い方 + +### データセット設定 + +[こちら](./dataset/dataset_config.md)を参照してください。 + +### latentの事前キャッシュ + +latentの事前キャッシュは必須です。以下のコマンドを使用して、事前キャッシュを作成してください。(pipによるインストールの場合) + +```bash +python cache_latents.py --dataset_config path/to/toml --vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt --vae_chunk_size 32 --vae_tiling +``` + +uvでインストールした場合は、`uv run python cache_latents.py ...`のように、`uv run`を先頭につけてください。以下のコマンドも同様です。 + +その他のオプションは`python cache_latents.py --help`で確認できます。 + +VRAMが足りない場合は、`--vae_spatial_tile_sample_min_size`を128程度に減らし、`--batch_size`を小さくしてください。 + +`--debug_mode image` を指定するとデータセットの画像とキャプションが新規ウィンドウに表示されます。`--debug_mode console`でコンソールに表示されます(`ascii-magic`が必要)。 + +デフォルトではデータセットに含まれないキャッシュファイルは自動的に削除されます。`--keep_cache`を指定すると、キャッシュファイルを残すことができます。 + +### Text Encoder出力の事前キャッシュ + +Text Encoder出力の事前キャッシュは必須です。以下のコマンドを使用して、事前キャッシュを作成してください。 + +```bash +python cache_text_encoder_outputs.py --dataset_config path/to/toml --text_encoder1 path/to/ckpts/text_encoder --text_encoder2 path/to/ckpts/text_encoder_2 --batch_size 16 +``` + +その他のオプションは`python cache_text_encoder_outputs.py --help`で確認できます。 + +`--batch_size`はVRAMに合わせて調整してください。 + +VRAMが足りない場合(16GB程度未満の場合)は、`--fp8_llm`を指定して、fp8でLLMを実行してください。 + +デフォルトではデータセットに含まれないキャッシュファイルは自動的に削除されます。`--keep_cache`を指定すると、キャッシュファイルを残すことができます。 + +### 学習 + +以下のコマンドを使用して、学習を開始します(実際には一行で入力してください)。 + +```bash +accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py + --dit path/to/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt + --dataset_config path/to/toml --sdpa --mixed_precision bf16 --fp8_base + --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing + --max_data_loader_n_workers 2 --persistent_data_loader_workers + --network_module networks.lora --network_dim 32 + --timestep_sampling shift --discrete_flow_shift 7.0 + --max_train_epochs 16 --save_every_n_epochs 1 --seed 42 + --output_dir path/to/output_dir --output_name name-of-lora +``` + +__更新__:サンプルの学習率を1e-3から2e-4に、`--timestep_sampling`を`sigmoid`から`shift`に、`--discrete_flow_shift`を1.0から7.0に変更しました。より高速な学習が期待されます。ディテールが甘くなる場合は、discrete flow shiftを3.0程度に下げてみてください。 + +ただ、適切な学習率、学習ステップ数、timestepsの分布、loss weightingなどのパラメータは、以前として不明な点が数多くあります。情報提供をお待ちしています。 + +その他のオプションは`python hv_train_network.py --help`で確認できます(ただし多くのオプションは動作未確認です)。 + +`--fp8_base`を指定すると、DiTがfp8で学習されます。未指定時はmixed precisionのデータ型が使用されます。fp8は大きく消費メモリを削減できますが、品質は低下する可能性があります。`--fp8_base`を指定しない場合はVRAM 24GB以上を推奨します。また必要に応じて`--blocks_to_swap`を使用してください。 + +VRAMが足りない場合は、`--blocks_to_swap`を指定して、一部のブロックをCPUにオフロードしてください。最大36が指定できます。 + +(block swapのアイデアは2kpr氏の実装に基づくものです。2kpr氏にあらためて感謝します。) + +`--sdpa`でPyTorchのscaled dot product attentionを使用します。`--flash_attn`で[FlashAttention]:(https://github.com/Dao-AILab/flash-attention)を使用します。`--xformers`でxformersの利用も可能ですが、xformersを使う場合は`--split_attn`を指定してください。`--sage_attn`でSageAttentionを使用しますが、SageAttentionは現時点では学習に未対応のため、正しく動作しません。 + +`--split_attn`を指定すると、attentionを分割して処理します。速度が多少低下しますが、VRAM使用量はわずかに減ります。 + +学習されるLoRAの形式は、`sd-scripts`と同じです。 + +`--show_timesteps`に`image`(`matplotlib`が必要)または`console`を指定すると、学習時のtimestepsの分布とtimestepsごとのloss weightingが確認できます。 + +学習時のログの記録が可能です。[TensorBoard形式のログの保存と参照](./docs/advanced_config.md#save-and-view-logs-in-tensorboard-format--tensorboard形式のログの保存と参照)を参照してください。 + +学習中のサンプル画像生成については、[こちらのドキュメント](./docs/sampling_during_training.md)を参照してください。その他の高度な設定については[こちらのドキュメント](./docs/advanced_config.md)を参照してください。 + +### LoRAの重みのマージ + +```bash +python merge_lora.py \ + --dit path/to/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \ + --lora_weight path/to/lora.safetensors \ + --save_merged_model path/to/merged_model.safetensors \ + --device cpu \ + --lora_multiplier 1.0 +``` + +`--device`には計算を行うデバイス(`cpu`または`cuda`等)を指定してください。`cuda`を指定すると計算が高速化されます。 + +`--lora_weight`にはマージするLoRAの重みを、`--lora_multiplier`にはLoRAの重みの係数を、それぞれ指定してください。複数個が指定可能で、両者の数は一致させてください。 + +### 推論 + +以下のコマンドを使用して動画を生成します。 + +```bash +python hv_generate_video.py --fp8 --video_size 544 960 --video_length 5 --infer_steps 30 + --prompt "A cat walks on the grass, realistic style." --save_path path/to/save/dir --output_type both + --dit path/to/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt --attn_mode sdpa --split_attn + --vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt + --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 + --text_encoder1 path/to/ckpts/text_encoder + --text_encoder2 path/to/ckpts/text_encoder_2 + --seed 1234 --lora_multiplier 1.0 --lora_weight path/to/lora.safetensors +``` + +その他のオプションは`python hv_generate_video.py --help`で確認できます。 + +`--fp8`を指定すると、DiTがfp8で推論されます。fp8は大きく消費メモリを削減できますが、品質は低下する可能性があります。 + +VRAMが足りない場合は、`--blocks_to_swap`を指定して、一部のブロックをCPUにオフロードしてください。最大38が指定できます。 + +`--attn_mode`には`flash`、`torch`、`sageattn`、`xformers`または`sdpa`(`torch`指定時と同じ)のいずれかを指定してください。それぞれFlashAttention、scaled dot product attention、SageAttention、xformersに対応します。デフォルトは`torch`です。SageAttentionはVRAMの削減に有効です。 + +`--split_attn`を指定すると、attentionを分割して処理します。SageAttention利用時で10%程度の高速化が見込まれます。 + +`--output_type`には`both`、`latent`、`video`、`images`のいずれかを指定してください。`both`はlatentと動画の両方を出力します。VAEでOut of Memoryエラーが発生する場合に備えて、`both`を指定することをお勧めします。`--latent_path`に保存されたlatentを指定し、`--output_type video` (または`images`)としてスクリプトを実行すると、VAEのdecodeのみを行えます。 + +`--seed`は省略可能です。指定しない場合はランダムなシードが使用されます。 + +`--video_length`は「4の倍数+1」を指定してください。 + +`--flow_shift`にタイムステップのシフト値(discrete flow shift)を指定可能です。省略時のデフォルト値は7.0で、これは推論ステップ数が50の時の推奨値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。 + +`--video_path`に読み込む動画を指定すると、video2videoの推論が可能です。動画ファイルを指定するか、複数の画像ファイルが入ったディレクトリを指定してください(画像ファイルはファイル名でソートされ、各フレームとして用いられます)。`--video_length`よりも短い動画を指定するとエラーになります。`--strength`で強度を指定できます。0~1.0で指定でき、大きいほど元の動画からの変化が大きくなります。 + +なおvideo2video推論の処理は実験的なものです。 + +`--save_merged_model`オプションで、LoRAマージ後のDiTモデルを保存できます。`--save_merged_model path/to/merged_model.safetensors`のように指定してください。なおこのオプションを指定すると推論は行われません。 + +### LoRAの形式の変換 + +ComfyUIで使用可能な形式(Diffusion-pipeと思われる)への変換は以下のコマンドで行えます。 + +```bash +python convert_lora.py --input path/to/musubi_lora.safetensors --output path/to/another_format.safetensors --target other +``` + +`--input`と`--output`はそれぞれ入力と出力のファイルパスを指定してください。 + +`--target`には`other`を指定してください。`default`を指定すると、他の形式から当リポジトリの形式に変換できます。 + +## その他 + +### SageAttentionのインストール方法 + +sdbds氏によるWindows対応のSageAttentionのwheelが https://github.com/sdbds/SageAttention-for-windows で公開されています。triton をインストールし、Python、PyTorch、CUDAのバージョンが一致する場合は、[Releases](https://github.com/sdbds/SageAttention-for-windows/releases)からビルド済みwheelをダウンロードしてインストールすることが可能です。sdbds氏に感謝します。 + +参考までに、以下は、SageAttentionをビルドしインストールするための簡単な手順です。Microsoft Visual C++ 再頒布可能パッケージを最新にする必要があるかもしれません。 + +1. Pythonのバージョンに応じたtriton 3.1.0のwhellを[こちら](https://github.com/woct0rdho/triton-windows/releases/tag/v3.1.0-windows.post5)からダウンロードしてインストールします。 + +2. Microsoft Visual Studio 2022かBuild Tools for Visual Studio 2022を、C++のビルドができるよう設定し、インストールします。(上のRedditの投稿を参照してください)。 + +3. 任意のフォルダにSageAttentionのリポジトリをクローンします。 + ```shell + git clone https://github.com/thu-ml/SageAttention.git + ``` + + なお `git clone https://github.com/sdbds/SageAttention-for-windows.git` で、前述のsdbds氏のリポジトリを使用することで、手順4.を省略できます。 + +4. `SageAttention/csrc`フォルダ内の`math.cuh`を開き、71行目と146行目の `ushort` を `unsigned short` に変更して保存します。 + +5. スタートメニューから Visual Studio 2022 内の `x64 Native Tools Command Prompt for VS 2022` を選択してコマンドプロンプトを開きます。 + +6. venvを有効にし、SageAttentionのフォルダに移動して以下のコマンドを実行します。DISTUTILSが設定されていない、のようなエラーが出た場合は `set DISTUTILS_USE_SDK=1`としてから再度実行してください。 + ```shell + python setup.py install + ``` + +以上でSageAttentionのインストールが完了です。 + +### PyTorchのバージョンについて + +`--attn_mode`に`torch`を指定する場合、2.5.1以降のPyTorchを使用してください(それより前のバージョンでは生成される動画が真っ黒になるようです)。 + +古いバージョンを使う場合、xformersやSageAttentionを使用してください。 + +## 免責事項 + +このリポジトリは非公式であり、公式のHunyuanVideoリポジトリとは関係ありません。また、このリポジトリは開発中で、実験的なものです。テストおよびフィードバックを歓迎しますが、以下の点にご注意ください: + +- 実際の稼働環境での動作を意図したものではありません +- 機能やAPIは予告なく変更されることがあります +- いくつもの機能が未検証です +- 動画学習機能はまだ開発中です + +問題やバグについては、以下の情報とともにIssueを作成してください: + +- 問題の詳細な説明 +- 再現手順 +- 環境の詳細(OS、GPU、VRAM、Pythonバージョンなど) +- 関連するエラーメッセージやログ + +## コントリビューションについて + +コントリビューションを歓迎します。ただし、以下にご注意ください: + +- メンテナーのリソースが限られているため、PRのレビューやマージには時間がかかる場合があります +- 大きな変更に取り組む前には、議論のためのIssueを作成してください +- PRに関して: + - 変更は焦点を絞り、適度なサイズにしてください + - 明確な説明をお願いします + - 既存のコードスタイルに従ってください + - ドキュメントが更新されていることを確認してください + +## ライセンス + +`hunyuan_model`ディレクトリ以下のコードは、[HunyuanVideo](https://github.com/Tencent/HunyuanVideo)のコードを一部改変して使用しているため、そちらのライセンスに従います。 + +他のコードはApache License 2.0に従います。一部Diffusersのコードをコピー、改変して使用しています。 diff --git a/README.md b/README.md index 5a2a60a8699dfce2a5746fbea99f1727db0643ca..f7b0c81d7a0ffa6ab14b79d642fd0e2911cedaa1 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,269 @@ ---- -title: Framepack H111 -emoji: 🌖 -colorFrom: pink -colorTo: yellow -sdk: gradio -sdk_version: 5.31.0 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +![GUI Screenshot](images/screenshot.png) + +# Recent update +5/25/2025 + Enable full intermediate previews for framepack tab, some change to framepack extension with image input logic. +5/24/2025 + Batch images from folder now available for framepack. Save only extension option and experimental start image for f1 in framepack extension tab. +5/23/2025 + Added ability to use the normal framepack model with endframe support in the framepack extension tab. Support additional bucket sizes. +5/18/2025 + Add video extension tab. Currently only works with f1 model. Full credit to @pfqt and @chaojie for their amazing work! + +# H1111 + +This is a GUI for tech wizard kohya-ss's musubi tuner's inference script. +https://github.com/kohya-ss/musubi-tuner + +It allows inference with these models: +FramePack +Hunyuan-t2v +Hunyuan-i2v +Hunyuan-v2v +WanX-t2v +WanX-i2v +WanX-v2v +SkyReels-i2v +SkyReels-t2v + +I have mostly been workiing on the framepack tab and the WanX-i2v tab. They are the best to use right now. WanX-i2v is used for skyreels v2 and the fun control models. + +This supports queuing multiple different jobs if you open 2+ browser tabs and use the same model. + +If you are running out of vram use more block swapping. Using FP8 scaled is also a decent option to lower memory usage, select fp8 and fp8 scaled to use it. Scaled fp8 tries to duplicate the important parts of the model from FP16. Sage attention is the fastest/lowest vram but difficult to install in windows. + +Best quality will be obtained with only enabling block swapping and using the fp16 model with sdpa attention. You can speed things up with cfg skip, fp8 scaled, slg skip is small speedup, sage attention is fastest but all speedups come with quality degradations. I designed this to try to focus on quality over speed. + +If you are using a lora that you didn't train with musubi you need to drag it to the convert lora tab and convert it to the default format. It should spit it out into the /lora folder. + +If you need additional installation instructions or information create an issue and I will try to help. Also there are alot of settings notes on the musubi github linked above. + +For torch 2.7.0 and windows installation try: +pip install typing-extensions +pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128 +pip install -r requirementsTorch27.txt + +## To Use FramePack + + +download these 5 files from https://huggingface.co/maybleMyers/framepack_h1111 and put them in a subfolder named hunyuan (H1111/hunyuan), or reference where they are in the gui if you have already aquired them. + +FramePackI2V_HY_bf16.safetensors or FramePack_F1_I2V_HY_20250503.safetensors for F1 + +clip_l.safetensors + +llava_llama3_fp16.safetensors + +model.safetensors + +pytorch_model.pt + +Lora trained with musubi tuner's framepack training confirmed to work great. Normal lora trained for hunyuan kinda suck. Use a lot of block swap this is a different back end than the official repo. If you select fp8 and fp8 scaled it will all fit on a 24gb gpu for fastest speed, about 3s/it or 1:17 per second of video w/ a 4090. Best quality will be obtained with just block swapping/sdpa attention/full model though. + +Put loras in a /lora subfolder, if not trained with musubi you need to convert them. + +Only unipc is supported for now. Sage attn is experimental. When using the F1 model not all options available for the original framepack model will work, like endframe and sectional images. + +Here is an example prompt for a 5 second video with 4 sections using sectional prompting, also supports longer videos with indexes ie 0-2 ;;;3-5 etc: + +0:A cinematic video showcases a cute blue penguin wearing sunglasses. The penguin runs quickly into mcdonalds.;;;1:The penguin runs quickly into mcdonalds and jumps up on a table and starts eating his food. The penguin's name is Piplup he is a famous Pokemon actor. The video is a fast action sequence animation showing the penguin running into a mcdonalds an jumping up onto a table.;;;2:The penguin is seated at a table and is enjoying his happy meal. The penguin's name is Piplup he is a famous Pokemon actor. The video is a fast action sequence animation showing the penguin running into a mcdonalds and jumping up onto a table.;;;3:The penguin is seated at a table and is happily enjoying his happy meal. The penguin's name is Piplup he is a famous Pokemon actor. The penguin flexes his huge arm muscles at the end of the video. + +I have added support for 4 sectional images during inference. It works best when the images are close together. Refer to the screen shot for an example of a working 5 second video. + +For more details on using framepack with musubi go here https://github.com/kohya-ss/musubi-tuner/blob/main/docs/framepack.md + +Fastest speed will be achieved with fp8 and fp8 scaled, then you can reduce block swapping to your memory constraints. (leave about 1gb free) + +Framepack Extension tab is still a work in progress. +Thanks to @pftq https://github.com/pftq and @chaojie https://github.com/chaojie for their work on the extension logics. + +## To Use the new Skyreels-V2 models + +I have provided these 2 at https://huggingface.co/maybleMyers/wan_files_for_h1111 + + SkyReels-V2-I2V-14B-720P-FP16.safetensors + SkyReels-V2-I2V-14B-540P-FP16.safetensors + +You can just drop them into the wan folder and use them in the WanX-i2v tab. Skyreels-V2 is a fine tune from Wan2.1. +If you have download the kijai variants the will not work because he added extra keys to the model. + +## To Use WanX + +To use wanX download these and toss them in the wan subfolder: +Download the T5 `models_t5_umt5-xxl-enc-bf16.pth`, vae `Wan2.1_VAE.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P/tree/main + +Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models +ie : wan2.1_i2v_720p_14B_fp16.safetensors + +For the fun control option in WanX-i2v I recommend the fp16 weights here: https://huggingface.co/maybleMyers/wan_files_for_h1111/tree/main +Wan2.1-Fun-14B-Control_fp16.safetensors + +git pull to update the installation +pip install -r requirements.txt + +I have tested the 14B i2v and t2v models so far to be working + +## Requirements + +- Python 3.10 +- CUDA 12.4 + +## Basic Installation (Linux) + +Tested on ubuntu 24 + +to update navigate to H1111 and git pull + +```powershell +git clone https://github.com/maybleMyers/H1111 +cd H1111 +python -m venv env +#(if you have another version of python do python3.10 -m venv env after you install it with sudo apt install python3.10 python3.10-venv python3.10-distutils) +source env/bin/activate +pip install torch==2.5.1 torchvision --index-url https://download.pytorch.org/whl/cu124 +pip install -r requirements.txt +pip install flash-attn --no-build-isolation +pip install sageattention==1.0.6 +might need python3.10-dev as well for sage attention to work + +``` + +run with +source env/bin/activate +python h1111.py + +for GPU1 +CUDA_VISIBLE_DEVICES=1 python h1111.py + +## Basic Installation (Windows) + + + +First, open PowerShell and navigate to your desired installation directory. Then run these commands: + +```powershell +git clone https://github.com/maybleMyers/H1111 +cd H1111 +python -m venv env +./env/scripts/activate +pip install torch==2.5.1 torchvision --index-url https://download.pytorch.org/whl/cu124 +pip install -r requirements.txt + +``` + +## To run + +``` +env/scripts/activate +python h1111.py +``` + +open 127.0.0.1:7860 in a browser + +You can set cuda device to 1,2,3,4,5,6,7 etc in the env once activated in a separate terminal to run unlimited copies at once if you have another gpu. +ie for linux on the second gpu: CUDA_VISIBLE_DEVICES=1 python h1111.py + +## full changlog +5/25/2025 + Enable full intermediate previews for framepack tab, some change to framepack extension with image input logic. +5/24/2025 + Batch images from folder now available for framepack. Save only extension option and experimental start image for f1 in framepack extension tab. +5/23/2025 + Added ability to use the normal framepack model with endframe support in the framepack extension tab. Support additional bucket sizes. +5/18/2025 + Add video extension tab. Currently only works with f1 model. Full credit to @pfqt and @chaojie for their amazing work! +5/12/2025 + Add skip button to framepack. +5/9/2025 + Add testing branch for framepack F1 end image, kinda glitchygo https://github.com/maybleMyers/H1111/tree/f1_end +5/5/2025 + Update an experimental hunyuan to framepack convert lora option in the convert lora tab. + Add tea cache to frame pack. +5/3/2025 + Add support for framepack F1! download from https://huggingface.co/maybleMyers/wan_files_for_h1111/blob/main/FramePack_F1_I2V_HY_20250503.safetensors put it in your hunyuan folder. You might need to reinstall reqs "pip install -r requirements.txt" + Add support for Wan2.1 i2v-14B-FC-1.1. It is a fun control model and is very good. Use it in the WanX-i2v tab and make sure to select the task i2v-14B-FC-1.1 at the bottom of the page. Download the weights from https://huggingface.co/maybleMyers/wan_files_for_h1111 +4/30/2025 + Previews for framepack. +4/29/2025 + Add initial preview support to the wanX-i2v tab based. If you want to use them use the preview branch. Thanks to Sarania. + Wan2.1-Fun-V1.1-14B-InP-FP16.safetensors is available at https://huggingface.co/maybleMyers/wan_files_for_h1111 + Fix bug in hunyuan-t2v not loading lora. +4/26/2025 + Add SkyReels-V2-I2V-14B-720P-FP16.safetensors to supported models. + Added alot better options for Framepack including working sectional images, Thanks to kohya! +4/25/2025 + Framepack backend updates for better LoRa support for LoRa's trained with musubi tuner. Also better weighting options. +4/24/2025 + Update FramePack backend to musubi backend instead of original. Offers much improved speed and some quality improvements. + Add support for torch 2.7.0 + cuda 12.8 +4/18/2025 + Add initial support for FramePack. https://github.com/lllyasviel/FramePack +4/15/2025 + Add much improved functionality for the wan fun control model. Added strength imrpovements and dropoff code to choose when to apply the control video. Thanks wordbrew. +4/3/2025 + Add support for hunyuan i2v model. Download the clip vision from https://huggingface.co/maybleMyers/H1111_Hunyuan_i2v And download the official model from hunyuan's website and rename it to mp_rank_00_model_states_i2v.pt https://huggingface.co/tencent/HunyuanVideo-I2V/tree/main/hunyuan-video-i2v-720p/transformers add both to your hunyuan folder. +3/29/2025 + Added support for fun models! download dit from https://huggingface.co/alibaba-pai/Wan2.1-Fun-14B-Control and specify correct task type and dit location. I renamed it from diffusion_pytorch_model to Wan2.1-Fun-14B-control. Works in the normal WanX-i2v tab when you select the control option at the bottom of the page. +3/23/2025 + Added Wanx cfg skip functionality to skip cfg guidance during inference for faster generations but less following of the prompt +3/22/2025 + Added WanX-i2v end frame functionality +3/20/2025 + Added WanX-v2v functionality. +3/18/2025 + Added Skip Layer Guidance for WanX-i2v. +3/13/2025 + Added extend video functionality to WanX-i2v. It kind of works . +3/12/2025 + Added ability to send the last frame of a video to the input in WanX-i2v. Also you can now use this to extend the video. You can do multiple batches at each step and pick the best extended video then generate an even longer one. +3/9/2025 + Added batching ability for a folder full of images in WanX-i2v tab. Added flash attn for windows prebuilt wheel. +3/8/2025 + Added support for wan lora's. Remember to convert them first in the convert lora tab. +3/5/2025 + Added ability to batch a folder of images with skyreels i2v, so you can make a video with every image in a folder. +3/2/2025 + Added initial support for wanX-2.1 Image to Video and Text to Video inference. +3/1/2025 + Added support for Skyreels Video to Video and Text to Video. +2/23/2025 + Added initial support for skyreels-V1 using musubi's skyreel implementation. (thanks sdbds) +download models from https://huggingface.co/Kijai/SkyReels-V1-Hunyuan_comfy and add them to your hunyuan folder +skyreels_hunyuan_i2v_bf16.safetensors +skyreels_hunyuan_t2v_bf16.safetensors + + +## to use stock hunyuan models + +https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt + +https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/vae/pytorch_model.pt + +https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/resolve/main/split_files/text_encoders/llava_llama3_fp16.safetensors + +https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/resolve/main/split_files/text_encoders/clip_l.safetensors + +#fp8 dit model + +https://huggingface.co/kohya-ss/HunyuanVideo-fp8_e4m3fn-unofficial/resolve/main/mp_rank_00_model_states_fp8.safetensors + +place models in H1111/hunyuan folder + +### Optional: Install Xformers +```powershell +pip install --no-deps xformers --index-url https://download.pytorch.org/whl/cu124 +``` + +### Optional: Install Flash Attention +Note: This can take 1-5 hour to install even on a good CPU, but provides faster generation. +I have uploaded a wheel for windows users to match cuda 12.4 and python 3.10.(thanks lldacing) +https://huggingface.co/maybleMyers/wan_files_for_h1111/resolve/main/flash_attn-2.7.4%2Bcu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl?download=true + +```powershell +pip install flash-attn --no-build-isolation + +If you have downloaded the wheel you can install it with: + +pip install "flash_attn-2.7.4+cu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl" +``` +``` diff --git a/base_hv_generate_video.py b/base_hv_generate_video.py new file mode 100644 index 0000000000000000000000000000000000000000..3b57335f3000d8a8292cb4ca619bfa4fcc92bba3 --- /dev/null +++ b/base_hv_generate_video.py @@ -0,0 +1,936 @@ +import argparse +from datetime import datetime +from pathlib import Path +import random +import sys +import os +import time +from typing import Optional, Union + +import numpy as np +import torch +import torchvision +import accelerate +from diffusers.utils.torch_utils import randn_tensor +from transformers.models.llama import LlamaModel +from tqdm import tqdm +import av +from einops import rearrange +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from PIL import Image + +from hunyuan_model import vae +from hunyuan_model.text_encoder import TextEncoder +from hunyuan_model.text_encoder import PROMPT_TEMPLATE +from hunyuan_model.vae import load_vae +from hunyuan_model.models import load_transformer, get_rotary_pos_embed +from hunyuan_model.fp8_optimization import convert_fp8_linear +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +from networks import lora + +try: + from lycoris.kohya import create_network_from_weights +except: + pass + +from utils.model_utils import str_to_dtype +from utils.safetensors_utils import mem_eff_save_file +from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def clean_memory_on_device(device): + if device.type == "cuda": + torch.cuda.empty_cache() + elif device.type == "cpu": + pass + elif device.type == "mps": # not tested + torch.mps.empty_cache() + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24): + """save videos by video tensor + copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61 + + Args: + videos (torch.Tensor): video tensor predicted by the model + path (str): path to save video + rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False. + n_rows (int, optional): Defaults to 1. + fps (int, optional): video save fps. Defaults to 8. + """ + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x, 0, 1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + # # save video with av + # container = av.open(path, "w") + # stream = container.add_stream("libx264", rate=fps) + # for x in outputs: + # frame = av.VideoFrame.from_ndarray(x, format="rgb24") + # packet = stream.encode(frame) + # container.mux(packet) + # packet = stream.encode(None) + # container.mux(packet) + # container.close() + + height, width, _ = outputs[0].shape + + # create output container + container = av.open(path, mode="w") + + # create video stream + codec = "libx264" + pixel_format = "yuv420p" + stream = container.add_stream(codec, rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = pixel_format + stream.bit_rate = 4000000 # 4Mbit/s + + for frame_array in outputs: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + packets = stream.encode(frame) + for packet in packets: + container.mux(packet) + + for packet in stream.encode(): + container.mux(packet) + + container.close() + + +def save_images_grid( + videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True +): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x, 0, 1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + if create_subdir: + output_dir = os.path.join(parent_dir, image_name) + else: + output_dir = parent_dir + + os.makedirs(output_dir, exist_ok=True) + for i, x in enumerate(outputs): + image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png") + image = Image.fromarray(x) + image.save(image_path) + + +# region Encoding prompt + + +def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + text_encoder (TextEncoder): + text encoder to be used for encoding the prompt + """ + # LoRA and Textual Inversion are not supported in this script + # negative prompt and prompt embedding are not supported in this script + # clip_skip is not supported in this script because it is not used in the original script + data_type = "video" # video only, image is not supported + + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) + + with torch.no_grad(): + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device) + prompt_embeds = prompt_outputs.hidden_state + + attention_mask = prompt_outputs.attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(device) + bs_embed, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, num_videos_per_prompt) + attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len) + + prompt_embeds_dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if prompt_embeds.ndim == 2: + bs_embed, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) + else: + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, attention_mask + + +def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None): + # constants + prompt_template_video = "dit-llm-encode-video" + prompt_template = "dit-llm-encode" + text_encoder_dtype = torch.float16 + text_encoder_type = "llm" + text_len = 256 + hidden_state_skip_layer = 2 + apply_final_norm = False + reproduce = False + + text_encoder_2_type = "clipL" + text_len_2 = 77 + + num_videos = 1 + + # if args.prompt_template_video is not None: + # crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0) + # elif args.prompt_template is not None: + # crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0) + # else: + # crop_start = 0 + crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0) + max_length = text_len + crop_start + + # prompt_template + prompt_template = PROMPT_TEMPLATE[prompt_template] + + # prompt_template_video + prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None + + # load text encoders + logger.info(f"loading text encoder: {args.text_encoder1}") + text_encoder = TextEncoder( + text_encoder_type=text_encoder_type, + max_length=max_length, + text_encoder_dtype=text_encoder_dtype, + text_encoder_path=args.text_encoder1, + tokenizer_type=text_encoder_type, + prompt_template=prompt_template, + prompt_template_video=prompt_template_video, + hidden_state_skip_layer=hidden_state_skip_layer, + apply_final_norm=apply_final_norm, + reproduce=reproduce, + ) + text_encoder.eval() + if fp8_llm: + org_dtype = text_encoder.dtype + logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn") + text_encoder.to(device=device, dtype=torch.float8_e4m3fn) + + # prepare LLM for fp8 + def prepare_fp8(llama_model: LlamaModel, target_dtype): + def forward_hook(module): + def forward(hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) + return module.weight.to(input_dtype) * hidden_states.to(input_dtype) + + return forward + + for module in llama_model.modules(): + if module.__class__.__name__ in ["Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["LlamaRMSNorm"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + prepare_fp8(text_encoder.model, org_dtype) + + logger.info(f"loading text encoder 2: {args.text_encoder2}") + text_encoder_2 = TextEncoder( + text_encoder_type=text_encoder_2_type, + max_length=text_len_2, + text_encoder_dtype=text_encoder_dtype, + text_encoder_path=args.text_encoder2, + tokenizer_type=text_encoder_2_type, + reproduce=reproduce, + ) + text_encoder_2.eval() + + # encode prompt + logger.info(f"Encoding prompt with text encoder 1") + text_encoder.to(device=device) + if fp8_llm: + with accelerator.autocast(): + prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder) + else: + prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder) + text_encoder = None + clean_memory_on_device(device) + + logger.info(f"Encoding prompt with text encoder 2") + text_encoder_2.to(device=device) + prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2) + + prompt_embeds = prompt_embeds.to("cpu") + prompt_mask = prompt_mask.to("cpu") + prompt_embeds_2 = prompt_embeds_2.to("cpu") + prompt_mask_2 = prompt_mask_2.to("cpu") + + text_encoder_2 = None + clean_memory_on_device(device) + + return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 + + +# endregion + + +def prepare_vae(args, device): + vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype) + vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae) + vae.eval() + # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} + + # set chunk_size to CausalConv3d recursively + chunk_size = args.vae_chunk_size + if chunk_size is not None: + vae.set_chunk_size_for_causal_conv_3d(chunk_size) + logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d") + + if args.vae_spatial_tile_sample_min_size is not None: + vae.enable_spatial_tiling(True) + vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size + vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 + # elif args.vae_tiling: + else: + vae.enable_spatial_tiling(True) + + return vae, vae_dtype + + +def encode_to_latents(args, video, device): + vae, vae_dtype = prepare_vae(args, device) + + video = video.to(device=device, dtype=vae_dtype) + video = video * 2 - 1 # 0, 1 -> -1, 1 + with torch.no_grad(): + latents = vae.encode(video).latent_dist.sample() + + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor + else: + latents = latents * vae.config.scaling_factor + + return latents + + +def decode_latents(args, latents, device): + vae, vae_dtype = prepare_vae(args, device) + + expand_temporal_dim = False + if len(latents.shape) == 4: + latents = latents.unsqueeze(2) + expand_temporal_dim = True + elif len(latents.shape) == 5: + pass + else: + raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.") + + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = latents / vae.config.scaling_factor + vae.config.shift_factor + else: + latents = latents / vae.config.scaling_factor + + latents = latents.to(device=device, dtype=vae_dtype) + with torch.no_grad(): + image = vae.decode(latents, return_dict=False)[0] + + if expand_temporal_dim: + image = image.squeeze(2) + + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().float() + + return image + + +def parse_args(): + parser = argparse.ArgumentParser(description="HunyuanVideo inference script") + + parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory") + parser.add_argument( + "--dit_in_channels", + type=int, + default=None, + help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others", + ) + parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") + parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory") + parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory") + + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights") + + # inference + parser.add_argument("--prompt", type=str, required=True, help="prompt for generation") + parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation") + parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size") + parser.add_argument("--video_length", type=int, default=129, help="video length") + parser.add_argument("--fps", type=int, default=24, help="video fps") + parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + parser.add_argument( + "--guidance_scale", + type=float, + default=1.0, + help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)", + ) + parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.") + parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") + parser.add_argument( + "--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model" + ) + parser.add_argument( + "--split_uncond", + action="store_true", + help="split unconditional call for classifier free guidance, slower but less memory usage", + ) + parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference") + + # Flow Matching + parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.") + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode" + ) + parser.add_argument( + "--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True" + ) + parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") + parser.add_argument( + "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" + ) + parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model") + parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu") + parser.add_argument( + "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") + parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)") + parser.add_argument("--compile", action="store_true", help="Enable torch.compile") + parser.add_argument( + "--compile_args", + nargs=4, + metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), + default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], + help="Torch.compile settings", + ) + + args = parser.parse_args() + + assert (args.latent_path is None or len(args.latent_path) == 0) or ( + args.output_type == "images" or args.output_type == "video" + ), "latent_path is only supported for images or video output" + + # update dit_weight based on model_base if not exists + + if args.fp8_fast and not args.fp8: + raise ValueError("--fp8_fast requires --fp8") + + return args + + +def check_inputs(args): + height = args.video_size[0] + width = args.video_size[1] + video_length = args.video_length + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + return height, width, video_length + + +def main(): + args = parse_args() + + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + dit_dtype = torch.bfloat16 + dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype + logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}") + + original_base_names = None + if args.latent_path is not None and len(args.latent_path) > 0: + original_base_names = [] + latents_list = [] + seeds = [] + for latent_path in args.latent_path: + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + seed = 0 + + if os.path.splitext(latent_path)[1] != ".safetensors": + latents = torch.load(latent_path, map_location="cpu") + else: + latents = load_file(latent_path)["latent"] + with safe_open(latent_path, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + logger.info(f"Loaded metadata: {metadata}") + + if "seeds" in metadata: + seed = int(metadata["seeds"]) + + seeds.append(seed) + latents_list.append(latents) + + logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") + latents = torch.stack(latents_list, dim=0) + else: + # prepare accelerator + mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" + accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) + + # load prompt + prompt = args.prompt # TODO load prompts from file + assert prompt is not None, "prompt is required" + + # check inputs: may be height, width, video_length etc will be changed for each generation in future + height, width, video_length = check_inputs(args) + + # encode prompt with LLM and Text Encoder + logger.info(f"Encoding prompt: {prompt}") + + do_classifier_free_guidance = args.guidance_scale != 1.0 + if do_classifier_free_guidance: + negative_prompt = args.negative_prompt + if negative_prompt is None: + logger.info("Negative prompt is not provided, using empty prompt") + negative_prompt = "" + logger.info(f"Encoding negative prompt: {negative_prompt}") + prompt = [negative_prompt, prompt] + else: + if args.negative_prompt is not None: + logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.") + + prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt( + prompt, args, device, args.fp8_llm, accelerator + ) + + # encode latents for video2video inference + video_latents = None + if args.video_path is not None: + # v2v inference + logger.info(f"Video2Video inference: {args.video_path}") + video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames + if len(video) < video_length: + raise ValueError(f"Video length is less than {video_length}") + video = np.stack(video, axis=0) # F, H, W, C + video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W + video = video / 255.0 + + logger.info(f"Encoding video to latents") + video_latents = encode_to_latents(args, video, device) + video_latents = video_latents.to(device=device, dtype=dit_dtype) + + clean_memory_on_device(device) + + # encode latents for image2video inference + image_latents = None + if args.image_path is not None: + # i2v inference + logger.info(f"Image2Video inference: {args.image_path}") + + image = Image.open(args.image_path) + image = resize_image_to_bucket(image, (width, height)) # returns a numpy array + image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W + image = image / 255.0 + + logger.info(f"Encoding image to latents") + image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W + image_latents = image_latents.to(device=device, dtype=dit_dtype) + + clean_memory_on_device(device) + + # load DiT model + blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0 + loading_device = "cpu" # if blocks_to_swap > 0 else device + + logger.info(f"Loading DiT model from {args.dit}") + if args.attn_mode == "sdpa": + args.attn_mode = "torch" + + # if image_latents is given, the model should be I2V model, so the in_channels should be 32 + dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16) + + # if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16 + # the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway + # on the fly merging will be a solution for this issue for .safetenors files (not implemented yet) + transformer = load_transformer( + args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels + ) + transformer.eval() + + # load LoRA weights + if args.lora_weight is not None and len(args.lora_weight) > 0: + for i, lora_weight in enumerate(args.lora_weight): + if args.lora_multiplier is not None and len(args.lora_multiplier) > i: + lora_multiplier = args.lora_multiplier[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + + # Filter to exclude keys that are part of single_blocks + if args.exclude_single_blocks: + filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k} + weights_sd = filtered_weights + + if args.lycoris: + lycoris_net, _ = create_network_from_weights( + multiplier=lora_multiplier, + file=None, + weights_sd=weights_sd, + unet=transformer, + text_encoder=None, + vae=None, + for_inference=True, + ) + else: + network = lora.create_arch_network_from_weights( + lora_multiplier, weights_sd, unet=transformer, for_inference=True + ) + logger.info("Merging LoRA weights to DiT model") + + # try: + # network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + # info = network.load_state_dict(weights_sd, strict=True) + # logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + # network.eval() + # network.to(device) + # except Exception as e: + if args.lycoris: + lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device) + else: + network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True) + + synchronize_device(device) + + logger.info("LoRA weights loaded") + + # save model here before casting to dit_weight_dtype + if args.save_merged_model: + logger.info(f"Saving merged model to {args.save_merged_model}") + mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory + logger.info("Merged model saved") + return + + logger.info(f"Casting model to {dit_weight_dtype}") + transformer.to(dtype=dit_weight_dtype) + + if args.fp8_fast: + logger.info("Enabling FP8 acceleration") + params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} + for name, param in transformer.named_parameters(): + dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype + param.to(dtype=dtype_to_use) + convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep) + + if args.compile: + compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args + logger.info( + f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" + ) + torch._dynamo.config.cache_size_limit = 32 + for i, block in enumerate(transformer.single_blocks): + compiled_block = torch.compile( + block, + backend=compile_backend, + mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true", + ) + transformer.single_blocks[i] = compiled_block + for i, block in enumerate(transformer.double_blocks): + compiled_block = torch.compile( + block, + backend=compile_backend, + mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true", + ) + transformer.double_blocks[i] = compiled_block + + if blocks_to_swap > 0: + logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}") + transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False) + transformer.move_to_device_except_swap_blocks(device) + transformer.prepare_block_swap_before_forward() + else: + logger.info(f"Moving model to {device}") + transformer.to(device=device) + if args.img_in_txt_in_offloading: + logger.info("Enable offloading img_in and txt_in to CPU") + transformer.enable_img_in_txt_in_offloading() + + # load scheduler + logger.info(f"Loading scheduler") + scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler") + + # Prepare timesteps + num_inference_steps = args.infer_steps + scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler + timesteps = scheduler.timesteps + + # Prepare generator + num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size + seed = args.seed + if seed is None: + seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)] + elif isinstance(seed, int): + seeds = [seed + i for i in range(num_videos_per_prompt)] + else: + raise ValueError(f"Seed must be an integer or None, got {seed}.") + generator = [torch.Generator(device).manual_seed(seed) for seed in seeds] + + # Prepare noisy latents + num_channels_latents = 16 # transformer.config.in_channels + vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4 + + vae_ver = vae.VAE_VER + if "884" in vae_ver: + latent_video_length = (video_length - 1) // 4 + 1 + elif "888" in vae_ver: + latent_video_length = (video_length - 1) // 8 + 1 + else: + latent_video_length = video_length + + # shape = ( + # num_videos_per_prompt, + # num_channels_latents, + # latent_video_length, + # height // vae_scale_factor, + # width // vae_scale_factor, + # ) + # latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype) + + # make first N frames to be the same if the given seed is same + shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor) + latents = [] + for i in range(latent_video_length): + latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype)) + latents = torch.cat(latents, dim=2) + + # pad image_latents to match the length of video_latents + if image_latents is not None: + zero_latents = torch.zeros_like(latents) + zero_latents[:, :, :1, :, :] = image_latents + image_latents = zero_latents + + if args.video_path is not None: + # v2v inference + noise = latents + assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}" + + num_inference_steps = int(num_inference_steps * args.strength) + timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time + t = timestep_start / 1000.0 + latents = noise * t + video_latents * (1 - t) + + timesteps = timesteps[-num_inference_steps:] + + logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}") + + # FlowMatchDiscreteScheduler does not have init_noise_sigma + + # Denoising loop + embedded_guidance_scale = args.embedded_cfg_scale + if embedded_guidance_scale is not None: + guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu") + guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype) + if do_classifier_free_guidance: + guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0) + else: + guidance_expand = None + freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width) + # n_tokens = freqs_cos.shape[0] + + # move and cast all inputs to the correct device and dtype + prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype) + prompt_mask = prompt_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype) + prompt_mask_2 = prompt_mask_2.to(device=device) + + freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype) + freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype) + + num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference + + # assert split_uncond and split_attn + if args.split_attn and do_classifier_free_guidance and not args.split_uncond: + logger.warning("split_attn is enabled, split_uncond will be enabled as well.") + args.split_uncond = True + + # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p: + with tqdm(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents = scheduler.scale_model_input(latents, t) + + # predict the noise residual + with torch.no_grad(), accelerator.autocast(): + latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0) + if image_latents is not None: + latents_image_input = ( + image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0) + ) + latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W + + batch_size = 1 if args.split_uncond else latents_input.shape[0] + + noise_pred_list = [] + for j in range(0, latents_input.shape[0], batch_size): + noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256) + latents_input[j : j + batch_size], # [1, 16, 33, 24, 42] + t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1] + text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096] + text_mask=prompt_mask[j : j + batch_size], # [1, 256] + text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768] + freqs_cos=freqs_cos, # [seqlen, head_dim] + freqs_sin=freqs_sin, # [seqlen, head_dim] + guidance=guidance_expand[j : j + batch_size], # [1] + return_dict=True, + )["x"] + noise_pred_list.append(noise_pred) + noise_pred = torch.cat(noise_pred_list, dim=0) + + # perform classifier free guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # # SkyReels' rescale noise config is omitted for now + # if guidance_rescale > 0.0: + # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # noise_pred = rescale_noise_cfg( + # noise_pred, + # noise_pred_cond, + # guidance_rescale=self.guidance_rescale, + # ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): + if progress_bar is not None: + progress_bar.update() + + # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1)) + # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + latents = latents.detach().cpu() + transformer = None + clean_memory_on_device(device) + + # Save samples + output_type = args.output_type + save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}" + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + if output_type == "latent" or output_type == "both": + # save latent + for i, latent in enumerate(latents): + latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors" + + if args.no_metadata: + metadata = None + else: + metadata = { + "seeds": f"{seeds[i]}", + "prompt": f"{args.prompt}", + "height": f"{height}", + "width": f"{width}", + "video_length": f"{video_length}", + "infer_steps": f"{num_inference_steps}", + "guidance_scale": f"{args.guidance_scale}", + "embedded_cfg_scale": f"{args.embedded_cfg_scale}", + } + if args.negative_prompt is not None: + metadata["negative_prompt"] = f"{args.negative_prompt}" + sd = {"latent": latent} + save_file(sd, latent_path, metadata=metadata) + + logger.info(f"Latent save to: {latent_path}") + if output_type == "video" or output_type == "both": + # save video + videos = decode_latents(args, latents, device) + for i, sample in enumerate(videos): + original_name = "" if original_base_names is None else f"_{original_base_names[i]}" + sample = sample.unsqueeze(0) + video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4" + save_videos_grid(sample, video_path, fps=args.fps) + logger.info(f"Sample save to: {video_path}") + elif output_type == "images": + # save images + videos = decode_latents(args, latents, device) + for i, sample in enumerate(videos): + original_name = "" if original_base_names is None else f"_{original_base_names[i]}" + sample = sample.unsqueeze(0) + image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}" + save_images_grid(sample, save_path, image_name) + logger.info(f"Sample images save to: {save_path}/{image_name}") + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/base_wan_generate_video.py b/base_wan_generate_video.py new file mode 100644 index 0000000000000000000000000000000000000000..aec5a7200e04a83a2ef6adc11447244f766944e8 --- /dev/null +++ b/base_wan_generate_video.py @@ -0,0 +1,1892 @@ +import argparse +from datetime import datetime +import gc +import random +import os +import re +import time +import math +import copy +from types import ModuleType, SimpleNamespace +from typing import Tuple, Optional, List, Union, Any, Dict + +import torch +import accelerate +from accelerate import Accelerator +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from PIL import Image +import cv2 +import numpy as np +import torchvision.transforms.functional as TF +from tqdm import tqdm + +from networks import lora_wan +from utils.safetensors_utils import mem_eff_save_file, load_safetensors +from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES +import wan +from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype +from wan.modules.vae import WanVAE +from wan.modules.t5 import T5EncoderModel +from wan.modules.clip import CLIPModel +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +try: + from lycoris.kohya import create_network_from_weights +except: + pass + +from utils.model_utils import str_to_dtype +from utils.device_utils import clean_memory_on_device +from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device +from dataset.image_video_dataset import load_video + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class GenerationSettings: + def __init__( + self, device: torch.device, cfg, dit_dtype: torch.dtype, dit_weight_dtype: Optional[torch.dtype], vae_dtype: torch.dtype + ): + self.device = device + self.cfg = cfg + self.dit_dtype = dit_dtype + self.dit_weight_dtype = dit_weight_dtype + self.vae_dtype = vae_dtype + + +def parse_args() -> argparse.Namespace: + """parse command line arguments""" + parser = argparse.ArgumentParser(description="Wan 2.1 inference script") + + # WAN arguments + parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).") + parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") + parser.add_argument( + "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample." + ) + + parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path") + parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16") + parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU") + parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path") + parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path") + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") + parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + + # inference + parser.add_argument("--prompt", type=str, default=None, help="prompt for generation") + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + help="negative prompt for generation, use default negative prompt if not specified", + ) + parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width") + parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task") + parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16") + parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + parser.add_argument( + "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False." + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=5.0, + help="Guidance scale for classifier free guidance. Default is 5.0.", + ) + parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") + parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference") + parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference") + parser.add_argument( + "--control_path", + type=str, + default=None, + help="path to control video for inference with controlnet. video file or directory with images", + ) + parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving") + parser.add_argument( + "--cfg_skip_mode", + type=str, + default="none", + choices=["early", "late", "middle", "early_late", "alternate", "none"], + help="CFG skip mode. each mode skips different parts of the CFG. " + " early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)", + ) + parser.add_argument( + "--cfg_apply_ratio", + type=float, + default=None, + help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).", + ) + parser.add_argument( + "--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated" + ) + parser.add_argument( + "--slg_scale", + type=float, + default=3.0, + help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond", + ) + parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.") + parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.") + parser.add_argument( + "--slg_mode", + type=str, + default=None, + choices=["original", "uncond"], + help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred", + ) + + # Flow Matching + parser.add_argument( + "--flow_shift", + type=float, + default=None, + help="Shift factor for flow matching schedulers. Default depends on task.", + ) + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") + parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled") + parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", + type=str, + default="torch", + choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"], + help="attention mode", + ) + parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") + parser.add_argument( + "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") + parser.add_argument("--compile", action="store_true", help="Enable torch.compile") + parser.add_argument( + "--compile_args", + nargs=4, + metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), + default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], + help="Torch.compile settings", + ) + + # New arguments for batch and interactive modes + parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file") + parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console") + + args = parser.parse_args() + + # Validate arguments + if args.from_file and args.interactive: + raise ValueError("Cannot use both --from_file and --interactive at the same time") + + if args.prompt is None and not args.from_file and not args.interactive and args.latent_path is None: + raise ValueError("Either --prompt, --from_file, --interactive, or --latent_path must be specified") + + assert (args.latent_path is None or len(args.latent_path) == 0) or ( + args.output_type == "images" or args.output_type == "video" + ), "latent_path is only supported for images or video output" + + return args + + +def parse_prompt_line(line: str) -> Dict[str, Any]: + """Parse a prompt line into a dictionary of argument overrides + + Args: + line: Prompt line with options + + Returns: + Dict[str, Any]: Dictionary of argument overrides + """ + # TODO common function with hv_train_network.line_to_prompt_dict + parts = line.split(" --") + prompt = parts[0].strip() + + # Create dictionary of overrides + overrides = {"prompt": prompt} + + for part in parts[1:]: + if not part.strip(): + continue + option_parts = part.split(" ", 1) + option = option_parts[0].strip() + value = option_parts[1].strip() if len(option_parts) > 1 else "" + + # Map options to argument names + if option == "w": + overrides["video_size_width"] = int(value) + elif option == "h": + overrides["video_size_height"] = int(value) + elif option == "f": + overrides["video_length"] = int(value) + elif option == "d": + overrides["seed"] = int(value) + elif option == "s": + overrides["infer_steps"] = int(value) + elif option == "g" or option == "l": + overrides["guidance_scale"] = float(value) + elif option == "fs": + overrides["flow_shift"] = float(value) + elif option == "i": + overrides["image_path"] = value + elif option == "cn": + overrides["control_path"] = value + elif option == "n": + overrides["negative_prompt"] = value + + return overrides + + +def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace: + """Apply overrides to args + + Args: + args: Original arguments + overrides: Dictionary of overrides + + Returns: + argparse.Namespace: New arguments with overrides applied + """ + args_copy = copy.deepcopy(args) + + for key, value in overrides.items(): + if key == "video_size_width": + args_copy.video_size[1] = value + elif key == "video_size_height": + args_copy.video_size[0] = value + else: + setattr(args_copy, key, value) + + return args_copy + + +def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None) -> Tuple[int, float, int, bool]: + """Return default values for each task + + Args: + task: task name (t2v, t2i, i2v etc.) + size: size of the video (width, height) + + Returns: + Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip) + """ + width, height = size if size else (0, 0) + + if "t2i" in task: + return 50, 5.0, 1, False + elif "i2v" in task: + flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 + return 40, flow_shift, 81, True + else: # t2v or default + return 50, 5.0, 81, False + + +def setup_args(args: argparse.Namespace) -> argparse.Namespace: + """Validate and set default values for optional arguments + + Args: + args: command line arguments + + Returns: + argparse.Namespace: updated arguments + """ + # Get default values for the task + infer_steps, flow_shift, video_length, _ = get_task_defaults(args.task, tuple(args.video_size)) + + # Apply default values to unset arguments + if args.infer_steps is None: + args.infer_steps = infer_steps + if args.flow_shift is None: + args.flow_shift = flow_shift + if args.video_length is None: + args.video_length = video_length + + # Force video_length to 1 for t2i tasks + if "t2i" in args.task: + assert args.video_length == 1, f"video_length should be 1 for task {args.task}" + + # parse slg_layers + if args.slg_layers is not None: + args.slg_layers = list(map(int, args.slg_layers.split(","))) + + return args + + +def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]: + """Validate video size and length + + Args: + args: command line arguments + + Returns: + Tuple[int, int, int]: (height, width, video_length) + """ + height = args.video_size[0] + width = args.video_size[1] + size = f"{width}*{height}" + + if size not in SUPPORTED_SIZES[args.task]: + logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.") + + video_length = args.video_length + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + return height, width, video_length + + +def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]: + """calculate dimensions for the generation + + Args: + video_size: video frame size (height, width) + video_length: number of frames in the video + config: model configuration + + Returns: + Tuple[Tuple[int, int, int, int], int]: + ((channels, frames, height, width), seq_len) + """ + height, width = video_size + frames = video_length + + # calculate latent space dimensions + lat_f = (frames - 1) // config.vae_stride[0] + 1 + lat_h = height // config.vae_stride[1] + lat_w = width // config.vae_stride[2] + + # calculate sequence length + seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f) + + return ((16, lat_f, lat_h, lat_w), seq_len) + + +def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE: + """load VAE model + + Args: + args: command line arguments + config: model configuration + device: device to use + dtype: data type for the model + + Returns: + WanVAE: loaded VAE model + """ + vae_path = args.vae if args.vae is not None else os.path.join(args.ckpt_dir, config.vae_checkpoint) + + logger.info(f"Loading VAE model from {vae_path}") + cache_device = torch.device("cpu") if args.vae_cache_cpu else None + vae = WanVAE(vae_path=vae_path, device=device, dtype=dtype, cache_device=cache_device) + return vae + + +def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel: + """load text encoder (T5) model + + Args: + args: command line arguments + config: model configuration + device: device to use + + Returns: + T5EncoderModel: loaded text encoder model + """ + checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint) + tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer) + + text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=device, + checkpoint_path=checkpoint_path, + tokenizer_path=tokenizer_path, + weight_path=args.t5, + fp8=args.fp8_t5, + ) + + return text_encoder + + +def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel: + """load CLIP model (for I2V only) + + Args: + args: command line arguments + config: model configuration + device: device to use + + Returns: + CLIPModel: loaded CLIP model + """ + checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint) + tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer) + + clip = CLIPModel( + dtype=config.clip_dtype, + device=device, + checkpoint_path=checkpoint_path, + tokenizer_path=tokenizer_path, + weight_path=args.clip, + ) + + return clip + + +def load_dit_model( + args: argparse.Namespace, + config, + device: torch.device, + dit_dtype: torch.dtype, + dit_weight_dtype: Optional[torch.dtype] = None, + is_i2v: bool = False, +) -> WanModel: + """load DiT model + + Args: + args: command line arguments + config: model configuration + device: device to use + dit_dtype: data type for the model + dit_weight_dtype: data type for the model weights. None for as-is + is_i2v: I2V mode + + Returns: + WanModel: loaded DiT model + """ + loading_device = "cpu" + if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled: + loading_device = device + + loading_weight_dtype = dit_weight_dtype + if args.fp8_scaled or args.lora_weight is not None: + loading_weight_dtype = dit_dtype # load as-is + + # do not fp8 optimize because we will merge LoRA weights + model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, False) + + return model + + +def merge_lora_weights(lora_module: ModuleType, model: torch.nn.Module, args: argparse.Namespace, device: torch.device) -> None: + """merge LoRA weights to the model + + Args: + model: DiT model + args: command line arguments + device: device to use + """ + if args.lora_weight is None or len(args.lora_weight) == 0: + return + + for i, lora_weight in enumerate(args.lora_weight): + if args.lora_multiplier is not None and len(args.lora_multiplier) > i: + lora_multiplier = args.lora_multiplier[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + + # apply include/exclude patterns + original_key_count = len(weights_sd.keys()) + if args.include_patterns is not None and len(args.include_patterns) > i: + include_pattern = args.include_patterns[i] + regex_include = re.compile(include_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} + logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") + if args.exclude_patterns is not None and len(args.exclude_patterns) > i: + original_key_count_ex = len(weights_sd.keys()) + exclude_pattern = args.exclude_patterns[i] + regex_exclude = re.compile(exclude_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} + logger.info( + f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}" + ) + if len(weights_sd) != original_key_count: + remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) + remaining_keys.sort() + logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") + if len(weights_sd) == 0: + logger.warning(f"No keys left after filtering.") + + if args.lycoris: + lycoris_net, _ = create_network_from_weights( + multiplier=lora_multiplier, + file=None, + weights_sd=weights_sd, + unet=model, + text_encoder=None, + vae=None, + for_inference=True, + ) + lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device) + else: + network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True) + network.merge_to(None, model, weights_sd, device=device, non_blocking=True) + + synchronize_device(device) + logger.info("LoRA weights loaded") + + # save model here before casting to dit_weight_dtype + if args.save_merged_model: + logger.info(f"Saving merged model to {args.save_merged_model}") + mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory + logger.info("Merged model saved") + + +def optimize_model( + model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype +) -> None: + """optimize the model (FP8 conversion, device move etc.) + + Args: + model: dit model + args: command line arguments + device: device to use + dit_dtype: dtype for the model + dit_weight_dtype: dtype for the model weights + """ + if args.fp8_scaled: + # load state dict as-is and optimize to fp8 + state_dict = model.state_dict() + + # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy) + move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU + state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast) + + info = model.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded FP8 optimized weights: {info}") + + if args.blocks_to_swap == 0: + model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.) + else: + # simple cast to dit_dtype + target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) + target_device = None + + if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled + logger.info(f"Convert model to {dit_weight_dtype}") + target_dtype = dit_weight_dtype + + if args.blocks_to_swap == 0: + logger.info(f"Move model to device: {device}") + target_device = device + + model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations + + if args.compile: + compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args + logger.info( + f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" + ) + torch._dynamo.config.cache_size_limit = 32 + for i in range(len(model.blocks)): + model.blocks[i] = torch.compile( + model.blocks[i], + backend=compile_backend, + mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true", + ) + + if args.blocks_to_swap > 0: + logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") + model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) + model.move_to_device_except_swap_blocks(device) + model.prepare_block_swap_before_forward() + else: + # make sure the model is on the right device + model.to(device) + + model.eval().requires_grad_(False) + clean_memory_on_device(device) + + +def prepare_t2v_inputs( + args: argparse.Namespace, + config, + accelerator: Accelerator, + device: torch.device, + vae: Optional[WanVAE] = None, + encoded_context: Optional[Dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + """Prepare inputs for T2V + + Args: + args: command line arguments + config: model configuration + accelerator: Accelerator instance + device: device to use + vae: VAE model for control video encoding + encoded_context: Pre-encoded text context + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + (noise, context, context_null, (arg_c, arg_null)) + """ + # Prepare inputs for T2V + # calculate dimensions and sequence length + height, width = args.video_size + frames = args.video_length + (_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config) + target_shape = (16, lat_f, lat_h, lat_w) + + # configure negative prompt + n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt + + # set seed + seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + if not args.cpu_noise: + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + else: + # ComfyUI compatible noise + seed_g = torch.manual_seed(seed) + + if encoded_context is None: + # load text encoder + text_encoder = load_text_encoder(args, config, device) + text_encoder.model.to(device) + + # encode prompt + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + + # free text encoder and clean memory + del text_encoder + clean_memory_on_device(device) + else: + # Use pre-encoded context + context = encoded_context["context"] + context_null = encoded_context["context_null"] + + # Fun-Control: encode control video to latent space + if config.is_fun_control: + # TODO use same resizing as for image + logger.info(f"Encoding control video to latent space") + # C, F, H, W + control_video = load_control_video(args.control_path, frames, height, width).to(device) + vae.to_device(device) + with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): + control_latent = vae.encode([control_video])[0] + y = torch.concat([control_latent, torch.zeros_like(control_latent)], dim=0) # add control video latent + vae.to_device("cpu") + else: + y = None + + # generate noise + noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu") + noise = noise.to(device) + + # prepare model input arguments + arg_c = {"context": context, "seq_len": seq_len} + arg_null = {"context": context_null, "seq_len": seq_len} + if y is not None: + arg_c["y"] = [y] + arg_null["y"] = [y] + + return noise, context, context_null, (arg_c, arg_null) + + +def prepare_i2v_inputs( + args: argparse.Namespace, + config, + accelerator: Accelerator, + device: torch.device, + vae: WanVAE, + encoded_context: Optional[Dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + """Prepare inputs for I2V + + Args: + args: command line arguments + config: model configuration + accelerator: Accelerator instance + device: device to use + vae: VAE model, used for image encoding + encoded_context: Pre-encoded text context + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + (noise, context, context_null, y, (arg_c, arg_null)) + """ + # get video dimensions + height, width = args.video_size + frames = args.video_length + max_area = width * height + + # load image + img = Image.open(args.image_path).convert("RGB") + + # convert to numpy + img_cv2 = np.array(img) # PIL to numpy + + # convert to tensor (-1 to 1) + img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) + + # end frame image + if args.end_image_path is not None: + end_img = Image.open(args.end_image_path).convert("RGB") + end_img_cv2 = np.array(end_img) # PIL to numpy + else: + end_img = None + end_img_cv2 = None + has_end_image = end_img is not None + + # calculate latent dimensions: keep aspect ratio + height, width = img_tensor.shape[1:] + aspect_ratio = height / width + lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1]) + lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2]) + height = lat_h * config.vae_stride[1] + width = lat_w * config.vae_stride[2] + lat_f = (frames - 1) // config.vae_stride[0] + 1 # size of latent frames + max_seq_len = (lat_f + (1 if has_end_image else 0)) * lat_h * lat_w // (config.patch_size[1] * config.patch_size[2]) + + # set seed + seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + if not args.cpu_noise: + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + else: + # ComfyUI compatible noise + seed_g = torch.manual_seed(seed) + + # generate noise + noise = torch.randn( + 16, + lat_f + (1 if has_end_image else 0), + lat_h, + lat_w, + dtype=torch.float32, + generator=seed_g, + device=device if not args.cpu_noise else "cpu", + ) + noise = noise.to(device) + + # configure negative prompt + n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt + + if encoded_context is None: + # load text encoder + text_encoder = load_text_encoder(args, config, device) + text_encoder.model.to(device) + + # encode prompt + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + + # free text encoder and clean memory + del text_encoder + clean_memory_on_device(device) + + # load CLIP model + clip = load_clip_model(args, config, device) + clip.model.to(device) + + # encode image to CLIP context + logger.info(f"Encoding image to CLIP context") + with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): + clip_context = clip.visual([img_tensor[:, None, :, :]]) + logger.info(f"Encoding complete") + + # free CLIP model and clean memory + del clip + clean_memory_on_device(device) + else: + # Use pre-encoded context + context = encoded_context["context"] + context_null = encoded_context["context_null"] + clip_context = encoded_context["clip_context"] + + # encode image to latent space with VAE + logger.info(f"Encoding image to latent space") + vae.to_device(device) + + # resize image + interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC + img_resized = cv2.resize(img_cv2, (width, height), interpolation=interpolation) + img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW + img_resized = img_resized.unsqueeze(1) # CFHW + + if has_end_image: + interpolation = cv2.INTER_AREA if height < end_img_cv2.shape[1] else cv2.INTER_CUBIC + end_img_resized = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation) + end_img_resized = TF.to_tensor(end_img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW + end_img_resized = end_img_resized.unsqueeze(1) # CFHW + + # create mask for the first frame + msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device) + msk[:, 0] = 1 + if has_end_image: + msk[:, -1] = 1 + + # encode image to latent space + with accelerator.autocast(), torch.no_grad(): + # padding to match the required number of frames + padding_frames = frames - 1 # the first frame is image + img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, height, width, device=device)], dim=1) + y = vae.encode([img_resized])[0] + + if has_end_image: + y_end = vae.encode([end_img_resized])[0] + y = torch.concat([y, y_end], dim=1) # add end frame + + y = torch.concat([msk, y]) + logger.info(f"Encoding complete") + + # Fun-Control: encode control video to latent space + if config.is_fun_control: + # TODO use same resizing as for image + logger.info(f"Encoding control video to latent space") + # C, F, H, W + control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device) + with accelerator.autocast(), torch.no_grad(): + control_latent = vae.encode([control_video])[0] + y = y[msk.shape[0] :] # remove mask because Fun-Control does not need it + if has_end_image: + y[:, 1:-1] = 0 # remove image latent except first and last frame. according to WanVideoWrapper, this doesn't work + else: + y[:, 1:] = 0 # remove image latent except first frame + y = torch.concat([control_latent, y], dim=0) # add control video latent + + # prepare model input arguments + arg_c = { + "context": [context[0]], + "clip_fea": clip_context, + "seq_len": max_seq_len, + "y": [y], + } + + arg_null = { + "context": context_null, + "clip_fea": clip_context, + "seq_len": max_seq_len, + "y": [y], + } + + vae.to_device("cpu") # move VAE to CPU to save memory + clean_memory_on_device(device) + + return noise, context, context_null, y, (arg_c, arg_null) + + +def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor: + """load control video to latent space + + Args: + control_path: path to control video + frames: number of frames in the video + height: height of the video + width: width of the video + + Returns: + torch.Tensor: control video latent, CFHW + """ + logger.info(f"Load control video from {control_path}") + video = load_video(control_path, 0, frames, bucket_reso=(width, height)) # list of frames + if len(video) < frames: + raise ValueError(f"Video length is less than {frames}") + # video = np.stack(video, axis=0) # F, H, W, C + video = torch.stack([TF.to_tensor(frame).sub_(0.5).div_(0.5) for frame in video], dim=0) # F, C, H, W, -1 to 1 + video = video.permute(1, 0, 2, 3) # C, F, H, W + return video + + +def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]: + """setup scheduler for sampling + + Args: + args: command line arguments + config: model configuration + device: device to use + + Returns: + Tuple[Any, torch.Tensor]: (scheduler, timesteps) + """ + if args.sample_solver == "unipc": + scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False) + scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift) + timesteps = scheduler.timesteps + elif args.sample_solver == "dpm++": + scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift) + timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas) + elif args.sample_solver == "vanilla": + scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift) + scheduler.set_timesteps(args.infer_steps, device=device) + timesteps = scheduler.timesteps + + # FlowMatchDiscreteScheduler does not support generator argument in step method + org_step = scheduler.step + + def step_wrapper( + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ): + return org_step(model_output, timestep, sample, return_dict=return_dict) + + scheduler.step = step_wrapper + else: + raise NotImplementedError("Unsupported solver.") + + return scheduler, timesteps + + +def run_sampling( + model: WanModel, + noise: torch.Tensor, + scheduler: Any, + timesteps: torch.Tensor, + args: argparse.Namespace, + inputs: Tuple[dict, dict], + device: torch.device, + seed_g: torch.Generator, + accelerator: Accelerator, + is_i2v: bool = False, + use_cpu_offload: bool = True, +) -> torch.Tensor: + """run sampling + Args: + model: dit model + noise: initial noise + scheduler: scheduler for sampling + timesteps: time steps for sampling + args: command line arguments + inputs: model input (arg_c, arg_null) + device: device to use + seed_g: random generator + accelerator: Accelerator instance + is_i2v: I2V mode (False means T2V mode) + use_cpu_offload: Whether to offload tensors to CPU during processing + Returns: + torch.Tensor: generated latent + """ + arg_c, arg_null = inputs + + latent = noise + latent_storage_device = device if not use_cpu_offload else "cpu" + latent = latent.to(latent_storage_device) + + # cfg skip + apply_cfg_array = [] + num_timesteps = len(timesteps) + + if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None: + # Calculate thresholds based on cfg_apply_ratio + apply_steps = int(num_timesteps * args.cfg_apply_ratio) + + if args.cfg_skip_mode == "early": + # Skip CFG in early steps, apply in late steps + start_index = num_timesteps - apply_steps + end_index = num_timesteps + elif args.cfg_skip_mode == "late": + # Skip CFG in late steps, apply in early steps + start_index = 0 + end_index = apply_steps + elif args.cfg_skip_mode == "early_late": + # Skip CFG in early and late steps, apply in middle steps + start_index = (num_timesteps - apply_steps) // 2 + end_index = start_index + apply_steps + elif args.cfg_skip_mode == "middle": + # Skip CFG in middle steps, apply in early and late steps + skip_steps = num_timesteps - apply_steps + middle_start = (num_timesteps - skip_steps) // 2 + middle_end = middle_start + skip_steps + + w = 0.0 + for step_idx in range(num_timesteps): + if args.cfg_skip_mode == "alternate": + # accumulate w and apply CFG when w >= 1.0 + w += args.cfg_apply_ratio + apply = w >= 1.0 + if apply: + w -= 1.0 + elif args.cfg_skip_mode == "middle": + # Skip CFG in early and late steps, apply in middle steps + apply = step_idx < middle_start or step_idx >= middle_end + else: + # Apply CFG on some steps based on ratio + apply = step_idx >= start_index and step_idx < end_index + + apply_cfg_array.append(apply) + + pattern = ["A" if apply else "S" for apply in apply_cfg_array] + pattern = "".join(pattern) + logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, pattern: {pattern}") + else: + # Apply CFG on all steps + apply_cfg_array = [True] * num_timesteps + + # SLG original implementation is based on https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py + slg_start_step = int(args.slg_start * num_timesteps) + slg_end_step = int(args.slg_end * num_timesteps) + + for i, t in enumerate(tqdm(timesteps)): + # latent is on CPU if use_cpu_offload is True + latent_model_input = [latent.to(device)] + timestep = torch.stack([t]).to(device) + + with accelerator.autocast(), torch.no_grad(): + noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to(latent_storage_device) + + apply_cfg = apply_cfg_array[i] # apply CFG or not + if apply_cfg: + apply_slg = i >= slg_start_step and i < slg_end_step + # print(f"Applying SLG: {apply_slg}, i: {i}, slg_start_step: {slg_start_step}, slg_end_step: {slg_end_step}") + if args.slg_mode == "original" and apply_slg: + noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device) + + # apply guidance + # SD3 formula: scaled = neg_out + (pos_out - neg_out) * cond_scale + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # calculate skip layer out + skip_layer_out = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to( + latent_storage_device + ) + + # apply skip layer guidance + # SD3 formula: scaled = scaled + (pos_out - skip_layer_out) * self.slg + noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out) + elif args.slg_mode == "uncond" and apply_slg: + # noise_pred_uncond is skip layer out + noise_pred_uncond = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to( + latent_storage_device + ) + + # apply guidance + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: + # normal guidance + noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device) + + # apply guidance + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + # step + latent_input = latent.unsqueeze(0) + temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent_input, return_dict=False, generator=seed_g)[0] + + # update latent + latent = temp_x0.squeeze(0) + + return latent + + +def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> torch.Tensor: + """main function for generation + + Args: + args: command line arguments + shared_models: dictionary containing pre-loaded models and encoded data + + Returns: + torch.Tensor: generated latent + """ + device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = ( + gen_settings.device, + gen_settings.cfg, + gen_settings.dit_dtype, + gen_settings.dit_weight_dtype, + gen_settings.vae_dtype, + ) + + # prepare accelerator + mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" + accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) + + # I2V or T2V + is_i2v = "i2v" in args.task + + # prepare seed + seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + args.seed = seed # set seed to args for saving + + # Check if we have shared models + if shared_models is not None: + # Use shared models and encoded data + vae = shared_models.get("vae") + model = shared_models.get("model") + encoded_context = shared_models.get("encoded_contexts", {}).get(args.prompt) + + # prepare inputs + if is_i2v: + # I2V + noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae, encoded_context) + else: + # T2V + noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae, encoded_context) + else: + # prepare inputs without shared models + if is_i2v: + # I2V: need text encoder, VAE and CLIP + vae = load_vae(args, cfg, device, vae_dtype) + noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae) + # vae is on CPU after prepare_i2v_inputs + else: + # T2V: need text encoder + vae = None + if cfg.is_fun_control: + # Fun-Control: need VAE for encoding control video + vae = load_vae(args, cfg, device, vae_dtype) + noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae) + + # load DiT model + model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) + + # merge LoRA weights + if args.lora_weight is not None and len(args.lora_weight) > 0: + merge_lora_weights(lora_wan, model, args, device) + + # if we only want to save the model, we can skip the rest + if args.save_merged_model: + return None + + # optimize model: fp8 conversion, block swap etc. + optimize_model(model, args, device, dit_dtype, dit_weight_dtype) + + # setup scheduler + scheduler, timesteps = setup_scheduler(args, cfg, device) + + # set random generator + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + + # run sampling + latent = run_sampling(model, noise, scheduler, timesteps, args, inputs, device, seed_g, accelerator, is_i2v) + + # Only clean up shared models if they were created within this function + if shared_models is None: + # free memory + del model + del scheduler + synchronize_device(device) + + # wait for 5 seconds until block swap is done + logger.info("Waiting for 5 seconds to finish block swap") + time.sleep(5) + + gc.collect() + clean_memory_on_device(device) + + # save VAE model for decoding + if vae is None: + args._vae = None + else: + args._vae = vae + + return latent + + +def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor: + """decode latent + + Args: + latent: latent tensor + args: command line arguments + cfg: model configuration + + Returns: + torch.Tensor: decoded video or image + """ + device = torch.device(args.device) + + # load VAE model or use the one from the generation + vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16 + if hasattr(args, "_vae") and args._vae is not None: + vae = args._vae + else: + vae = load_vae(args, cfg, device, vae_dtype) + + vae.to_device(device) + + logger.info(f"Decoding video from latents: {latent.shape}") + x0 = latent.to(device) + + with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad(): + videos = vae.decode(x0) + + # some tail frames may be corrupted when end frame is used, we add an option to remove them + if args.trim_tail_frames: + videos[0] = videos[0][:, : -args.trim_tail_frames] + + logger.info(f"Decoding complete") + video = videos[0] + del videos + video = video.to(torch.float32).cpu() + + return video + + +def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str: + """Save latent to file + + Args: + latent: latent tensor + args: command line arguments + height: height of frame + width: width of frame + + Returns: + str: Path to saved latent file + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + seed = args.seed + video_length = args.video_length + latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors" + + if args.no_metadata: + metadata = None + else: + metadata = { + "seeds": f"{seed}", + "prompt": f"{args.prompt}", + "height": f"{height}", + "width": f"{width}", + "video_length": f"{video_length}", + "infer_steps": f"{args.infer_steps}", + "guidance_scale": f"{args.guidance_scale}", + } + if args.negative_prompt is not None: + metadata["negative_prompt"] = f"{args.negative_prompt}" + + sd = {"latent": latent} + save_file(sd, latent_path, metadata=metadata) + logger.info(f"Latent saved to: {latent_path}") + + return latent_path + + +def save_video(video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: + """Save video to file + + Args: + video: Video tensor + args: command line arguments + original_base_name: Original base name (if latents are loaded from files) + + Returns: + str: Path to saved video file + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + seed = args.seed + original_name = "" if original_base_name is None else f"_{original_base_name}" + video_path = f"{save_path}/{time_flag}_{seed}{original_name}.mp4" + + video = video.unsqueeze(0) + save_videos_grid(video, video_path, fps=args.fps, rescale=True) + logger.info(f"Video saved to: {video_path}") + + return video_path + + +def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: + """Save images to directory + + Args: + sample: Video tensor + args: command line arguments + original_base_name: Original base name (if latents are loaded from files) + + Returns: + str: Path to saved images directory + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + seed = args.seed + original_name = "" if original_base_name is None else f"_{original_base_name}" + image_name = f"{time_flag}_{seed}{original_name}" + sample = sample.unsqueeze(0) + save_images_grid(sample, save_path, image_name, rescale=True) + logger.info(f"Sample images saved to: {save_path}/{image_name}") + + return f"{save_path}/{image_name}" + + +def save_output( + latent: torch.Tensor, args: argparse.Namespace, cfg, height: int, width: int, original_base_names: Optional[List[str]] = None +) -> None: + """save output + + Args: + latent: latent tensor + args: command line arguments + cfg: model configuration + height: height of frame + width: width of frame + original_base_names: original base names (if latents are loaded from files) + """ + if args.output_type == "latent" or args.output_type == "both": + # save latent + save_latent(latent, args, height, width) + + if args.output_type == "video" or args.output_type == "both": + # save video + sample = decode_latent(latent.unsqueeze(0), args, cfg) + original_name = "" if original_base_names is None else f"_{original_base_names[0]}" + save_video(sample, args, original_name) + + elif args.output_type == "images": + # save images + sample = decode_latent(latent.unsqueeze(0), args, cfg) + original_name = "" if original_base_names is None else f"_{original_base_names[0]}" + save_images(sample, args, original_name) + + +def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]: + """Process multiple prompts for batch mode + + Args: + prompt_lines: List of prompt lines + base_args: Base command line arguments + + Returns: + List[Dict]: List of prompt data dictionaries + """ + prompts_data = [] + + for line in prompt_lines: + line = line.strip() + if not line or line.startswith("#"): # Skip empty lines and comments + continue + + # Parse prompt line and create override dictionary + prompt_data = parse_prompt_line(line) + logger.info(f"Parsed prompt data: {prompt_data}") + prompts_data.append(prompt_data) + + return prompts_data + + +def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None: + """Process multiple prompts with model reuse + + Args: + prompts_data: List of prompt data dictionaries + args: Base command line arguments + """ + if not prompts_data: + logger.warning("No valid prompts found") + return + + # 1. Load configuration + gen_settings = get_generation_settings(args) + device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = ( + gen_settings.device, + gen_settings.cfg, + gen_settings.dit_dtype, + gen_settings.dit_weight_dtype, + gen_settings.vae_dtype, + ) + is_i2v = "i2v" in args.task + + # 2. Encode all prompts + logger.info("Loading text encoder to encode all prompts") + text_encoder = load_text_encoder(args, cfg, device) + text_encoder.model.to(device) + + encoded_contexts = {} + + with torch.no_grad(): + for prompt_data in prompts_data: + prompt = prompt_data["prompt"] + prompt_args = apply_overrides(args, prompt_data) + n_prompt = prompt_data.get( + "negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt + ) + + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype): + context = text_encoder([prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([prompt], device) + context_null = text_encoder([n_prompt], device) + + encoded_contexts[prompt] = {"context": context, "context_null": context_null} + + # Free text encoder and clean memory + del text_encoder + clean_memory_on_device(device) + + # 3. Process I2V additional encodings if needed + vae = None + if is_i2v: + logger.info("Loading VAE and CLIP for I2V preprocessing") + vae = load_vae(args, cfg, device, vae_dtype) + vae.to_device(device) + + clip = load_clip_model(args, cfg, device) + clip.model.to(device) + + # Process each image and encode with CLIP + for prompt_data in prompts_data: + if "image_path" not in prompt_data: + continue + + prompt_args = apply_overrides(args, prompt_data) + if not os.path.exists(prompt_args.image_path): + logger.warning(f"Image path not found: {prompt_args.image_path}") + continue + + # Load and encode image with CLIP + img = Image.open(prompt_args.image_path).convert("RGB") + img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) + + with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): + clip_context = clip.visual([img_tensor[:, None, :, :]]) + + encoded_contexts[prompt_data["prompt"]]["clip_context"] = clip_context + + # Free CLIP and clean memory + del clip + clean_memory_on_device(device) + + # Keep VAE in CPU memory for later use + vae.to_device("cpu") + elif cfg.is_fun_control: + # For Fun-Control, we need VAE but keep it on CPU + vae = load_vae(args, cfg, device, vae_dtype) + vae.to_device("cpu") + + # 4. Load DiT model + logger.info("Loading DiT model") + model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) + + # 5. Merge LoRA weights if needed + if args.lora_weight is not None and len(args.lora_weight) > 0: + merge_lora_weights(lora_wan, model, args, device) + if args.save_merged_model: + logger.info("Model merged and saved. Exiting.") + return + + # 6. Optimize model + optimize_model(model, args, device, dit_dtype, dit_weight_dtype) + + # Create shared models dict for generate function + shared_models = {"vae": vae, "model": model, "encoded_contexts": encoded_contexts} + + # 7. Generate for each prompt + all_latents = [] + all_prompt_args = [] + + for i, prompt_data in enumerate(prompts_data): + logger.info(f"Processing prompt {i+1}/{len(prompts_data)}: {prompt_data['prompt'][:50]}...") + + # Apply overrides for this prompt + prompt_args = apply_overrides(args, prompt_data) + + # Generate latent + latent = generate(prompt_args, gen_settings, shared_models) + + # Save latent if needed + height, width, _ = check_inputs(prompt_args) + if prompt_args.output_type == "latent" or prompt_args.output_type == "both": + save_latent(latent, prompt_args, height, width) + + all_latents.append(latent) + all_prompt_args.append(prompt_args) + + # 8. Free DiT model + del model + clean_memory_on_device(device) + synchronize_device(device) + + # wait for 5 seconds until block swap is done + logger.info("Waiting for 5 seconds to finish block swap") + time.sleep(5) + + gc.collect() + clean_memory_on_device(device) + + # 9. Decode latents if needed + if args.output_type != "latent": + logger.info("Decoding latents to videos/images") + + if vae is None: + vae = load_vae(args, cfg, device, vae_dtype) + + vae.to_device(device) + + for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)): + logger.info(f"Decoding output {i+1}/{len(all_latents)}") + + # Decode latent + video = decode_latent(latent.unsqueeze(0), prompt_args, cfg) + + # Save as video or images + if prompt_args.output_type == "video" or prompt_args.output_type == "both": + save_video(video, prompt_args) + elif prompt_args.output_type == "images": + save_images(video, prompt_args) + + # Free VAE + del vae + + clean_memory_on_device(device) + gc.collect() + + +def process_interactive(args: argparse.Namespace) -> None: + """Process prompts in interactive mode + + Args: + args: Base command line arguments + """ + gen_settings = get_generation_settings(args) + device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = ( + gen_settings.device, + gen_settings.cfg, + gen_settings.dit_dtype, + gen_settings.dit_weight_dtype, + gen_settings.vae_dtype, + ) + is_i2v = "i2v" in args.task + + # Initialize models to None + text_encoder = None + vae = None + model = None + clip = None + + print("Interactive mode. Enter prompts (Ctrl+D to exit):") + + try: + while True: + try: + line = input("> ") + if not line.strip(): + continue + + # Parse prompt + prompt_data = parse_prompt_line(line) + prompt_args = apply_overrides(args, prompt_data) + + # Ensure we have all the models we need + + # 1. Load text encoder if not already loaded + if text_encoder is None: + logger.info("Loading text encoder") + text_encoder = load_text_encoder(args, cfg, device) + + text_encoder.model.to(device) + + # Encode prompt + n_prompt = prompt_data.get( + "negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt + ) + + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype): + context = text_encoder([prompt_data["prompt"]], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([prompt_data["prompt"]], device) + context_null = text_encoder([n_prompt], device) + + encoded_context = {"context": context, "context_null": context_null} + + # Move text encoder to CPU after use + text_encoder.model.to("cpu") + + # 2. For I2V, we need CLIP and VAE + if is_i2v: + if clip is None: + logger.info("Loading CLIP model") + clip = load_clip_model(args, cfg, device) + + clip.model.to(device) + + # Encode image with CLIP if there's an image path + if prompt_args.image_path and os.path.exists(prompt_args.image_path): + img = Image.open(prompt_args.image_path).convert("RGB") + img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) + + with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): + clip_context = clip.visual([img_tensor[:, None, :, :]]) + + encoded_context["clip_context"] = clip_context + + # Move CLIP to CPU after use + clip.model.to("cpu") + + # Load VAE if needed + if vae is None: + logger.info("Loading VAE model") + vae = load_vae(args, cfg, device, vae_dtype) + elif cfg.is_fun_control and vae is None: + # For Fun-Control, we need VAE + logger.info("Loading VAE model for Fun-Control") + vae = load_vae(args, cfg, device, vae_dtype) + + # 3. Load DiT model if not already loaded + if model is None: + logger.info("Loading DiT model") + model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) + + # Merge LoRA weights if needed + if args.lora_weight is not None and len(args.lora_weight) > 0: + merge_lora_weights(lora_wan, model, args, device) + + # Optimize model + optimize_model(model, args, device, dit_dtype, dit_weight_dtype) + else: + # Move model to GPU if it was offloaded + model.to(device) + + # Create shared models dict + shared_models = {"vae": vae, "model": model, "encoded_contexts": {prompt_data["prompt"]: encoded_context}} + + # Generate latent + latent = generate(prompt_args, gen_settings, shared_models) + + # Move model to CPU after generation + model.to("cpu") + + # Save latent if needed + height, width, _ = check_inputs(prompt_args) + if prompt_args.output_type == "latent" or prompt_args.output_type == "both": + save_latent(latent, prompt_args, height, width) + + # Decode and save output + if prompt_args.output_type != "latent": + if vae is None: + vae = load_vae(args, cfg, device, vae_dtype) + + vae.to_device(device) + video = decode_latent(latent.unsqueeze(0), prompt_args, cfg) + + if prompt_args.output_type == "video" or prompt_args.output_type == "both": + save_video(video, prompt_args) + elif prompt_args.output_type == "images": + save_images(video, prompt_args) + + # Move VAE to CPU after use + vae.to_device("cpu") + + clean_memory_on_device(device) + + except KeyboardInterrupt: + print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)") + continue + + except EOFError: + print("\nExiting interactive mode") + + # Clean up all models + if text_encoder is not None: + del text_encoder + if clip is not None: + del clip + if vae is not None: + del vae + if model is not None: + del model + + clean_memory_on_device(device) + gc.collect() + + +def get_generation_settings(args: argparse.Namespace) -> GenerationSettings: + device = torch.device(args.device) + + cfg = WAN_CONFIGS[args.task] + + # select dtype + dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16 + if dit_dtype.itemsize == 1: + # if weight is in fp8, use bfloat16 for DiT (input/output) + dit_dtype = torch.bfloat16 + if args.fp8_scaled: + raise ValueError( + "DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください" + ) + + dit_weight_dtype = dit_dtype # default + if args.fp8_scaled: + dit_weight_dtype = None # various precision weights, so don't cast to specific dtype + elif args.fp8: + dit_weight_dtype = torch.float8_e4m3fn + + vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else dit_dtype + logger.info( + f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}, VAE precision: {vae_dtype}" + ) + + gen_settings = GenerationSettings( + device=device, + cfg=cfg, + dit_dtype=dit_dtype, + dit_weight_dtype=dit_weight_dtype, + vae_dtype=vae_dtype, + ) + return gen_settings + + +def main(): + # Parse arguments + args = parse_args() + + # Check if latents are provided + latents_mode = args.latent_path is not None and len(args.latent_path) > 0 + + # Set device + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + logger.info(f"Using device: {device}") + args.device = device + + if latents_mode: + # Original latent decode mode + cfg = WAN_CONFIGS[args.task] # any task is fine + original_base_names = [] + latents_list = [] + seeds = [] + + assert len(args.latent_path) == 1, "Only one latent path is supported for now" + + for latent_path in args.latent_path: + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + seed = 0 + + if os.path.splitext(latent_path)[1] != ".safetensors": + latents = torch.load(latent_path, map_location="cpu") + else: + latents = load_file(latent_path)["latent"] + with safe_open(latent_path, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + logger.info(f"Loaded metadata: {metadata}") + + if "seeds" in metadata: + seed = int(metadata["seeds"]) + if "height" in metadata and "width" in metadata: + height = int(metadata["height"]) + width = int(metadata["width"]) + args.video_size = [height, width] + if "video_length" in metadata: + args.video_length = int(metadata["video_length"]) + + seeds.append(seed) + latents_list.append(latents) + + logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") + + latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape + + height = latents.shape[-2] + width = latents.shape[-1] + height *= cfg.patch_size[1] * cfg.vae_stride[1] + width *= cfg.patch_size[2] * cfg.vae_stride[2] + video_length = latents.shape[1] + video_length = (video_length - 1) * cfg.vae_stride[0] + 1 + args.seed = seeds[0] + + # Decode and save + save_output(latent[0], args, cfg, height, width, original_base_names) + + elif args.from_file: + # Batch mode from file + args = setup_args(args) + + # Read prompts from file + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_lines = f.readlines() + + # Process prompts + prompts_data = preprocess_prompts_for_batch(prompt_lines, args) + process_batch_prompts(prompts_data, args) + + elif args.interactive: + # Interactive mode + args = setup_args(args) + process_interactive(args) + + else: + # Single prompt mode (original behavior) + args = setup_args(args) + height, width, video_length = check_inputs(args) + + logger.info( + f"Video size: {height}x{width}@{video_length} (HxW@F), fps: {args.fps}, " + f"infer_steps: {args.infer_steps}, flow_shift: {args.flow_shift}" + ) + + # Generate latent + gen_settings = get_generation_settings(args) + latent = generate(args, gen_settings) + + # Make sure the model is freed from GPU memory + gc.collect() + clean_memory_on_device(args.device) + + # Save latent and video + if args.save_merged_model: + return + + # Add batch dimension + latent = latent.unsqueeze(0) + save_output(latent[0], args, WAN_CONFIGS[args.task], height, width) + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/blissful_tuner/GIMMVFI.py b/blissful_tuner/GIMMVFI.py new file mode 100644 index 0000000000000000000000000000000000000000..78f709160defbaf43ac49d3cd04b0e12745b9d67 --- /dev/null +++ b/blissful_tuner/GIMMVFI.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Frame Rate Interpolation using GIMM-VFI +----------------------------------- +This specific code file as well as all files in ./blissful_tuner/gimmvfi and subfolders (all GIMM-VFI related code) licensed: + +S-Lab License 1.0 +Copyright 2024 S-Lab + +Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. +--------------------------------------- +Created on Mon Apr 14 12:23:15 2025 +@author: blyss +""" + +import os +import warnings +from typing import List +import torch +import yaml +from tqdm import tqdm +from omegaconf import OmegaConf +from rich.traceback import install as install_rich_tracebacks + +# Importing necessary modules from our project +from gimmvfi.generalizable_INR.gimmvfi_r import GIMMVFI_R +from gimmvfi.generalizable_INR.gimmvfi_f import GIMMVFI_F +from gimmvfi.generalizable_INR.configs import GIMMVFIConfig +from gimmvfi.generalizable_INR.raft import RAFT +from gimmvfi.generalizable_INR.flowformer.core.FlowFormer.LatentCostFormer.transformer import FlowFormer +from gimmvfi.generalizable_INR.flowformer.configs.submission import get_cfg +from gimmvfi.utils.utils import InputPadder, RaftArgs, easydict_to_dict +from utils import load_torch_file, setup_compute_context +from video_processing_common import BlissfulVideoProcessor, setup_parser_video_common, set_seed +warnings.filterwarnings("ignore") +install_rich_tracebacks() + + +def load_model(model_path: str, device: torch.device, dtype: torch.dtype, mode: str = "gimmvfi_r") -> torch.nn.Module: + """ + Loads the GIMM-VFI model along with its required flow estimator. + + Depending on the mode ("gimmvfi_r" or "gimmvfi_f") a different configuration, + checkpoint, and flow estimation network are loaded. + """ + + # Select proper configuration, checkpoint, and flow model based on mode. + if "gimmvfi_r" in mode: + config_path = os.path.join(model_path, "gimmvfi_r_arb.yaml") + flow_model_filename = "raft-things_fp32.safetensors" + checkpoint = os.path.join(model_path, "gimmvfi_r_arb_lpips_fp32.safetensors") + elif "gimmvfi_f" in mode: + config_path = os.path.join(model_path, "gimmvfi_f_arb.yaml") + checkpoint = os.path.join(model_path, "gimmvfi_f_arb_lpips_fp32.safetensors") + flow_model_filename = "flowformer_sintel_fp32.safetensors" + else: + raise ValueError(f"Unsupported mode: {mode}") + + flow_model_path = os.path.join(model_path, flow_model_filename) + + # Load and merge YAML configuration + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + config = easydict_to_dict(config) + config = OmegaConf.create(config) + arch_defaults = GIMMVFIConfig.create(config.arch) + config = OmegaConf.merge(arch_defaults, config.arch) + + # Initialize the model and its associated flow estimator + if "gimmvfi_r" in mode: + model = GIMMVFI_R(config) + # Setup RAFT as flow estimator + raft_args = RaftArgs(small=False, mixed_precision=False, alternate_corr=False) + raft_model = RAFT(raft_args) + raft_sd = load_torch_file(flow_model_path) + raft_model.load_state_dict(raft_sd, strict=True) + flow_estimator = raft_model.to(device, dtype) + else: # mode == "gimmvfi_f" + model = GIMMVFI_F(config) + cfg = get_cfg() + flowformer = FlowFormer(cfg.latentcostformer) + flowformer_sd = load_torch_file(flow_model_path) + flowformer.load_state_dict(flowformer_sd, strict=True) + flow_estimator = flowformer.to(device, dtype) + + # Load main model checkpoint + sd = load_torch_file(checkpoint) + model.load_state_dict(sd, strict=False) + + # Attach the flow estimator to the model, set evaluation mode, and move to device + model.flow_estimator = flow_estimator + model = model.eval().to(device, dtype) + + return model + + +def interpolate(model: torch.nn.Module, frames: List[torch.Tensor], ds_factor: float, N: int, VideoProcessor: BlissfulVideoProcessor): + """ + Interpolates frames using the provided model. + + Args: + model: The loaded interpolation model. + frames: List of input frame tensors. + ds_factor: Downsampling factor used by the model. + N: Number of interpolation steps between two frames. + """ + device = VideoProcessor.device + dtype = VideoProcessor.dtype + start = 0 + end = len(frames) - 1 + + # Process each adjacent pair of frames. + for j in tqdm(range(start, end), desc="Interpolating frames"): + I0 = frames[j] + I2 = frames[j + 1] + + # For the very first frame, add it directly. + if j == start: + VideoProcessor.write_np_or_tensor_to_png(I0) + + # Pad both images so that their dimensions are divisible by 32. + padder = InputPadder(I0.shape, 32) + I0_padded, I2_padded = padder.pad(I0, I2) + # Concatenate along a new dimension to create a tensor of shape [batch, 2, C, H, W] + xs = torch.cat((I0_padded.unsqueeze(2), I2_padded.unsqueeze(2)), dim=2).to(device, dtype, non_blocking=True) + + model.zero_grad() + + batch_size = xs.shape[0] + s_shape = xs.shape[-2:] + + with torch.no_grad(): + # Prepare coordinate inputs and timesteps for interpolation. + coord_inputs = [ + ( + model.sample_coord_input( + batch_size, + s_shape, + [1 / N * i], + device=xs.device, + upsample_ratio=ds_factor, + ), + None, + ) + for i in range(1, N) + ] + timesteps = [ + i / N * torch.ones(batch_size, device=xs.device, dtype=dtype) + for i in range(1, N) + ] + if dtype != torch.float32: + with torch.autocast(device_type=str(device), dtype=dtype): + all_outputs = model(xs, coord_inputs, t=timesteps, ds_factor=ds_factor) + else: + all_outputs = model(xs, coord_inputs, t=timesteps, ds_factor=ds_factor) + # Unpad the outputs to get back to original image size. + out_frames = [padder.unpad(im) for im in all_outputs["imgt_pred"]] + + # Convert each interpolated frame tensor to an image array. + I1_pred_images = [I1_pred[0] for I1_pred in out_frames] + + # Append the interpolated frames and corresponding flow images. + for i in range(N - 1): + VideoProcessor.write_np_or_tensor_to_png(I1_pred_images[i]) + + # Append the next original frame. + VideoProcessor.write_np_or_tensor_to_png(I2) + + +def main(): + parser = setup_parser_video_common(description="Frame rate interpolation using GIMM-VFI") + parser.add_argument("--ds_factor", type=float, default=1.0, help="Downsampling factor") + parser.add_argument("--mode", type=str, default="gimmvfi_f", help="Model mode: 'gimmvfi_r' or 'gimmvfi_f' for RAFT or FlowFormer version respectively") + parser.add_argument( + "--factor", type=int, default=2, help="Factor to increase the number of frames by. \ + A factor of 2 will double the fps, taking e.g. a 16fps video to 32fps. Can go up to 8 but higher values have more artifacts" + ) + args = parser.parse_args() + device, dtype = setup_compute_context(None, args.dtype) + VideoProcessor = BlissfulVideoProcessor(device, dtype) + VideoProcessor.prepare_files_and_path(args.input, args.output, "VFI", args.codec, args.container) + model = load_model(args.model, device, dtype, args.mode) + frames, fps, _, _ = VideoProcessor.load_frames(make_rgb=True) + frames = VideoProcessor.np_image_to_tensor(frames) + new_fps = fps * args.factor # Adjust the frame rate according to the interpolation + + # Set seed for reproducibility. + set_seed(args.seed) + + # Perform the frame interpolation. + interpolate(model, frames, args.ds_factor, args.factor, VideoProcessor) + + # Save the interpolated video. + VideoProcessor.write_buffered_frames_to_output(new_fps, args.keep_pngs) + + +if __name__ == "__main__": + main() diff --git a/blissful_tuner/__init__.py b/blissful_tuner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/blissful_tuner/advanced_rope.py b/blissful_tuner/advanced_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..3613cf351e971afb621576ca56cc9c9c93a501e0 --- /dev/null +++ b/blissful_tuner/advanced_rope.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Apr 16 19:25:53 2025 +Advanced rope functions for Blissful Tuner extension +License: Apache 2.0 + +@author: blyss +""" +import torch +import torch.nn as nn +from einops import rearrange +from typing import List +from blissful_tuner.hvw_posemb_layers import get_nd_rotary_pos_embed + + +# From ComfyUI +def apply_rope_comfy(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +# From WanVideoWrapper +def rope_riflex(pos, dim, theta, L_test, k, temporal): + assert dim % 2 == 0 + device = pos.device + scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float64, device=device) + omega = 1.0 / (theta**scale) + # RIFLEX modification - adjust last frequency component if L_test and k are provided + if temporal and k > 0 and L_test: + omega[k - 1] = 0.9 * 2 * torch.pi / L_test + out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.to(dtype=torch.float32, device=pos.device) + + +class EmbedND_RifleX(nn.Module): + def __init__(self: nn.Module, dim: int, theta: float, axes_dim: List[int], num_frames: int, k: int): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + self.num_frames = num_frames + self.k = k + + def forward(self, ids): + n_axes = ids.shape[-1] + emb = torch.cat( + [rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k, temporal=True if i == 0 else False) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + + +# Modified from HunyuanVideo Wrapper +def get_rotary_pos_embed_riflex(vae_ver, transformer, latent_video_length, height, width, k=0): + if "884" in vae_ver: + latents_size = [(latent_video_length - 1) // 4 + 1, height // 8, width // 8] + elif "888" in vae_ver: + latents_size = [(latent_video_length - 1) // 8 + 1, height // 8, width // 8] + else: + latents_size = [latent_video_length, height // 8, width // 8] + + target_ndim = 3 + ndim = 5 - 2 + rope_theta = 256 # 225 + patch_size = transformer.patch_size + rope_dim_list = transformer.rope_dim_list + hidden_size = transformer.hidden_size + heads_num = transformer.heads_num + head_dim = hidden_size // heads_num + + if isinstance(patch_size, int): + assert all(s % patch_size == 0 for s in latents_size), ( + f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), " + f"but got {latents_size}." + ) + rope_sizes = [s // patch_size for s in latents_size] + elif isinstance(patch_size, list): + assert all( + s % patch_size[idx] == 0 + for idx, s in enumerate(latents_size) + ), ( + f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), " + f"but got {latents_size}." + ) + rope_sizes = [ + s // patch_size[idx] for idx, s in enumerate(latents_size) + ] + + if len(rope_sizes) != target_ndim: + rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis + + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert ( + sum(rope_dim_list) == head_dim + ), "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list, + rope_sizes, + theta=rope_theta, + use_real=True, + theta_rescale_factor=1, + num_frames=latent_video_length, + k=k, + ) + return freqs_cos, freqs_sin diff --git a/blissful_tuner/blissful_args.py b/blissful_tuner/blissful_args.py new file mode 100644 index 0000000000000000000000000000000000000000..9855572c61762909a575db63f811aae6ed7b73c0 --- /dev/null +++ b/blissful_tuner/blissful_args.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sat Apr 26 15:11:58 2025 + +@author: blyss +""" +import sys +import os +import argparse +import torch +from rich.traceback import install as install_rich_tracebacks +from blissful_tuner.utils import BlissfulLogger, string_to_seed, parse_scheduled_cfg, error_out +logger = BlissfulLogger(__name__, "#8e00ed") + +BLISSFUL_VERSION = "0.4.0" + +CFG_SCHEDULE_HELP = """ +Comma-separated list of steps/ranges where CFG should be applied. + +You can specify: +- Single steps (e.g., '5') +- Ranges (e.g., '1-10') +- Modulus patterns (e.g., 'e~2' for every 2 steps) +- Guidance scale overrides (e.g., '1-10:5.0') + +Example schedule: + 'e~2:6.4, 1-10, 46-50' + +This would apply: +- Default CFG scale for steps 1-10 and 46-50 +- 6.4 CFG scale every 2 steps outside that range +- No CFG otherwise + +You can exclude steps using '!', e.g., '!32' skips step 32. +Note: The list is processed left to right, so modulus ranges should come first and exclusions at the end! +""" + +ROOT_SCRIPT = os.path.basename(sys.argv[0]).lower() +if "hv_" in ROOT_SCRIPT: + DIFFUSION_MODEL = "hunyuan" +elif "wan_" in ROOT_SCRIPT: + DIFFUSION_MODEL = "wan" +elif "fpack_" in ROOT_SCRIPT: + DIFFUSION_MODEL = "framepack" +else: + raise ValueError("Unsupported root_script for Blissful Extension") + +if "generate" in ROOT_SCRIPT: + MODE = "generate" +elif "train" in ROOT_SCRIPT: + MODE = "train" +else: + raise ValueError("Unsupported root script for Blissful Extension!") + + +def blissful_prefunc(args: argparse.Namespace): + """Simple function to print about version, environment, and things""" + cuda_list = [f"PyTorch: {torch.__version__}"] + if torch.cuda.is_available(): + allocator = torch.cuda.get_allocator_backend() + cuda = torch.cuda.get_device_properties(0) + cuda_list[0] += f", CUDA: {torch.version.cuda} CC: {cuda.major}.{cuda.minor}" + cuda_list.append(f"Device: '{cuda.name}', VRAM: '{cuda.total_memory // 1024 ** 2}MB'") + for string in cuda_list: + logger.info(string) + if args.fp16_accumulation and MODE == "generate": + logger.info("Enabling FP16 accumulation") + if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): + torch.backends.cuda.matmul.allow_fp16_accumulation = True + else: + raise ValueError("torch.backends.cuda.matmul.allow_fp16_accumulation is not available in this version of torch, requires torch 2.7.0.dev2025 02 26 nightly minimum") + + +def add_blissful_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + install_rich_tracebacks() + if DIFFUSION_MODEL == "wan": + parser.add_argument("--noise_aug_strength", type=float, default=0.0, help="Additional multiplier for i2v noise, higher might help motion/quality") + parser.add_argument("--prompt_weighting", action="store_true", help="Enable (prompt weighting:1.2)") + parser.add_argument( + "--rope_func", type=str, default="default", + help="Function to use for ROPE. Choose from 'default' or 'comfy' the latter of which uses ComfyUI implementation and is compilable with torch.compile to enable BIG VRAM savings" + ) + + elif DIFFUSION_MODEL == "hunyuan": + parser.add_argument("--hidden_state_skip_layer", type=int, default=2, help="Hidden state skip layer for LLM. Default is 2. Think 'clip skip' for the LLM") + parser.add_argument("--apply_final_norm", type=bool, default=False, help="Apply final norm for LLM. Default is False. Usually makes things worse.") + parser.add_argument("--reproduce", action="store_true", help="Enable reproducible output(Same seed = same result. Default is False.") + parser.add_argument("--fp8_scaled", action="store_true", help="Scaled FP8 quantization. Better quality/accuracy with slightly more VRAM usage.") + parser.add_argument("--prompt_2", type=str, required=False, help="Optional different prompt for CLIP") + parser.add_argument("--te_multiplier", nargs=2, metavar=("llm_multiplier", "clip_multiplier"), help="Scale clip and llm influence") + elif DIFFUSION_MODEL == "framepack": + parser.add_argument("--preview_latent_every", type=int, default=None, help="Enable latent preview every N sections. If --preview_vae is not specified it will use latent2rgb") + + if DIFFUSION_MODEL in ["wan", "hunyuan"]: + parser.add_argument("--riflex_index", type=int, default=0, help="Frequency for RifleX extension. 4 is good for Hunyuan, 6 is good for Wan. Only 'comfy' rope_func supports this with Wan!") + parser.add_argument("--cfgzerostar_scaling", action="store_true", help="Enables CFG-Zero* scaling - https://github.com/WeichenFan/CFG-Zero-star") + parser.add_argument("--cfgzerostar_init_steps", type=int, default=-1, help="Enables CFGZero* zeroing out the first N steps. 2 is good for Wan T2V, 1 for I2V") + parser.add_argument("--preview_latent_every", type=int, default=None, help="Enable latent preview every N steps. If --preview_vae is not specified it will use latent2rgb") + + # Common + + parser.add_argument("--preview_vae", type=str, help="Path to TAE vae for taehv previews") + parser.add_argument("--cfg_schedule", type=str, help=CFG_SCHEDULE_HELP) + parser.add_argument("--keep_pngs", action="store_true", help="Save frames as PNGs in addition to output video") + parser.add_argument("--codec", choices=["prores", "h264", "h265"], default=None, help="Codec to use, choose from 'prores', 'h264', or 'h265'") + parser.add_argument("--container", choices=["mkv", "mp4"], default="mkv", help="Container format to use, choose from 'mkv' or 'mp4'. Note prores can only go in MKV!") + parser.add_argument("--fp16_accumulation", action="store_true", help="Enable full FP16 Accmumulation in FP16 GEMMs, requires Pytorch 2.7.0 or higher") + return parser + + +def parse_blissful_args(args: argparse.Namespace) -> argparse.Namespace: + if args.seed is not None: + try: + args.seed = int(args.seed) + except ValueError: + string_seed = args.seed + args.seed = string_to_seed(args.seed) + logger.info(f"Seed {args.seed} was generated from string '{string_seed}'!") + if DIFFUSION_MODEL == "wan": + if args.riflex_index != 0 and args.rope_func.lower() != "comfy": + logger.error("RIFLEx can only be used with rope_func == 'comfy'!") + raise ValueError("RIFLEx can only be used with rope_func =='comfy'!") + if DIFFUSION_MODEL in ["wan", "hunyuan"]: + if args.cfg_schedule: + args.cfg_schedule = parse_scheduled_cfg(args.cfg_schedule, args.infer_steps, args.guidance_scale) + if args.cfgzerostar_scaling or args.cfgzerostar_init_steps != -1: + if args.guidance_scale == 1 and not args.cfg_schedule: + error_out(AttributeError, "Requested CFGZero* but CFG is not enabled so it will have no effect!") + blissful_prefunc(args) + return args diff --git a/blissful_tuner/blissful_settings.py b/blissful_tuner/blissful_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..988cccb25ff193127cb0c34500814acc4ced44ca --- /dev/null +++ b/blissful_tuner/blissful_settings.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Mar 11 19:08:55 2025 + +@author: blyss +""" +import os +import json + + +class SingletonMeta(type): + """ + The SingletonMeta class is useful for creating objects that persist as a single instance across the whole program. Basically a global class. + """ + _instances = {} + + def __call__(cls, *Parameters, **kwParameters): + if cls not in cls._instances: + cls._instances[cls] = super(SingletonMeta, cls).__call__(*Parameters, **kwParameters) + return cls._instances[cls] + + +class BlissfulSettings(metaclass=SingletonMeta): + def __init__(self): + """ + Loads the settings from a 'settings.json' file, creating it with default settings if it doesn't exist. + + This method attempts to read the program's settings from a JSON file. If the file does not exist, + it creates a new file with default settings. This ensures that the program can start with a known + set of configurations and modify them as needed. + + This class is a SingletonMeta so even if we reinstantiate the class, this only happens the first time + """ + # These are globals that do not persist + self.generating = 0 + self.last_preview_file = "" + + default_settings = { + "prompt": "a cat walks on the grass, realistic style", + "resolution_x": 960, + "resolution_y": 544, + "fps": 24, + "embedded_guidance": 6.0, + "flow_shift": 7.0, + "infer_steps": 50, + "seed": 42, + "video_length": 129, + "attention": "sage", + "blocks_to_swap": 0, + "hidden_state_skip_layer": 2, + "apply_final_norm": False, + "reproduce": False, + "fp8": True, + "fp8_fast": False, + "do_compile": False, + "transformer_path": "", + "text_encoder_1_path": "", + "text_encoder_2_path": "", + "vae_path": "", + "lora_path": "", + } + + if not os.path.exists("./settings.json"): + with open("./settings.json", "w", encoding="utf-8") as file: + json.dump(default_settings, file, indent=4) + print("No existing settings found. Created default settings file.") + + with open("./settings.json", "r", encoding="utf-8") as file: + data = json.load(file) + + for key, default_value in default_settings.items(): + setattr(self, key, data.get(key, default_value)) + + def save_to_file(self): + """ + Saves the current settings to a JSON file named 'settings.json'. + """ + settings = { + "prompt": self.prompt, + "resolution_x": self.resolution_x, + "resolution_y": self.resolution_y, + "fps": self.fps, + "embedded_guidance": self.embedded_guidance, + "flow_shift": self.flow_shift, + "infer_steps": self.infer_steps, + "seed": self.seed, + "video_length": self.video_length, + "attention": self.attention, + "blocks_to_swap": self.blocks_to_swap, + "hidden_state_skip_layer": self.hidden_state_skip_layer, + "apply_final_norm": self.apply_final_norm, + "reproduce": self.reproduce, + "fp8": self.fp8, + "fp8_fast": self.fp8_fast, + "do_compile": self.do_compile, + "transformer_path": self.transformer_path, + "text_encoder_1_path": self.text_encoder_1_path, + "text_encoder_2_path": self.text_encoder_2_path, + "vae_path": self.vae_path, + "lora_path": self.lora_path, + } + + with open("./settings.json", "w", encoding="utf-8") as file: + json.dump(settings, file, indent=4) + + def update(self, option, value, label_target=None, label_value=None): + """Method for updating various settings called via QT connection and may update an associated label/value""" + setattr(self, option, value) + if label_target is not None and label_value is not None: + label_target.setText(str(label_value)) diff --git a/blissful_tuner/cfgzerostar.py b/blissful_tuner/cfgzerostar.py new file mode 100644 index 0000000000000000000000000000000000000000..7e4ea1e57bcf1fca0ea13f233860a340b7abe3a3 --- /dev/null +++ b/blissful_tuner/cfgzerostar.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Apr 16 17:23:28 2025 +CFGZero* implementation for Blissful Tuner extension based on https://github.com/WeichenFan/CFG-Zero-star/blob/main/models/wan/wan_pipeline.py +License: Apache 2.0 +@author: blyss +""" +import torch + + +def apply_zerostar(cond: torch.Tensor, uncond: torch.Tensor, current_step: int, guidance_scale: float, use_scaling: bool = True, zero_init_steps: int = -1) -> torch.Tensor: + + if (current_step <= zero_init_steps): + return cond * 0 + if not use_scaling: + # CFG formula + noise_pred = uncond + guidance_scale * (cond - uncond) + else: + batch_size = cond.shape[0] + positive_flat = cond.view(batch_size, -1) + negative_flat = uncond.view(batch_size, -1) + alpha = optimized_scale(positive_flat, negative_flat) + alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1))) + alpha = alpha.to(cond.dtype) + # CFG formula modified with alpha + noise_pred = uncond * alpha + guidance_scale * (cond - uncond * alpha) + return noise_pred + + +def optimized_scale(positive_flat, negative_flat): + + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star diff --git a/blissful_tuner/codeformer/LICENSE b/blissful_tuner/codeformer/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6e1924f4b4e9cdf520a0f7ace2eabf028b4c6c9b --- /dev/null +++ b/blissful_tuner/codeformer/LICENSE @@ -0,0 +1,15 @@ +THIS FOLDER AND SUBFOLDERS (all CodeFormer related code and files) LICENSED AS BELOW + +S-Lab License 1.0 +Copyright 2024 S-Lab + +Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. diff --git a/blissful_tuner/codeformer/basicsr/VERSION b/blissful_tuner/codeformer/basicsr/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..1892b926767774e9ba91f1e584fa71b4c56abb69 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/VERSION @@ -0,0 +1 @@ +1.3.2 diff --git a/blissful_tuner/codeformer/basicsr/__init__.py b/blissful_tuner/codeformer/basicsr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05f4ae3b1a9acd81b61b6067748ff5765d459006 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/__init__.py @@ -0,0 +1,11 @@ +# https://github.com/xinntao/BasicSR +# flake8: noqa +from .archs import * +from .data import * +from .losses import * +from .metrics import * +from .models import * +from .ops import * +from .train import * +from .utils import * +#from .version import __gitsha__, __version__ diff --git a/blissful_tuner/codeformer/basicsr/archs/__init__.py b/blissful_tuner/codeformer/basicsr/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1285aadc37561772b05c2be1f31b9fd4438c9f56 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from codeformer.basicsr.utils import get_root_logger, scandir +from codeformer.basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'codeformer.basicsr.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/blissful_tuner/codeformer/basicsr/archs/arcface_arch.py b/blissful_tuner/codeformer/basicsr/archs/arcface_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c18695b069dc1cae3b0cd6d54aeb802ef40d4b --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/archs/arcface_arch.py @@ -0,0 +1,245 @@ +import torch.nn as nn +from codeformer.basicsr.utils.registry import ARCH_REGISTRY + + +def conv3x3(inplanes, outplanes, stride=1): + """A simple wrapper for 3x3 convolution with padding. + + Args: + inplanes (int): Channel number of inputs. + outplanes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + """ + return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + """Basic residual block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IRBlock(nn.Module): + """Improved residual block (IR Block) used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 4 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + """The squeeze-and-excitation block (SEBlock) used in the IRBlock. + + Args: + channel (int): Channel number of inputs. + reduction (int): Channel reduction ration. Default: 16. + """ + + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +@ARCH_REGISTRY.register() +class ResNetArcFace(nn.Module): + """ArcFace with ResNet architectures. + + Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. + + Args: + block (str): Block used in the ArcFace architecture. + layers (tuple(int)): Block numbers in each layer. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + + def __init__(self, block, layers, use_se=True): + if block == 'IRBlock': + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1d(512) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + x = self.bn5(x) + + return x \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/archs/arch_util.py b/blissful_tuner/codeformer/basicsr/archs/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8518a65174f17a5954b56174cc5e07e9a8372e --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/archs/arch_util.py @@ -0,0 +1,318 @@ +import collections.abc +import math +import torch +import torchvision +import warnings +from distutils.version import LooseVersion +from itertools import repeat +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from codeformer.basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from codeformer.basicsr.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class DCNv2Pack(ModulatedDeformConvPack): + """Modulated deformable conv for deformable alignment. + + Different from the official DCNv2Pack, which generates offsets and masks + from the preceding features, this DCNv2Pack takes another different + features to generate offsets and masks. + + Ref: + Delving Deep into Deformable Alignment in Video Super-Resolution. + """ + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger = get_root_logger() + logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') + + if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + else: + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/archs/codeformer_arch.py b/blissful_tuner/codeformer/basicsr/archs/codeformer_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe1452a18219b325148e59aa0252f91fd1bbfb6 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/archs/codeformer_arch.py @@ -0,0 +1,280 @@ +import math +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional, List + +from codeformer.basicsr.archs.vqgan_arch import * +from codeformer.basicsr.utils import get_root_logger +from codeformer.basicsr.utils.registry import ARCH_REGISTRY + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class TransformerSALayer(nn.Module): + def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + # Implementation of Feedforward model - MLP + self.linear1 = nn.Linear(embed_dim, dim_mlp) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_mlp, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + # self attention + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout2(tgt2) + return tgt + +class Fuse_sft_block(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.encode_enc = ResBlock(2*in_ch, out_ch) + + self.scale = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + self.shift = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + def forward(self, enc_feat, dec_feat, w=1): + enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) + scale = self.scale(enc_feat) + shift = self.shift(enc_feat) + residual = w * (dec_feat * scale + shift) + out = dec_feat + residual + return out + + +@ARCH_REGISTRY.register() +class CodeFormer(VQAutoEncoder): + def __init__(self, dim_embd=512, n_head=8, n_layers=9, + codebook_size=1024, latent_size=256, + connect_list=['32', '64', '128', '256'], + fix_modules=['quantize','generator'], vqgan_path=None): + super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + + if vqgan_path is not None: + self.load_state_dict( + torch.load(vqgan_path, map_location='cpu')['params_ema']) + + if fix_modules is not None: + for module in fix_modules: + for param in getattr(self, module).parameters(): + param.requires_grad = False + + self.connect_list = connect_list + self.n_layers = n_layers + self.dim_embd = dim_embd + self.dim_mlp = dim_embd*2 + + self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) + self.feat_emb = nn.Linear(256, self.dim_embd) + + # transformer + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + for _ in range(self.n_layers)]) + + # logits_predict head + self.idx_pred_layer = nn.Sequential( + nn.LayerNorm(dim_embd), + nn.Linear(dim_embd, codebook_size, bias=False)) + + self.channels = { + '16': 512, + '32': 256, + '64': 256, + '128': 128, + '256': 128, + '512': 64, + } + + # after second residual block for > 16, before attn layer for ==16 + self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} + # after first residual block for > 16, before attn layer for ==16 + self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} + + # fuse_convs_dict + self.fuse_convs_dict = nn.ModuleDict() + for f_size in self.connect_list: + in_ch = self.channels[f_size] + self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.clone() + + lq_feat = x + # ################# Transformer ################### + # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + if code_only: # for training stage II + # logits doesn't need softmax before cross_entropy loss + return logits, lq_feat + + # ################# Quantization ################### + # if self.training: + # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) + # # b(hw)c -> bc(hw) -> bchw + # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) + # ------------ + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + # preserve gradients + # quant_feat = lq_feat + (quant_feat - lq_feat).detach() + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + out = x + # logits doesn't need softmax before cross_entropy loss + return out, logits, lq_feat \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/archs/rrdbnet_arch.py b/blissful_tuner/codeformer/basicsr/archs/rrdbnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..adb799c0825752ec87a4497fc7801131e09c0f8b --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/archs/rrdbnet_arch.py @@ -0,0 +1,119 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from codeformer.basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +@ARCH_REGISTRY.register() +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/archs/vgg_arch.py b/blissful_tuner/codeformer/basicsr/archs/vgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec8c9883e9afa9242dd5145af9aa50492fce194 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/archs/vgg_arch.py @@ -0,0 +1,161 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +from codeformer.basicsr.utils.registry import ARCH_REGISTRY + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +@ARCH_REGISTRY.register() +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/blissful_tuner/codeformer/basicsr/archs/vqgan_arch.py b/blissful_tuner/codeformer/basicsr/archs/vqgan_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..7c11ede250224fae0430f7a3a4047c3e191c0c49 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/archs/vqgan_arch.py @@ -0,0 +1,434 @@ +''' +VQGAN code, adapted from the original created by the Unleashing Transformers authors: +https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py + +''' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from codeformer.basicsr.utils import get_root_logger +from codeformer.basicsr.utils.registry import ARCH_REGISTRY + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.jit.script +def swish(x): + return x*torch.sigmoid(x) + + +# Define VQVAE classes +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, beta): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.emb_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ + 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + + mean_distance = torch.mean(d) + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) + # [0-1], higher score, higher confidence + # min_encoding_scores = torch.exp(-min_encoding_scores/10) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, { + "perplexity": perplexity, + "min_encodings": min_encodings, + "min_encoding_indices": min_encoding_indices, + "mean_distance": mean_distance + } + + def get_codebook_feat(self, indices, shape): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1,1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): + super().__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits + self.embed = nn.Embedding(codebook_size, emb_dim) + + def forward(self, z): + hard = self.straight_through if self.training else True + + logits = self.proj(z) + + soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) + + z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() + min_encoding_indices = soft_one_hot.argmax(dim=1) + + return z_q, diff, { + "min_encoding_indices": min_encoding_indices + } + + +class Downsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Encoder(nn.Module): + def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): + super().__init__() + self.nf = nf + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions + + curr_res = self.resolution + in_ch_mult = (1,)+tuple(ch_mult) + + blocks = [] + # initial convultion + blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) + + # residual and downsampling blocks, with attention on smaller res (16x16) + for i in range(self.num_resolutions): + block_in_ch = nf * in_ch_mult[i] + block_out_ch = nf * ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + if curr_res in attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != self.num_resolutions - 1: + blocks.append(Downsample(block_in_ch)) + curr_res = curr_res // 2 + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + # normalise and convert to latent size + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +class Generator(nn.Module): + def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): + super().__init__() + self.nf = nf + self.ch_mult = ch_mult + self.num_resolutions = len(self.ch_mult) + self.num_res_blocks = res_blocks + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.in_channels = emb_dim + self.out_channels = 3 + block_in_ch = self.nf * self.ch_mult[-1] + curr_res = self.resolution // 2 ** (self.num_resolutions-1) + + blocks = [] + # initial conv + blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + for i in reversed(range(self.num_resolutions)): + block_out_ch = self.nf * self.ch_mult[i] + + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + + if curr_res in self.attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != 0: + blocks.append(Upsample(block_in_ch)) + curr_res = curr_res * 2 + + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) + + self.blocks = nn.ModuleList(blocks) + + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +@ARCH_REGISTRY.register() +class VQAutoEncoder(nn.Module): + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): + super().__init__() + logger = get_root_logger() + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks + self.codebook_size = codebook_size + self.embed_dim = emb_dim + self.ch_mult = ch_mult + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.quantizer_type = quantizer + self.encoder = Encoder( + self.in_channels, + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + if self.quantizer_type == "nearest": + self.beta = beta #0.25 + self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) + elif self.quantizer_type == "gumbel": + self.gumbel_num_hiddens = emb_dim + self.straight_through = gumbel_straight_through + self.kl_weight = gumbel_kl_weight + self.quantize = GumbelQuantizer( + self.codebook_size, + self.embed_dim, + self.gumbel_num_hiddens, + self.straight_through, + self.kl_weight + ) + self.generator = Generator( + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_ema' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) + logger.info(f'vqgan is loaded from: {model_path} [params_ema]') + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + logger.info(f'vqgan is loaded from: {model_path} [params]') + else: + raise ValueError(f'Wrong params!') + + + def forward(self, x): + x = self.encoder(x) + quant, codebook_loss, quant_stats = self.quantize(x) + x = self.generator(quant) + return x, codebook_loss, quant_stats + + + +# patch based discriminator +@ARCH_REGISTRY.register() +class VQGANDiscriminator(nn.Module): + def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): + super().__init__() + + layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] + ndf_mult = 1 + ndf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n, 8) + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n_layers, 8) + + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + layers += [ + nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map + self.main = nn.Sequential(*layers) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_d' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + else: + raise ValueError(f'Wrong params!') + + def forward(self, x): + return self.main(x) \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/data/__init__.py b/blissful_tuner/codeformer/basicsr/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce98519109fe2735563355ee6e47863a9d6811a --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/data/__init__.py @@ -0,0 +1,100 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from codeformer.basicsr.data.prefetch_dataloader import PrefetchDataLoader +from codeformer.basicsr.utils import get_root_logger, scandir +from codeformer.basicsr.utils.dist_util import get_dist_info +from codeformer.basicsr.utils.registry import DATASET_REGISTRY + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'codeformer.basicsr.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must constain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + logger = get_root_logger() + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/blissful_tuner/codeformer/basicsr/data/data_sampler.py b/blissful_tuner/codeformer/basicsr/data/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/blissful_tuner/codeformer/basicsr/data/data_util.py b/blissful_tuner/codeformer/basicsr/data/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..67e507ed8c0df0a69a12dd11d72c1611240ecd20 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/data/data_util.py @@ -0,0 +1,392 @@ +import cv2 +import math +import numpy as np +import torch +from os import path as osp +from PIL import Image, ImageDraw +from torch.nn import functional as F + +from codeformer.basicsr.data.transforms import mod_crop +from codeformer.basicsr.utils import img2tensor, scandir + + +def read_img_seq(path, require_mod_crop=False, scale=1): + """Read a sequence of images from a given folder path. + + Args: + path (list[str] | str): List of image paths or image folder path. + require_mod_crop (bool): Require mod crop for each image. + Default: False. + scale (int): Scale factor for mod_crop. Default: 1. + + Returns: + Tensor: size (t, c, h, w), RGB, [0, 1]. + """ + if isinstance(path, list): + img_paths = path + else: + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] + if require_mod_crop: + imgs = [mod_crop(img, scale) for img in imgs] + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = torch.stack(imgs, dim=0) + return imgs + + +def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'): + """Generate an index list for reading `num_frames` frames from a sequence + of images. + + Args: + crt_idx (int): Current center index. + max_frame_num (int): Max number of the sequence of images (from 1). + num_frames (int): Reading num_frames frames. + padding (str): Padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle' + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + list[int]: A list of indices. + """ + assert num_frames % 2 == 1, 'num_frames should be an odd number.' + assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.' + + max_frame_num = max_frame_num - 1 # start from 0 + num_pad = num_frames // 2 + + indices = [] + for i in range(crt_idx - num_pad, crt_idx + num_pad + 1): + if i < 0: + if padding == 'replicate': + pad_idx = 0 + elif padding == 'reflection': + pad_idx = -i + elif padding == 'reflection_circle': + pad_idx = crt_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if padding == 'replicate': + pad_idx = max_frame_num + elif padding == 'reflection': + pad_idx = max_frame_num * 2 - i + elif padding == 'reflection_circle': + pad_idx = (crt_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + indices.append(pad_idx) + return indices + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): + raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb ' + f'formats. But received {input_key}: {input_folder}; ' + f'{gt_key}: {gt_folder}') + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, 'meta_info.txt')) as fin: + input_lmdb_keys = [line.split('.')[0] for line in fin] + with open(osp.join(gt_folder, 'meta_info.txt')) as fin: + gt_lmdb_keys = [line.split('.')[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.') + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.split(' ')[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.') + gt_path = osp.join(gt_folder, gt_path) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + +def generate_gaussian_kernel(kernel_size=13, sigma=1.6): + """Generate Gaussian kernel used in `duf_downsample`. + + Args: + kernel_size (int): Kernel size. Default: 13. + sigma (float): Sigma of the Gaussian kernel. Default: 1.6. + + Returns: + np.array: The Gaussian kernel. + """ + from scipy.ndimage import filters as filters + kernel = np.zeros((kernel_size, kernel_size)) + # set element at the middle to one, a dirac delta + kernel[kernel_size // 2, kernel_size // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter + return filters.gaussian_filter(kernel, sigma) + + +def duf_downsample(x, kernel_size=13, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code. + + Args: + x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w). + kernel_size (int): Kernel size. Default: 13. + scale (int): Downsampling factor. Supported scale: (2, 3, 4). + Default: 4. + + Returns: + Tensor: DUF downsampled frames. + """ + assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.' + + squeeze_flag = False + if x.ndim == 4: + squeeze_flag = True + x = x.unsqueeze(0) + b, t, c, h, w = x.size() + x = x.view(-1, 1, h, w) + pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2 + x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect') + + gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale) + gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(b, t, c, x.size(2), x.size(3)) + if squeeze_flag: + x = x.squeeze(0) + return x + + +def brush_stroke_mask(img, color=(255,255,255)): + min_num_vertex = 8 + max_num_vertex = 28 + mean_angle = 2*math.pi / 5 + angle_range = 2*math.pi / 12 + # training large mask ratio (training setting) + min_width = 30 + max_width = 70 + # very large mask ratio (test setting and refine after 200k) + # min_width = 80 + # max_width = 120 + def generate_mask(H, W, img=None): + average_radius = math.sqrt(H*H+W*W) / 8 + mask = Image.new('RGB', (W, H), 0) + if img is not None: mask = img # Image.fromarray(img) + + for _ in range(np.random.randint(1, 4)): + num_vertex = np.random.randint(min_num_vertex, max_num_vertex) + angle_min = mean_angle - np.random.uniform(0, angle_range) + angle_max = mean_angle + np.random.uniform(0, angle_range) + angles = [] + vertex = [] + for i in range(num_vertex): + if i % 2 == 0: + angles.append(2*math.pi - np.random.uniform(angle_min, angle_max)) + else: + angles.append(np.random.uniform(angle_min, angle_max)) + + h, w = mask.size + vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h)))) + for i in range(num_vertex): + r = np.clip( + np.random.normal(loc=average_radius, scale=average_radius//2), + 0, 2*average_radius) + new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) + new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) + vertex.append((int(new_x), int(new_y))) + + draw = ImageDraw.Draw(mask) + width = int(np.random.uniform(min_width, max_width)) + draw.line(vertex, fill=color, width=width) + for v in vertex: + draw.ellipse((v[0] - width//2, + v[1] - width//2, + v[0] + width//2, + v[1] + width//2), + fill=color) + + return mask + + width, height = img.size + mask = generate_mask(height, width, img) + return mask + + +def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10): + """Generate a random free form mask with configuration. + Args: + config: Config should have configuration including IMG_SHAPES, + VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. + Returns: + tuple: (top, left, height, width) + Link: + https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py + """ + height = shape[0] + width = shape[1] + mask = np.zeros((height, width), np.float32) + times = np.random.randint(times-5, times) + for i in range(times): + start_x = np.random.randint(width) + start_y = np.random.randint(height) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_len-20, max_len) + brush_w = 5 + np.random.randint(max_width-30, max_width) + end_x = (start_x + length * np.sin(angle)).astype(np.int32) + end_y = (start_y + length * np.cos(angle)).astype(np.int32) + cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w) + start_x, start_y = end_x, end_y + return mask.astype(np.float32) \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/data/ffhq_blind_dataset.py b/blissful_tuner/codeformer/basicsr/data/ffhq_blind_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..94157b8f580204251e73d46cda6d8a230ae32f80 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/data/ffhq_blind_dataset.py @@ -0,0 +1,299 @@ +import cv2 +import math +import random +import numpy as np +import os.path as osp +from scipy.io import loadmat +from PIL import Image +import torch +import torch.utils.data as data +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, + adjust_hue, adjust_saturation, normalize) +from codeformer.basicsr.data import gaussian_kernels as gaussian_kernels +from codeformer.basicsr.data.transforms import augment +from codeformer.basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask +from codeformer.basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from codeformer.basicsr.utils.registry import DATASET_REGISTRY + +@DATASET_REGISTRY.register() +class FFHQBlindDataset(data.Dataset): + + def __init__(self, opt): + super(FFHQBlindDataset, self).__init__() + logger = get_root_logger() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.gt_size = opt.get('gt_size', 512) + self.in_size = opt.get('in_size', 512) + assert self.gt_size >= self.in_size, 'Wrong setting.' + + self.mean = opt.get('mean', [0.5, 0.5, 0.5]) + self.std = opt.get('std', [0.5, 0.5, 0.5]) + + self.component_path = opt.get('component_path', None) + self.latent_gt_path = opt.get('latent_gt_path', None) + + if self.component_path is not None: + self.crop_components = True + self.components_dict = torch.load(self.component_path) + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4) + self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1) + self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3) + else: + self.crop_components = False + + if self.latent_gt_path is not None: + self.load_latent_gt = True + self.latent_gt_dict = torch.load(self.latent_gt_path) + else: + self.load_latent_gt = False + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}') + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + self.paths = paths_from_folder(self.gt_folder) + + # inpainting mask + self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False) + if self.gen_inpaint_mask: + logger.info(f'generate mask ...') + # self.mask_max_angle = opt.get('mask_max_angle', 10) + # self.mask_max_len = opt.get('mask_max_len', 150) + # self.mask_max_width = opt.get('mask_max_width', 50) + # self.mask_draw_times = opt.get('mask_draw_times', 10) + # # print + # logger.info(f'mask_max_angle: {self.mask_max_angle}') + # logger.info(f'mask_max_len: {self.mask_max_len}') + # logger.info(f'mask_max_width: {self.mask_max_width}') + # logger.info(f'mask_draw_times: {self.mask_draw_times}') + + # perform corrupt + self.use_corrupt = opt.get('use_corrupt', True) + self.use_motion_kernel = False + # self.use_motion_kernel = opt.get('use_motion_kernel', True) + + if self.use_motion_kernel: + self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001) + motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth') + self.motion_kernels = torch.load(motion_kernel_path) + + if self.use_corrupt and not self.gen_inpaint_mask: + # degradation configurations + self.blur_kernel_size = opt['blur_kernel_size'] + self.blur_sigma = opt['blur_sigma'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.downsample_range = opt['downsample_range'] + self.noise_range = opt['noise_range'] + self.jpeg_range = opt['jpeg_range'] + # print + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob', None) + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None) + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') + + # to gray + self.gray_prob = opt.get('gray_prob', 0.0) + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + self.color_jitter_shift /= 255. + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + + def get_component_locations(self, name, status): + components_bbox = self.components_dict[name] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0] + components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0] + components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0] + + locations_gt = {} + locations_in = {} + for part in ['left_eye', 'right_eye', 'nose', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + elif part == 'nose': + half_len *= self.nose_enlarge_ratio + elif part == 'mouth': + half_len *= self.mouth_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations_gt[part] = loc + loc_in = loc/(self.gt_size//self.in_size) + locations_in[part] = loc_in + return locations_gt, locations_in + + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + gt_path = self.paths[index] + name = osp.basename(gt_path)[:-4] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + + if self.load_latent_gt: + if status[0]: + latent_gt = self.latent_gt_dict['hflip'][name] + else: + latent_gt = self.latent_gt_dict['orig'][name] + + if self.crop_components: + locations_gt, locations_in = self.get_component_locations(name, status) + + # generate in image + img_in = img_gt + if self.use_corrupt and not self.gen_inpaint_mask: + # motion blur + if self.use_motion_kernel and random.random() < self.motion_kernel_prob: + m_i = random.randint(0,31) + k = self.motion_kernels[f'{m_i:02d}'] + img_in = cv2.filter2D(img_in,-1,k) + + # gaussian blur + kernel = gaussian_kernels.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, + [-math.pi, math.pi], + noise_range=None) + img_in = cv2.filter2D(img_in, -1, kernel) + + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) + + # noise + if self.noise_range is not None: + noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.) + noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma + img_in = img_in + noise + img_in = np.clip(img_in, 0, 1) + + # jpeg + if self.jpeg_range is not None: + jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1]) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)] + _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param) + img_in = np.float32(cv2.imdecode(encimg, 1)) / 255. + + # resize to in_size + img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) + + # if self.gen_inpaint_mask: + # inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size), + # max_angle = self.mask_max_angle, max_len = self.mask_max_len, + # max_width = self.mask_max_width, times = self.mask_draw_times) + # img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \ + # 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1) + + # inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size) + + if self.gen_inpaint_mask: + img_in = (img_in*255).astype('uint8') + img_in = brush_stroke_mask(Image.fromarray(img_in)) + img_in = np.array(img_in) / 255. + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_in = self.color_jitter(img_in, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY) + img_in = np.tile(img_in[:, :, None], [1, 1, 3]) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue) + + # round and clip + img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255. + + # Set vgg range_norm=True if use the normalization here + # normalize + normalize(img_in, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path} + + if self.crop_components: + return_dict['locations_in'] = locations_in + return_dict['locations_gt'] = locations_gt + + if self.load_latent_gt: + return_dict['latent_gt'] = latent_gt + + # if self.gen_inpaint_mask: + # return_dict['inpaint_mask'] = inpaint_mask + + return return_dict + + + def __len__(self): + return len(self.paths) \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/data/ffhq_blind_joint_dataset.py b/blissful_tuner/codeformer/basicsr/data/ffhq_blind_joint_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d8e8e0c1b1a128c74b40071a2f076969bc3f7784 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/data/ffhq_blind_joint_dataset.py @@ -0,0 +1,324 @@ +import cv2 +import math +import random +import numpy as np +import os.path as osp +from scipy.io import loadmat +import torch +import torch.utils.data as data +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, + adjust_hue, adjust_saturation, normalize) +from codeformer.basicsr.data import gaussian_kernels as gaussian_kernels +from codeformer.basicsr.data.transforms import augment +from codeformer.basicsr.data.data_util import paths_from_folder +from codeformer.basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from codeformer.basicsr.utils.registry import DATASET_REGISTRY + +@DATASET_REGISTRY.register() +class FFHQBlindJointDataset(data.Dataset): + + def __init__(self, opt): + super(FFHQBlindJointDataset, self).__init__() + logger = get_root_logger() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.gt_size = opt.get('gt_size', 512) + self.in_size = opt.get('in_size', 512) + assert self.gt_size >= self.in_size, 'Wrong setting.' + + self.mean = opt.get('mean', [0.5, 0.5, 0.5]) + self.std = opt.get('std', [0.5, 0.5, 0.5]) + + self.component_path = opt.get('component_path', None) + self.latent_gt_path = opt.get('latent_gt_path', None) + + if self.component_path is not None: + self.crop_components = True + self.components_dict = torch.load(self.component_path) + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4) + self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1) + self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3) + else: + self.crop_components = False + + if self.latent_gt_path is not None: + self.load_latent_gt = True + self.latent_gt_dict = torch.load(self.latent_gt_path) + else: + self.load_latent_gt = False + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}') + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + self.paths = paths_from_folder(self.gt_folder) + + # perform corrupt + self.use_corrupt = opt.get('use_corrupt', True) + self.use_motion_kernel = False + # self.use_motion_kernel = opt.get('use_motion_kernel', True) + + if self.use_motion_kernel: + self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001) + motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth') + self.motion_kernels = torch.load(motion_kernel_path) + + if self.use_corrupt: + # degradation configurations + self.blur_kernel_size = self.opt['blur_kernel_size'] + self.kernel_list = self.opt['kernel_list'] + self.kernel_prob = self.opt['kernel_prob'] + # Small degradation + self.blur_sigma = self.opt['blur_sigma'] + self.downsample_range = self.opt['downsample_range'] + self.noise_range = self.opt['noise_range'] + self.jpeg_range = self.opt['jpeg_range'] + # Large degradation + self.blur_sigma_large = self.opt['blur_sigma_large'] + self.downsample_range_large = self.opt['downsample_range_large'] + self.noise_range_large = self.opt['noise_range_large'] + self.jpeg_range_large = self.opt['jpeg_range_large'] + + # print + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob', None) + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None) + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') + + # to gray + self.gray_prob = opt.get('gray_prob', 0.0) + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + self.color_jitter_shift /= 255. + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + + def get_component_locations(self, name, status): + components_bbox = self.components_dict[name] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0] + components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0] + components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0] + + locations_gt = {} + locations_in = {} + for part in ['left_eye', 'right_eye', 'nose', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + elif part == 'nose': + half_len *= self.nose_enlarge_ratio + elif part == 'mouth': + half_len *= self.mouth_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations_gt[part] = loc + loc_in = loc/(self.gt_size//self.in_size) + locations_in[part] = loc_in + return locations_gt, locations_in + + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + gt_path = self.paths[index] + name = osp.basename(gt_path)[:-4] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + + if self.load_latent_gt: + if status[0]: + latent_gt = self.latent_gt_dict['hflip'][name] + else: + latent_gt = self.latent_gt_dict['orig'][name] + + if self.crop_components: + locations_gt, locations_in = self.get_component_locations(name, status) + + # generate in image + img_in = img_gt + if self.use_corrupt: + # motion blur + if self.use_motion_kernel and random.random() < self.motion_kernel_prob: + m_i = random.randint(0,31) + k = self.motion_kernels[f'{m_i:02d}'] + img_in = cv2.filter2D(img_in,-1,k) + + # gaussian blur + kernel = gaussian_kernels.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, + [-math.pi, math.pi], + noise_range=None) + img_in = cv2.filter2D(img_in, -1, kernel) + + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) + + # noise + if self.noise_range is not None: + noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.) + noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma + img_in = img_in + noise + img_in = np.clip(img_in, 0, 1) + + # jpeg + if self.jpeg_range is not None: + jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1]) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)] + _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param) + img_in = np.float32(cv2.imdecode(encimg, 1)) / 255. + + # resize to in_size + img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) + + + # generate in_large with large degradation + img_in_large = img_gt + + if self.use_corrupt: + # motion blur + if self.use_motion_kernel and random.random() < self.motion_kernel_prob: + m_i = random.randint(0,31) + k = self.motion_kernels[f'{m_i:02d}'] + img_in_large = cv2.filter2D(img_in_large,-1,k) + + # gaussian blur + kernel = gaussian_kernels.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma_large, + self.blur_sigma_large, + [-math.pi, math.pi], + noise_range=None) + img_in_large = cv2.filter2D(img_in_large, -1, kernel) + + # downsample + scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1]) + img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) + + # noise + if self.noise_range_large is not None: + noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.) + noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma + img_in_large = img_in_large + noise + img_in_large = np.clip(img_in_large, 0, 1) + + # jpeg + if self.jpeg_range_large is not None: + jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1]) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)] + _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param) + img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255. + + # resize to in_size + img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) + + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_in = self.color_jitter(img_in, self.color_jitter_shift) + img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY) + img_in = np.tile(img_in[:, :, None], [1, 1, 3]) + img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY) + img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3]) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue) + img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue) + + # round and clip + img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255. + img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255. + + # Set vgg range_norm=True if use the normalization here + # normalize + normalize(img_in, self.mean, self.std, inplace=True) + normalize(img_in_large, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path} + + if self.crop_components: + return_dict['locations_in'] = locations_in + return_dict['locations_gt'] = locations_gt + + if self.load_latent_gt: + return_dict['latent_gt'] = latent_gt + + return return_dict + + + def __len__(self): + return len(self.paths) diff --git a/blissful_tuner/codeformer/basicsr/data/gaussian_kernels.py b/blissful_tuner/codeformer/basicsr/data/gaussian_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce57f0ae52bb4efce9212dd09960ac9c7358c3a --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/data/gaussian_kernels.py @@ -0,0 +1,690 @@ +import math +import numpy as np +import random +from scipy.ndimage.interpolation import shift +from scipy.stats import multivariate_normal + + +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + Returns: + ndarray: Rotated sigma matrix. + """ + D = np.array([[sig_x**2, 0], [0, sig_y**2]]) + U = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + return np.dot(U, np.dot(D, U.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + Args: + kernel_size (int): + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), + yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def cdf2(D, grid): + """Calculate the CDF of the standard bivariate Gaussian distribution. + Used in skewed Gaussian distribution. + Args: + D (ndarrasy): skew matrix. + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + Returns: + cdf (ndarray): skewed cdf. + """ + rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) + grid = np.dot(grid, D) + cdf = rv.cdf(grid) + return cdf + + +def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None): + """Generate a bivariate skew Gaussian kernel. + Described in `A multivariate skew normal distribution`_ by Shi et. al (2004). + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + D (ndarrasy): skew matrix. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + .. _A multivariate skew normal distribution: + https://www.sciencedirect.com/science/article/pii/S0047259X03001313 + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + pdf = pdf2(sigma_matrix, grid) + cdf = cdf2(D, grid) + kernel = pdf * cdf + kernel = kernel / np.sum(kernel) + return kernel + + +def mass_center_shift(kernel_size, kernel): + """Calculate the shift of the mass center of a kenrel. + Args: + kernel_size (int): + kernel (ndarray): normalized kernel. + Returns: + delta_h (float): + delta_w (float): + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1) + delta_h = np.dot(row_sum, ax) + delta_w = np.dot(col_sum, ax) + return delta_h, delta_w + + +def bivariate_skew_Gaussian_center(kernel_size, + sig_x, + sig_y, + theta, + D, + grid=None): + """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding. + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + D (ndarrasy): skew matrix. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): centered and normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid) + delta_h, delta_w = mass_center_shift(kernel_size, kernel) + kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest') + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_anisotropic_Gaussian(kernel_size, + sig_x, + sig_y, + theta, + grid=None): + """Generate a bivariate anisotropic Gaussian kernel. + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None): + """Generate a bivariate isotropic Gaussian kernel. + Args: + kernel_size (int): + sig (float): + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = np.array([[sig**2, 0], [0, sig**2]]) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_generalized_Gaussian(kernel_size, + sig_x, + sig_y, + theta, + beta, + grid=None): + """Generate a bivariate generalized Gaussian kernel. + Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_ + by Pascal et. al (2013). + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions: + https://arxiv.org/abs/1302.6498 + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp( + -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None): + """Generate a plateau-like anisotropic kernel. + 1 / (1+x^(beta)) + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal( + np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None): + """Generate a plateau-like isotropic kernel. + 1 / (1+x^(beta)) + Args: + kernel_size (int): + sig (float): + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = np.array([[sig**2, 0], [0, sig**2]]) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal( + np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_skew_Gaussian_center(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + strict=False): + """Randomly generate bivariate skew Gaussian kernels at center. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + + sigma_max = np.max([sigma_x, sigma_y]) + thres = 3 / sigma_max + D = [[np.random.uniform(-thres, thres), + np.random.uniform(-thres, thres)], + [np.random.uniform(-thres, thres), + np.random.uniform(-thres, thres)]] + + kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y, + rotation, D) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation, D + else: + return kernel + + +def random_bivariate_anisotropic_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + strict=False): + """Randomly generate bivariate anisotropic Gaussian kernels. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + + kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y, + rotation) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation + else: + return kernel + + +def random_bivariate_isotropic_Gaussian(kernel_size, + sigma_range, + noise_range=None, + strict=False): + """Randomly generate bivariate isotropic Gaussian kernels. + Args: + kernel_size (int): + sigma_range (tuple): [0.6, 5] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.' + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + + kernel = bivariate_isotropic_Gaussian(kernel_size, sigma) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma + else: + return kernel + + +def random_bivariate_generalized_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + strict=False): + """Randomly generate bivariate generalized Gaussian kernels. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, + rotation, beta) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation, beta + else: + return kernel + + +def random_bivariate_plateau_type1(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + strict=False): + """Randomly generate bivariate plateau type1 kernels. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi/2, math.pi/2] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation, + beta) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation, beta + else: + return kernel + + +def random_bivariate_plateau_type1_iso(kernel_size, + sigma_range, + beta_range, + noise_range=None, + strict=False): + """Randomly generate bivariate plateau type1 kernels (iso). + Args: + kernel_size (int): + sigma_range (tuple): [0.6, 5] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.' + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + beta = np.random.uniform(beta_range[0], beta_range[1]) + + kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma, beta + else: + return kernel + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=[0.6, 5], + sigma_y_range=[0.6, 5], + rotation_range=[-math.pi, math.pi], + beta_range=[0.5, 8], + noise_range=None): + """Randomly generate mixed kernels. + Args: + kernel_list (tuple): a list name of kenrel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if kernel_type == 'iso': + kernel = random_bivariate_isotropic_Gaussian( + kernel_size, sigma_x_range, noise_range=noise_range) + elif kernel_type == 'aniso': + kernel = random_bivariate_anisotropic_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range) + elif kernel_type == 'skew': + kernel = random_bivariate_skew_Gaussian_center( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range) + elif kernel_type == 'generalized': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=noise_range) + elif kernel_type == 'plateau_iso': + kernel = random_bivariate_plateau_type1_iso( + kernel_size, sigma_x_range, beta_range, noise_range=noise_range) + elif kernel_type == 'plateau_aniso': + kernel = random_bivariate_plateau_type1( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=noise_range) + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def show_one_kernel(): + import matplotlib.pyplot as plt + kernel_size = 21 + + # bivariate skew Gaussian + D = [[0, 0], [0, 0]] + D = [[3 / 4, 0], [0, 0.5]] + kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D) + # bivariate anisotropic Gaussian + kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4) + # bivariate anisotropic Gaussian + kernel = bivariate_isotropic_Gaussian(kernel_size, 1) + # bivariate generalized Gaussian + kernel = bivariate_generalized_Gaussian( + kernel_size, 2, 4, -math.pi / 4, beta=4) + + delta_h, delta_w = mass_center_shift(kernel_size, kernel) + print(delta_h, delta_w) + + fig, axs = plt.subplots(nrows=2, ncols=2) + # axs.set_axis_off() + ax = axs[0][0] + im = ax.matshow(kernel, cmap='jet', origin='upper') + fig.colorbar(im, ax=ax) + + # image + ax = axs[0][1] + kernel_vis = kernel - np.min(kernel) + kernel_vis = kernel_vis / np.max(kernel_vis) * 255. + ax.imshow(kernel_vis, interpolation='nearest') + + _, xx, yy = mesh_grid(kernel_size) + # contour + ax = axs[1][0] + CS = ax.contour(xx, yy, kernel, origin='upper') + ax.clabel(CS, inline=1, fontsize=3) + + # contourf + ax = axs[1][1] + kernel = kernel / np.max(kernel) + p = ax.contourf( + xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10)) + fig.colorbar(p) + + plt.show() + + +def show_plateau_kernel(): + import matplotlib.pyplot as plt + kernel_size = 21 + + kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None) + kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5) + kernel_gau = bivariate_generalized_Gaussian( + kernel_size, 2, 4, -math.pi / 8, 2, grid=None) + delta_h, delta_w = mass_center_shift(kernel_size, kernel) + print(delta_h, delta_w) + + # kernel_slice = kernel[10, :] + # kernel_gau_slice = kernel_gau[10, :] + # kernel_norm_slice = kernel_norm[10, :] + # fig, ax = plt.subplots() + # t = list(range(1, 22)) + + # ax.plot(t, kernel_gau_slice) + # ax.plot(t, kernel_slice) + # ax.plot(t, kernel_norm_slice) + + # t = np.arange(0, 10, 0.1) + # y = np.exp(-0.5 * t) + # y2 = np.reciprocal(1 + t) + # print(t.shape) + # print(y.shape) + # ax.plot(t, y) + # ax.plot(t, y2) + # plt.show() + + fig, axs = plt.subplots(nrows=2, ncols=2) + # axs.set_axis_off() + ax = axs[0][0] + im = ax.matshow(kernel, cmap='jet', origin='upper') + fig.colorbar(im, ax=ax) + + # image + ax = axs[0][1] + kernel_vis = kernel - np.min(kernel) + kernel_vis = kernel_vis / np.max(kernel_vis) * 255. + ax.imshow(kernel_vis, interpolation='nearest') + + _, xx, yy = mesh_grid(kernel_size) + # contour + ax = axs[1][0] + CS = ax.contour(xx, yy, kernel, origin='upper') + ax.clabel(CS, inline=1, fontsize=3) + + # contourf + ax = axs[1][1] + kernel = kernel / np.max(kernel) + p = ax.contourf( + xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10)) + fig.colorbar(p) + + plt.show() diff --git a/blissful_tuner/codeformer/basicsr/data/paired_image_dataset.py b/blissful_tuner/codeformer/basicsr/data/paired_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6e38f598f0f4da77aa0fa3ac7ebb460dca6064de --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/data/paired_image_dataset.py @@ -0,0 +1,101 @@ +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from codeformer.basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file +from codeformer.basicsr.data.transforms import augment, paired_random_crop +from codeformer.basicsr.utils import FileClient, imfrombytes, img2tensor +from codeformer.basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class PairedImageDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and + GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info_file': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_flip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(PairedImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + if 'filename_tmpl' in opt: + self.filename_tmpl = opt['filename_tmpl'] + else: + self.filename_tmpl = '{}' + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: + self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], + self.opt['meta_info_file'], self.filename_tmpl) + else: + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) + + # TODO: color space transform + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/blissful_tuner/codeformer/basicsr/data/prefetch_dataloader.py b/blissful_tuner/codeformer/basicsr/data/prefetch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/blissful_tuner/codeformer/basicsr/data/transforms.py b/blissful_tuner/codeformer/basicsr/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aead9dc73ed063e1c5865040eaa2652b26aa3ad3 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/data/transforms.py @@ -0,0 +1,165 @@ +import cv2 +import random + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): + """Paired random crop. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + h_lq, w_lq, _ = img_lqs[0].shape + h_gt, w_gt, _ = img_gts[0].shape + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/blissful_tuner/codeformer/basicsr/losses/__init__.py b/blissful_tuner/codeformer/basicsr/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec3c399902b805de1a0809a1dfab5d672fe8117 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/losses/__init__.py @@ -0,0 +1,26 @@ +from copy import deepcopy + +from codeformer.basicsr.utils import get_root_logger +from codeformer.basicsr.utils.registry import LOSS_REGISTRY +from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize, + gradient_penalty_loss, r1_penalty) + +__all__ = [ + 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', + 'r1_penalty', 'g_path_regularize' +] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/blissful_tuner/codeformer/basicsr/losses/loss_util.py b/blissful_tuner/codeformer/basicsr/losses/loss_util.py new file mode 100644 index 0000000000000000000000000000000000000000..744eeb46d1f3b5a7b4553ca23237ddd9c899a698 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/losses/loss_util.py @@ -0,0 +1,95 @@ +import functools +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper diff --git a/blissful_tuner/codeformer/basicsr/losses/losses.py b/blissful_tuner/codeformer/basicsr/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc225f1a2c244ec842e02934e837203d65b2914 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/losses/losses.py @@ -0,0 +1,455 @@ +import math +import lpips +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from codeformer.basicsr.archs.vgg_arch import VGGFeatureExtractor +from codeformer.basicsr.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. + Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0): + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight) + + def forward(self, pred, weight=None): + y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :]) + x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1]) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculting losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'mse': + self.criterion = torch.nn.MSELoss(reduction='mean') + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + +@LOSS_REGISTRY.register() +class LPIPSLoss(nn.Module): + def __init__(self, + loss_weight=1.0, + use_input_norm=True, + range_norm=False,): + super(LPIPSLoss, self).__init__() + self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() + self.loss_weight = loss_weight + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, pred, target): + if self.range_norm: + pred = (pred + 1) / 2 + target = (target + 1) / 2 + if self.use_input_norm: + pred = (pred - self.mean) / self.std + target = (target - self.mean) / self.std + lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) + return self.loss_weight * lpips_loss.mean() + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/blissful_tuner/codeformer/basicsr/metrics/__init__.py b/blissful_tuner/codeformer/basicsr/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3021b3f65291be577b8b62fbe89f64a61a8077 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/metrics/__init__.py @@ -0,0 +1,19 @@ +from copy import deepcopy + +from codeformer.basicsr.utils.registry import METRIC_REGISTRY +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/blissful_tuner/codeformer/basicsr/metrics/metric_util.py b/blissful_tuner/codeformer/basicsr/metrics/metric_util.py new file mode 100644 index 0000000000000000000000000000000000000000..60884ee20690724b85905f5bf5ef77d6618550bc --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from codeformer.basicsr.utils.matlab_functions import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/blissful_tuner/codeformer/basicsr/metrics/psnr_ssim.py b/blissful_tuner/codeformer/basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd9911781f96c9f2d2410b3054aa7f37a8f2284 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,128 @@ +import cv2 +import numpy as np + +from codeformer.basicsr.metrics.metric_util import reorder_image, to_y_channel +from codeformer.basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def _ssim(img1, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + ssims = [] + for i in range(img1.shape[2]): + ssims.append(_ssim(img1[..., i], img2[..., i])) + return np.array(ssims).mean() diff --git a/blissful_tuner/codeformer/basicsr/models/__init__.py b/blissful_tuner/codeformer/basicsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9610c07a266ba491117c337b78492585027cad --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/models/__init__.py @@ -0,0 +1,30 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from codeformer.basicsr.utils import get_root_logger, scandir +from codeformer.basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'codeformer.basicsr.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must constain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/blissful_tuner/codeformer/basicsr/models/base_model.py b/blissful_tuner/codeformer/basicsr/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..91750594db8e598f20e5131d6a71c18a4dd5a1e3 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/models/base_model.py @@ -0,0 +1,322 @@ +import logging +import os +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from codeformer.basicsr.models import lr_scheduler as lr_scheduler +from codeformer.basicsr.utils.dist_util import master_only + +logger = logging.getLogger('basicsr') + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}') + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warmup iter numbers. -1 for no warmup. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + torch.save(save_dict, save_path) + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with differnet name or different size when loading models. + + 1. Print keys with differnet names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + net = self.get_bare_model(net) + logger.info(f'Loading {net.__class__.__name__} model from {load_path}.') + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + torch.save(state, save_path) + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/blissful_tuner/codeformer/basicsr/models/codeformer_idx_model.py b/blissful_tuner/codeformer/basicsr/models/codeformer_idx_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8da0024d805cb2514afd235515a17691cc44e96c --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/models/codeformer_idx_model.py @@ -0,0 +1,220 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from codeformer.basicsr.archs import build_network +from codeformer.basicsr.metrics import calculate_metric +from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img +from codeformer.basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class CodeFormerIdxModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.input = data['in'].to(self.device) + self.b = self.gt.shape[0] + + if 'latent_gt' in data: + self.idx_gt = data['latent_gt'].to(self.device) + self.idx_gt = self.idx_gt.view(self.b, -1) + else: + self.idx_gt = None + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + if self.opt['datasets']['train'].get('latent_gt_path', None) is not None: + self.generate_idx_gt = False + elif self.opt.get('network_vqgan', None) is not None: + self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device) + self.hq_vqgan_fix.eval() + self.generate_idx_gt = True + for param in self.hq_vqgan_fix.parameters(): + param.requires_grad = False + else: + raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.') + + logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}') + + self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True) + self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0) + self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True) + self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) + + self.net_g.train() + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + self.optimizer_g.zero_grad() + + if self.generate_idx_gt: + x = self.hq_vqgan_fix.encoder(self.gt) + _, _, quant_stats = self.hq_vqgan_fix.quantize(x) + min_encoding_indices = quant_stats['min_encoding_indices'] + self.idx_gt = min_encoding_indices.view(self.b, -1) + + if self.hq_feat_loss: + # quant_feats + quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256]) + + logits, lq_feat = self.net_g(self.input, w=0, code_only=True) + + l_g_total = 0 + loss_dict = OrderedDict() + # hq_feat_loss + if self.hq_feat_loss: # codebook loss + l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight + l_g_total += l_feat_encoder + loss_dict['l_feat_encoder'] = l_feat_encoder + + # cross_entropy_loss + if self.cross_entropy_loss: + # b(hw)n -> bn(hw) + cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight + l_g_total += cross_entropy_loss + loss_dict['cross_entropy_loss'] = cross_entropy_loss + + l_g_total.backward() + self.optimizer_g.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.input, w=0) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.input, w=0) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/blissful_tuner/codeformer/basicsr/models/codeformer_joint_model.py b/blissful_tuner/codeformer/basicsr/models/codeformer_joint_model.py new file mode 100644 index 0000000000000000000000000000000000000000..eaaca3714685a4c19907618ab9db4c5fa3760af6 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/models/codeformer_joint_model.py @@ -0,0 +1,350 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + + +from codeformer.basicsr.archs import build_network +from codeformer.basicsr.losses import build_loss +from codeformer.basicsr.metrics import calculate_metric +from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img +from codeformer.basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class CodeFormerJointModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.input = data['in'].to(self.device) + self.input_large_de = data['in_large_de'].to(self.device) + self.b = self.gt.shape[0] + + if 'latent_gt' in data: + self.idx_gt = data['latent_gt'].to(self.device) + self.idx_gt = self.idx_gt.view(self.b, -1) + else: + self.idx_gt = None + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + if self.opt['datasets']['train'].get('latent_gt_path', None) is not None: + self.generate_idx_gt = False + elif self.opt.get('network_vqgan', None) is not None: + self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device) + self.hq_vqgan_fix.eval() + self.generate_idx_gt = True + for param in self.hq_vqgan_fix.parameters(): + param.requires_grad = False + else: + raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.') + + logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}') + + self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True) + self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0) + self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True) + self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) + self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8) + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + + self.fix_generator = train_opt.get('fix_generator', True) + logger.info(f'fix_generator: {self.fix_generator}') + + self.net_g_start_iter = train_opt.get('net_g_start_iter', 0) + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_start_iter = train_opt.get('net_d_start_iter', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max): + recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() + return d_weight + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def gray_resize_for_identity(self, out, size=128): + out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) + out_gray = out_gray.unsqueeze(1) + out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) + return out_gray + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + + if self.generate_idx_gt: + x = self.hq_vqgan_fix.encoder(self.gt) + output, _, quant_stats = self.hq_vqgan_fix.quantize(x) + min_encoding_indices = quant_stats['min_encoding_indices'] + self.idx_gt = min_encoding_indices.view(self.b, -1) + + if current_iter <= 40000: # small degradation + small_per_n = 1 + w = 1 + elif current_iter <= 80000: # small degradation + small_per_n = 1 + w = 1.3 + elif current_iter <= 120000: # large degradation + small_per_n = 120000 + w = 0 + else: # mixed degradation + small_per_n = 15 + w = 1.3 + + if current_iter % small_per_n == 0: + self.output, logits, lq_feat = self.net_g(self.input, w=w, detach_16=True) + large_de = False + else: + logits, lq_feat = self.net_g(self.input_large_de, code_only=True) + large_de = True + + if self.hq_feat_loss: + # quant_feats + quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256]) + + l_g_total = 0 + loss_dict = OrderedDict() + if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter: + # hq_feat_loss + if not 'transformer' in self.opt['network_g']['fix_modules']: + if self.hq_feat_loss: # codebook loss + l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight + l_g_total += l_feat_encoder + loss_dict['l_feat_encoder'] = l_feat_encoder + + # cross_entropy_loss + if self.cross_entropy_loss: + # b(hw)n -> bn(hw) + cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight + l_g_total += cross_entropy_loss + loss_dict['cross_entropy_loss'] = cross_entropy_loss + + # pixel loss + if not large_de: # when large degradation don't need image-level loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # perceptual loss + if self.cri_perceptual: + l_g_percep = self.cri_perceptual(self.output, self.gt) + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + + # gan loss + if current_iter > self.net_d_start_iter: + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + recon_loss = l_g_pix + l_g_percep + if not self.fix_generator: + last_layer = self.net_g.module.generator.blocks[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + else: + largest_fuse_size = self.opt['network_g']['connect_list'][-1] + last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + + d_weight *= self.scale_adaptive_gan_weight # 0.8 + loss_dict['d_weight'] = d_weight + l_g_total += d_weight * l_g_gan + loss_dict['l_g_gan'] = d_weight * l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + # optimize net_d + if not large_de: + if current_iter > self.net_d_start_iter: + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.input, w=1) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.input, w=1) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/blissful_tuner/codeformer/basicsr/models/codeformer_model.py b/blissful_tuner/codeformer/basicsr/models/codeformer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d829651022c776c5ad29d3551ea85c9846c2de24 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/models/codeformer_model.py @@ -0,0 +1,332 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from codeformer.basicsr.archs import build_network +from codeformer.basicsr.losses import build_loss +from codeformer.basicsr.metrics import calculate_metric +from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img +from codeformer.basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class CodeFormerModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.input = data['in'].to(self.device) + self.b = self.gt.shape[0] + + if 'latent_gt' in data: + self.idx_gt = data['latent_gt'].to(self.device) + self.idx_gt = self.idx_gt.view(self.b, -1) + else: + self.idx_gt = None + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None: + self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device) + self.hq_vqgan_fix.eval() + self.generate_idx_gt = True + for param in self.hq_vqgan_fix.parameters(): + param.requires_grad = False + else: + self.generate_idx_gt = False + + self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True) + self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0) + self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True) + self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) + self.fidelity_weight = train_opt.get('fidelity_weight', 1.0) + self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8) + + + self.net_g.train() + # define network net_d + if self.fidelity_weight > 0: + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + + self.fix_generator = train_opt.get('fix_generator', True) + logger.info(f'fix_generator: {self.fix_generator}') + + self.net_g_start_iter = train_opt.get('net_g_start_iter', 0) + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_start_iter = train_opt.get('net_d_start_iter', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max): + recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() + return d_weight + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + if self.fidelity_weight > 0: + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def gray_resize_for_identity(self, out, size=128): + out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) + out_gray = out_gray.unsqueeze(1) + out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) + return out_gray + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + + if self.generate_idx_gt: + x = self.hq_vqgan_fix.encoder(self.gt) + output, _, quant_stats = self.hq_vqgan_fix.quantize(x) + min_encoding_indices = quant_stats['min_encoding_indices'] + self.idx_gt = min_encoding_indices.view(self.b, -1) + + if self.fidelity_weight > 0: + self.output, logits, lq_feat = self.net_g(self.input, w=self.fidelity_weight, detach_16=True) + else: + logits, lq_feat = self.net_g(self.input, w=0, code_only=True) + + if self.hq_feat_loss: + # quant_feats + quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256]) + + l_g_total = 0 + loss_dict = OrderedDict() + if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter: + # hq_feat_loss + if self.hq_feat_loss: # codebook loss + l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight + l_g_total += l_feat_encoder + loss_dict['l_feat_encoder'] = l_feat_encoder + + # cross_entropy_loss + if self.cross_entropy_loss: + # b(hw)n -> bn(hw) + cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight + l_g_total += cross_entropy_loss + loss_dict['cross_entropy_loss'] = cross_entropy_loss + + if self.fidelity_weight > 0: # when fidelity_weight == 0 don't need image-level loss + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # perceptual loss + if self.cri_perceptual: + l_g_percep = self.cri_perceptual(self.output, self.gt) + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + + # gan loss + if current_iter > self.net_d_start_iter: + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + recon_loss = l_g_pix + l_g_percep + if not self.fix_generator: + last_layer = self.net_g.module.generator.blocks[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + else: + largest_fuse_size = self.opt['network_g']['connect_list'][-1] + last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + + d_weight *= self.scale_adaptive_gan_weight # 0.8 + loss_dict['d_weight'] = d_weight + l_g_total += d_weight * l_g_gan + loss_dict['l_g_gan'] = d_weight * l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + # optimize net_d + if current_iter > self.net_d_start_iter and self.fidelity_weight > 0: + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.input, w=self.fidelity_weight) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.input, w=self.fidelity_weight) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + if self.fidelity_weight > 0: + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/blissful_tuner/codeformer/basicsr/models/lr_scheduler.py b/blissful_tuner/codeformer/basicsr/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a423ce656c044ed5861a0056020074a517f04156 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/models/lr_scheduler.py @@ -0,0 +1,96 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The mimimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/blissful_tuner/codeformer/basicsr/models/sr_model.py b/blissful_tuner/codeformer/basicsr/models/sr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5f29c6a0a23080d7352f1554ef2748a2b11bde7d --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/models/sr_model.py @@ -0,0 +1,209 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from codeformer.basicsr.archs import build_network +from codeformer.basicsr.losses import build_loss +from codeformer.basicsr.metrics import calculate_metric +from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img +from codeformer.basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + +@MODEL_REGISTRY.register() +class SRModel(BaseModel): + """Base SR model for single image super-resolution.""" + + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + if hasattr(self, 'ema_decay'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'ema_decay'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/blissful_tuner/codeformer/basicsr/models/vqgan_model.py b/blissful_tuner/codeformer/basicsr/models/vqgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..46769e7ba42d888b22e98e64604478e6c747b41d --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/models/vqgan_model.py @@ -0,0 +1,285 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from codeformer.basicsr.archs import build_network +from codeformer.basicsr.losses import build_loss +from codeformer.basicsr.metrics import calculate_metric +from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img +from codeformer.basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class VQGANModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.b = self.gt.shape[0] + + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + if train_opt.get('codebook_opt'): + self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0) + else: + self.l_weight_codebook = 1.0 + + self.vqgan_quantizer = self.opt['network_g']['quantizer'] + logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}') + + self.net_g_start_iter = train_opt.get('net_g_start_iter', 0) + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_start_iter = train_opt.get('net_d_start_iter', 0) + self.disc_weight = train_opt.get('disc_weight', 0.8) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max): + recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() + return d_weight + + def adopt_weight(self, weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + loss_dict = OrderedDict() + if self.opt['network_g']['quantizer'] == 'gumbel': + self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1) + if current_iter%1000 == 0: + logger.info(f'temperature: {self.net_g.module.quantize.temperature}') + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output, l_codebook, quant_stats = self.net_g(self.gt) + + l_codebook = l_codebook*self.l_weight_codebook + + l_g_total = 0 + if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter: + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep = self.cri_perceptual(self.output, self.gt) + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + + # gan loss + if current_iter > self.net_d_start_iter: + # fake_g_pred = self.net_d(self.output_1024) + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + recon_loss = l_g_total + last_layer = self.net_g.module.generator.blocks[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter) + d_weight *= self.disc_weight # tamming setting 0.8 + l_g_total += d_weight * l_g_gan + loss_dict['l_g_gan'] = d_weight * l_g_gan + + l_g_total += l_codebook + loss_dict['l_codebook'] = l_codebook + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + if current_iter > self.net_d_start_iter: + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.gt) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.gt) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/blissful_tuner/codeformer/basicsr/ops/__init__.py b/blissful_tuner/codeformer/basicsr/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/blissful_tuner/codeformer/basicsr/ops/dcn/__init__.py b/blissful_tuner/codeformer/basicsr/ops/dcn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32e3592f896d61b4127e09d0476381b9d55e32ff --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/dcn/__init__.py @@ -0,0 +1,7 @@ +from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, + modulated_deform_conv) + +__all__ = [ + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', + 'modulated_deform_conv' +] diff --git a/blissful_tuner/codeformer/basicsr/ops/dcn/deform_conv.py b/blissful_tuner/codeformer/basicsr/ops/dcn/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..734154f9ed9447d585eae7df6886acb136f8a3cf --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/dcn/deform_conv.py @@ -0,0 +1,377 @@ +import math +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from torch.nn.modules.utils import _pair, _single + +try: + from . import deform_conv_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + deform_conv_ext = load( + 'deform_conv', + sources=[ + os.path.join(module_path, 'src', 'deform_conv_ext.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'), + ], + ) + + +class DeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64): + if input is not None and input.dim() != 4: + raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + deform_conv_ext.deform_conv_forward(input, weight, + offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input, + grad_offset, weight, ctx.bufs_[0], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight, + ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], + ctx.padding[1], ctx.padding[0], ctx.dilation[1], + ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, + cur_im2col_step) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output, + ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, + grad_output, weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, \ + f'in_channels {in_channels} is not divisible by groups {groups}' + assert out_channels % groups == 0, \ + f'out_channels {out_channels} is not divisible ' \ + f'by groups {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d9424908ed2dbd4ac3cdb98d13e09287a4d2f2d --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,685 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..98752dccf8c58817ca1a952554dd3f33188a2d34 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp b/blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41c6df6f721bd95a525fd6a03dd9882e863de042 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,164 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/blissful_tuner/codeformer/basicsr/ops/fused_act/__init__.py b/blissful_tuner/codeformer/basicsr/ops/fused_act/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..241dc0754fae7d88dbbd9a02e665ca30a73c7422 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] diff --git a/blissful_tuner/codeformer/basicsr/ops/fused_act/fused_act.py b/blissful_tuner/codeformer/basicsr/ops/fused_act/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..588f815e596ab0fc83ab0f9d21426c22ec5ed7c3 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/fused_act/fused_act.py @@ -0,0 +1,89 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import torch +from torch import nn +from torch.autograd import Function + +try: + from . import fused_act_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + fused_act_ext = load( + 'fused', + sources=[ + os.path.join(module_path, 'src', 'fused_bias_act.cpp'), + os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), + ], + ) + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/blissful_tuner/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp b/blissful_tuner/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..85ed0a79fb9c75f83470ac834090f03608d998ee --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/blissful_tuner/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/blissful_tuner/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..54c7ff53ce8306db2b3c582ec7fa6696a38b4df0 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/__init__.py b/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..397e85bea063e97fc4c12ad4d3e15669b69290bd --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ['upfirdn2d'] diff --git a/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43d0b6783a5b512b55815a291fcac2bebeea31e0 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp @@ -0,0 +1,24 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..8870063bae4468deab2e721f0978fe9facfb01b1 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py b/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..667f96e1ded35d48f163f37e21d1ed8ff191aac3 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,186 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +try: + from . import upfirdn2d_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'src', 'upfirdn2d.cpp'), + os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), + ], + ) + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/blissful_tuner/codeformer/basicsr/setup.py b/blissful_tuner/codeformer/basicsr/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..b24d0450a016f2f36d634c4f79516e8137639cb8 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/setup.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import sys +import time +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +from utils.misc import gpu_is_available + +version_file = './basicsr/version.py' + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + elif os.path.exists(version_file): + try: + from version import __version__ + sha = __version__.split('+')[-1] + except ImportError: + raise ImportError('Unable to get git version') + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('./basicsr/VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def make_cuda_ext(name, module, sources, sources_cuda=None): + if sources_cuda is None: + sources_cuda = [] + define_macros = [] + extra_compile_args = {'cxx': []} + + # if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': + if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1': + define_macros += [('WITH_CUDA', None)] + extension = CUDAExtension + extra_compile_args['nvcc'] = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + sources += sources_cuda + else: + print(f'Compiling {name} without CUDA') + extension = CppExtension + + return extension( + name=f'{module}.{name}', + sources=[os.path.join(*module.split('.'), p) for p in sources], + define_macros=define_macros, + extra_compile_args=extra_compile_args) + + +def get_requirements(filename='requirements.txt'): + with open(os.path.join('.', filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + if '--cuda_ext' in sys.argv: + ext_modules = [ + make_cuda_ext( + name='deform_conv_ext', + module='ops.dcn', + sources=['src/deform_conv_ext.cpp'], + sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), + make_cuda_ext( + name='fused_act_ext', + module='ops.fused_act', + sources=['src/fused_bias_act.cpp'], + sources_cuda=['src/fused_bias_act_kernel.cu']), + make_cuda_ext( + name='upfirdn2d_ext', + module='ops.upfirdn2d', + sources=['src/upfirdn2d.cpp'], + sources_cuda=['src/upfirdn2d_kernel.cu']), + ] + sys.argv.remove('--cuda_ext') + else: + ext_modules = [] + + write_version_py() + setup( + name='basicsr', + version=get_version(), + description='Open Source Image and Video Super-Resolution Toolbox', + long_description=readme(), + long_description_content_type='text/markdown', + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, restoration, super resolution', + url='https://github.com/xinntao/BasicSR', + include_package_data=True, + packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='Apache License 2.0', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension}, + zip_safe=False) diff --git a/blissful_tuner/codeformer/basicsr/train.py b/blissful_tuner/codeformer/basicsr/train.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5a932a876bd72c0af6e5acd12d3b99adbf7348 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/train.py @@ -0,0 +1,225 @@ +import argparse +import datetime +import logging +import math +import copy +import random +import time +import torch +from os import path as osp + +from codeformer.basicsr.data import build_dataloader, build_dataset +from codeformer.basicsr.data.data_sampler import EnlargedSampler +from codeformer.basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from codeformer.basicsr.models import build_model +from codeformer.basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger, + init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed) +from codeformer.basicsr.utils.dist_util import get_dist_info, init_dist +from codeformer.basicsr.utils.options import dict2str, parse + +import warnings +# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. +warnings.filterwarnings("ignore", category=UserWarning) + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = parse(args.opt, root_path, is_train=is_train) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + return opt + + +def init_loggers(opt): + log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # initialize wandb logger before tensorboard logger to allow proper sync: + if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None): + assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger'): + tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) + return logger, tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loader = None, None + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = build_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) + train_loader = build_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed']) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info('Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + + elif phase == 'val': + val_set = build_dataset(dataset_opt) + val_loader = build_dataloader( + val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}') + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loader, total_epochs, total_iters + + +def train_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(root_path, is_train=True) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + if opt['path'].get('resume_state'): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and opt['rank'] == 0: + mkdir_and_rename(osp.join('tb_logger', opt['name'])) + + # initialize loggers + logger, tb_logger = init_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loader, total_epochs, total_iters = result + + # create model + if resume_state: # resume training + check_resume(opt, resume_state['iter']) + model = build_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + else: + model = build_model(opt) + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}') + data_time, iter_time = time.time(), time.time() + start_time = time.time() + + for epoch in range(start_epoch, total_epochs + 1): + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_time = time.time() - data_time + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + # training + model.feed_data(train_data) + model.optimize_parameters(current_iter) + iter_time = time.time() - iter_time + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_time, 'data_time': data_time}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and opt['datasets'].get('val') is not None \ + and (current_iter % opt['val']['val_freq'] == 0): + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + + data_time = time.time() + iter_time = time.time() + train_data = prefetcher.next() + # end of iter + + # end of epoch + + consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None and opt['datasets'].get('val'): + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/blissful_tuner/codeformer/basicsr/utils/__init__.py b/blissful_tuner/codeformer/basicsr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcc1d540462712387523d1e326d1dfc2bcfbf32 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/__init__.py @@ -0,0 +1,29 @@ +from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt + +__all__ = [ + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt' +] diff --git a/blissful_tuner/codeformer/basicsr/utils/dist_util.py b/blissful_tuner/codeformer/basicsr/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/blissful_tuner/codeformer/basicsr/utils/download_util.py b/blissful_tuner/codeformer/basicsr/utils/download_util.py new file mode 100644 index 0000000000000000000000000000000000000000..2a267915743ee3f3232bc8fe992466b52468979a --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/download_util.py @@ -0,0 +1,95 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + print(response_file_size) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/utils/file_client.py b/blissful_tuner/codeformer/basicsr/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..7f38d9796da3899048924f2f803d1088927966b0 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/blissful_tuner/codeformer/basicsr/utils/img_util.py b/blissful_tuner/codeformer/basicsr/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5aba82ce08eefaeb3e56ea5a3a09c342ae513522 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/img_util.py @@ -0,0 +1,171 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] + \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/utils/lmdb_util.py b/blissful_tuner/codeformer/basicsr/utils/lmdb_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a10f60ffca2e36ac5f5564aafd70e79d06a723 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/lmdb_util.py @@ -0,0 +1,196 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/blissful_tuner/codeformer/basicsr/utils/logger.py b/blissful_tuner/codeformer/basicsr/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..a390ff2a2193e461c4e9ce25d1625bbef6bfb48d --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/logger.py @@ -0,0 +1,169 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class MessageLogger(): + """Message logger for printing. + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + @master_only + def __call__(self, log_vars): + """Format logging message. + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger: + # if k.startswith('l_'): + # self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + # else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = logging.getLogger('basicsr') + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + # file_handler = logging.FileHandler(log_file, 'w') + file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + Currently, only log the software version. + """ + import torch + import torchvision + + from codeformer.basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/utils/matlab_functions.py b/blissful_tuner/codeformer/basicsr/utils/matlab_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ce1004a2c9f8521505c4b5889d3c24a909c70d --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/matlab_functions.py @@ -0,0 +1,347 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if numpy_type: + out_2 = out_2.numpy().transpose(1, 2, 0) + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/blissful_tuner/codeformer/basicsr/utils/misc.py b/blissful_tuner/codeformer/basicsr/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4993167693b831fbde18e8af8d88a3e51b648c --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/misc.py @@ -0,0 +1,156 @@ +import os +import re +import random +import time +import torch +import numpy as np +from os import path as osp + +from .dist_util import master_only +from .logger import get_root_logger + +IS_HIGH_VERSION = True + +def gpu_is_available(): + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return True + return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False + +def get_device(gpu_id=None): + if gpu_id is None: + gpu_str = '' + elif isinstance(gpu_id, int): + gpu_str = f':{gpu_id}' + else: + raise TypeError('Input should be int value.') + + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return torch.device('mps'+gpu_str) + return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + logger = get_root_logger() + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + logger.warning('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (basename + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + logger.info(f"Set {name} to {opt['path'][name]}") + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/blissful_tuner/codeformer/basicsr/utils/options.py b/blissful_tuner/codeformer/basicsr/utils/options.py new file mode 100644 index 0000000000000000000000000000000000000000..856e6a08429fc477e877c62a659595855b43f528 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/options.py @@ -0,0 +1,108 @@ +import yaml +import time +from collections import OrderedDict +from os import path as osp +from codeformer.basicsr.utils.misc import get_time_str + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def parse(opt_path, root_path, is_train=True): + """Parse option file. + + Args: + opt_path (str): Option file path. + is_train (str): Indicate whether in training or not. Default: True. + + Returns: + (dict): Options. + """ + with open(opt_path, mode='r') as f: + Loader, _ = ordered_yaml() + opt = yaml.load(f, Loader=Loader) + + opt['is_train'] = is_train + + # opt['name'] = f"{get_time_str()}_{opt['name']}" + if opt['path'].get('resume_state', None): # Shangchen added + resume_state_path = opt['path'].get('resume_state') + opt['name'] = resume_state_path.split("/")[-3] + else: + opt['name'] = f"{get_time_str()}_{opt['name']}" + + + # datasets + for phase, dataset in opt['datasets'].items(): + # for several datasets, e.g., test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg diff --git a/blissful_tuner/codeformer/basicsr/utils/realesrgan_utils.py b/blissful_tuner/codeformer/basicsr/utils/realesrgan_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3283f9466f894e96ccfa8ed3338cb0c7a1c9675 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/realesrgan_utils.py @@ -0,0 +1,302 @@ +import cv2 +import math +import numpy as np +import os +import queue +import threading +import torch +from torch.nn import functional as F +from codeformer.basicsr.utils.download_util import load_file_from_url +from codeformer.basicsr.utils.misc import get_device + +# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, + scale, + model_path, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + # if gpu_id: + # self.device = torch.device( + # f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + # else: + # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + + self.device = get_device(gpu_id) if device is None else device + + # if the model_path starts with https, it will first download models to the folder: realesrgan/weights + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None) + loadnet = torch.load(model_path, map_location=torch.device('cpu')) + # prefer to use params_ema + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + try: + with torch.no_grad(): + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img_t = self.post_process() + output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + del output_img_t + torch.cuda.empty_cache() + except RuntimeError as error: + print(f"Failed inference for RealESRGAN: {error}") + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == 'quit': + break + + output = msg['output'] + save_path = msg['save_path'] + cv2.imwrite(save_path, output) + print(f'IO worker {self.qid} is done.') \ No newline at end of file diff --git a/blissful_tuner/codeformer/basicsr/utils/registry.py b/blissful_tuner/codeformer/basicsr/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..655753b3b9cbd0cfe73fe93a77cf1fcc3db6d827 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/registry.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj): + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj) + + def get(self, name): + ret = self._obj_map.get(name) + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/blissful_tuner/codeformer/basicsr/utils/video_util.py b/blissful_tuner/codeformer/basicsr/utils/video_util.py new file mode 100644 index 0000000000000000000000000000000000000000..20a2ff14c4016b4ec543051471fc930ad71d83f9 --- /dev/null +++ b/blissful_tuner/codeformer/basicsr/utils/video_util.py @@ -0,0 +1,125 @@ +''' +The code is modified from the Real-ESRGAN: +https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py + +''' +import cv2 +import sys +import numpy as np + +try: + import ffmpeg +except ImportError: + import pip + pip.main(['install', '--user', 'ffmpeg-python']) + import ffmpeg + +def get_video_meta_info(video_path): + ret = {} + probe = ffmpeg.probe(video_path) + video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] + has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) + ret['width'] = video_streams[0]['width'] + ret['height'] = video_streams[0]['height'] + ret['fps'] = eval(video_streams[0]['avg_frame_rate']) + ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None + ret['nb_frames'] = int(video_streams[0]['nb_frames']) + return ret + +class VideoReader: + def __init__(self, video_path): + self.paths = [] # for image&folder type + self.audio = None + try: + self.stream_reader = ( + ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24', + loglevel='error').run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + except FileNotFoundError: + print('Please install ffmpeg (not ffmpeg-python) by running\n', + '\t$ conda install -c conda-forge ffmpeg') + sys.exit(0) + + meta = get_video_meta_info(video_path) + self.width = meta['width'] + self.height = meta['height'] + self.input_fps = meta['fps'] + self.audio = meta['audio'] + self.nb_frames = meta['nb_frames'] + + self.idx = 0 + + def get_resolution(self): + return self.height, self.width + + def get_fps(self): + if self.input_fps is not None: + return self.input_fps + return 24 + + def get_audio(self): + return self.audio + + def __len__(self): + return self.nb_frames + + def get_frame_from_stream(self): + img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel + if not img_bytes: + return None + img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) + return img + + def get_frame_from_list(self): + if self.idx >= self.nb_frames: + return None + img = cv2.imread(self.paths[self.idx]) + self.idx += 1 + return img + + def get_frame(self): + return self.get_frame_from_stream() + + + def close(self): + self.stream_reader.stdin.close() + self.stream_reader.wait() + + +class VideoWriter: + def __init__(self, video_save_path, height, width, fps, audio): + if height > 2160: + print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', + 'We highly recommend to decrease the outscale(aka, -s).') + if audio is not None: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', + framerate=fps).output( + audio, + video_save_path, + pix_fmt='yuv420p', + vcodec='libx264', + loglevel='error', + acodec='copy').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + else: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', + framerate=fps).output( + video_save_path, pix_fmt='yuv420p', vcodec='libx264', + loglevel='error').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + + def write_frame(self, frame): + try: + frame = frame.astype(np.uint8).tobytes() + self.stream_writer.stdin.write(frame) + except BrokenPipeError: + print('Please re-install ffmpeg and libx264 by running\n', + '\t$ conda install -c conda-forge ffmpeg\n', + '\t$ conda install -c conda-forge x264') + sys.exit(0) + + def close(self): + self.stream_writer.stdin.close() + self.stream_writer.wait() \ No newline at end of file diff --git a/blissful_tuner/esrgan/model.py b/blissful_tuner/esrgan/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d12e12d319d7e68f9f609b6811026d308c645066 --- /dev/null +++ b/blissful_tuner/esrgan/model.py @@ -0,0 +1,414 @@ +# Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +from typing import Any, cast, Dict, List, Union + +import torch +from torch import nn, Tensor +from torch.nn import functional as F_torch +from torchvision import models, transforms +from torchvision.models.feature_extraction import create_feature_extractor + +__all__ = [ + "DiscriminatorForVGG", "RRDBNet", "ContentLoss", + "discriminator_for_vgg", "rrdbnet_x2", "rrdbnet_x4", "rrdbnet_x8" +] + +feature_extractor_net_cfgs: Dict[str, List[Union[str, int]]] = { + "vgg11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "vgg13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "vgg16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], + "vgg19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], +} + + +def _make_layers(net_cfg_name: str, batch_norm: bool = False) -> nn.Sequential: + net_cfg = feature_extractor_net_cfgs[net_cfg_name] + layers: nn.Sequential[nn.Module] = nn.Sequential() + in_channels = 3 + for v in net_cfg: + if v == "M": + layers.append(nn.MaxPool2d((2, 2), (2, 2))) + else: + v = cast(int, v) + conv2d = nn.Conv2d(in_channels, v, (3, 3), (1, 1), (1, 1)) + if batch_norm: + layers.append(conv2d) + layers.append(nn.BatchNorm2d(v)) + layers.append(nn.ReLU(True)) + else: + layers.append(conv2d) + layers.append(nn.ReLU(True)) + in_channels = v + + return layers + + +class _FeatureExtractor(nn.Module): + def __init__( + self, + net_cfg_name: str = "vgg19", + batch_norm: bool = False, + num_classes: int = 1000) -> None: + super(_FeatureExtractor, self).__init__() + self.features = _make_layers(net_cfg_name, batch_norm) + + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(0.5), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(0.5), + nn.Linear(4096, num_classes), + ) + + # Initialize neural network weights + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.BatchNorm2d): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.Linear): + nn.init.normal_(module.weight, 0, 0.01) + nn.init.constant_(module.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + # Support torch.script function + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + + return x + + +class RRDBNet(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + channels: int = 64, + growth_channels: int = 32, + num_rrdb: int = 23, + upscale: int = 4, + ) -> None: + super(RRDBNet, self).__init__() + self.upscale = upscale + + # The first layer of convolutional layer. + self.conv1 = nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1)) + + # Feature extraction backbone network. + trunk = [] + for _ in range(num_rrdb): + trunk.append(_ResidualResidualDenseBlock(channels, growth_channels)) + self.trunk = nn.Sequential(*trunk) + + # After the feature extraction network, reconnect a layer of convolutional blocks. + self.conv2 = nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)) + + # Upsampling convolutional layer. + if upscale == 2: + self.upsampling1 = nn.Sequential( + nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), + nn.LeakyReLU(0.2, True) + ) + if upscale == 4: + self.upsampling1 = nn.Sequential( + nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), + nn.LeakyReLU(0.2, True) + ) + self.upsampling2 = nn.Sequential( + nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), + nn.LeakyReLU(0.2, True) + ) + if upscale == 8: + self.upsampling1 = nn.Sequential( + nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), + nn.LeakyReLU(0.2, True) + ) + self.upsampling2 = nn.Sequential( + nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), + nn.LeakyReLU(0.2, True) + ) + self.upsampling3 = nn.Sequential( + nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), + nn.LeakyReLU(0.2, True) + ) + + # Reconnect a layer of convolution block after upsampling. + self.conv3 = nn.Sequential( + nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), + nn.LeakyReLU(0.2, True) + ) + + # Output layer. + self.conv4 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1)) + + # Initialize all layer + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight) + module.weight.data *= 0.2 + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + # The model should be defined in the Torch.script method. + def _forward_impl(self, x: Tensor) -> Tensor: + conv1 = self.conv1(x) + x = self.trunk(conv1) + x = self.conv2(x) + x = torch.add(x, conv1) + + if self.upscale == 2: + x = self.upsampling1(F_torch.interpolate(x, scale_factor=2, mode="nearest")) + if self.upscale == 4: + x = self.upsampling1(F_torch.interpolate(x, scale_factor=2, mode="nearest")) + x = self.upsampling2(F_torch.interpolate(x, scale_factor=2, mode="nearest")) + if self.upscale == 8: + x = self.upsampling1(F_torch.interpolate(x, scale_factor=2, mode="nearest")) + x = self.upsampling2(F_torch.interpolate(x, scale_factor=2, mode="nearest")) + x = self.upsampling3(F_torch.interpolate(x, scale_factor=2, mode="nearest")) + + x = self.conv3(x) + x = self.conv4(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +class _ResidualDenseBlock(nn.Module): + """Achieves densely connected convolutional layers. + `Densely Connected Convolutional Networks ` paper. + + Args: + channels (int): The number of channels in the input image. + growth_channels (int): The number of channels that increase in each layer of convolution. + """ + + def __init__(self, channels: int, growth_channels: int) -> None: + super(_ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(channels + growth_channels * 0, growth_channels, (3, 3), (1, 1), (1, 1)) + self.conv2 = nn.Conv2d(channels + growth_channels * 1, growth_channels, (3, 3), (1, 1), (1, 1)) + self.conv3 = nn.Conv2d(channels + growth_channels * 2, growth_channels, (3, 3), (1, 1), (1, 1)) + self.conv4 = nn.Conv2d(channels + growth_channels * 3, growth_channels, (3, 3), (1, 1), (1, 1)) + self.conv5 = nn.Conv2d(channels + growth_channels * 4, channels, (3, 3), (1, 1), (1, 1)) + + self.leaky_relu = nn.LeakyReLU(0.2, True) + self.identity = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out1 = self.leaky_relu(self.conv1(x)) + out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1))) + out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1))) + out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1))) + out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1))) + + x = torch.mul(out5, 0.2) + x = torch.add(x, identity) + + return x + + +class _ResidualResidualDenseBlock(nn.Module): + """Multi-layer residual dense convolution block. + + Args: + channels (int): The number of channels in the input image. + growth_channels (int): The number of channels that increase in each layer of convolution. + """ + + def __init__(self, channels: int, growth_channels: int) -> None: + super(_ResidualResidualDenseBlock, self).__init__() + self.rdb1 = _ResidualDenseBlock(channels, growth_channels) + self.rdb2 = _ResidualDenseBlock(channels, growth_channels) + self.rdb3 = _ResidualDenseBlock(channels, growth_channels) + + def forward(self, x: Tensor) -> Tensor: + identity = x + + x = self.rdb1(x) + x = self.rdb2(x) + x = self.rdb3(x) + + x = torch.mul(x, 0.2) + x = torch.add(x, identity) + + return x + + +class DiscriminatorForVGG(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + channels: int = 64, + ) -> None: + super(DiscriminatorForVGG, self).__init__() + self.features = nn.Sequential( + # input size. (3) x 128 x 128 + nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1), bias=True), + nn.LeakyReLU(0.2, True), + # state size. (64) x 64 x 64 + nn.Conv2d(channels, channels, (4, 4), (2, 2), (1, 1), bias=False), + nn.BatchNorm2d(channels), + nn.LeakyReLU(0.2, True), + nn.Conv2d(channels, int(2 * channels), (3, 3), (1, 1), (1, 1), bias=False), + nn.BatchNorm2d(int(2 * channels)), + nn.LeakyReLU(0.2, True), + # state size. (128) x 32 x 32 + nn.Conv2d(int(2 * channels), int(2 * channels), (4, 4), (2, 2), (1, 1), bias=False), + nn.BatchNorm2d(int(2 * channels)), + nn.LeakyReLU(0.2, True), + nn.Conv2d(int(2 * channels), int(4 * channels), (3, 3), (1, 1), (1, 1), bias=False), + nn.BatchNorm2d(int(4 * channels)), + nn.LeakyReLU(0.2, True), + # state size. (256) x 16 x 16 + nn.Conv2d(int(4 * channels), int(4 * channels), (4, 4), (2, 2), (1, 1), bias=False), + nn.BatchNorm2d(int(4 * channels)), + nn.LeakyReLU(0.2, True), + nn.Conv2d(int(4 * channels), int(8 * channels), (3, 3), (1, 1), (1, 1), bias=False), + nn.BatchNorm2d(int(8 * channels)), + nn.LeakyReLU(0.2, True), + # state size. (512) x 8 x 8 + nn.Conv2d(int(8 * channels), int(8 * channels), (4, 4), (2, 2), (1, 1), bias=False), + nn.BatchNorm2d(int(8 * channels)), + nn.LeakyReLU(0.2, True), + nn.Conv2d(int(8 * channels), int(8 * channels), (3, 3), (1, 1), (1, 1), bias=False), + nn.BatchNorm2d(int(8 * channels)), + nn.LeakyReLU(0.2, True), + # state size. (512) x 4 x 4 + nn.Conv2d(int(8 * channels), int(8 * channels), (4, 4), (2, 2), (1, 1), bias=False), + nn.BatchNorm2d(int(8 * channels)), + nn.LeakyReLU(0.2, True) + ) + + self.classifier = nn.Sequential( + nn.Linear(int(8 * channels) * 4 * 4, 100), + nn.LeakyReLU(0.2, True), + nn.Linear(100, out_channels) + ) + + def forward(self, x: Tensor) -> Tensor: + out = self.features(x) + out = torch.flatten(out, 1) + out = self.classifier(out) + + return out + + +class ContentLoss(nn.Module): + """Constructs a content loss function based on the VGG19 network. + Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image. + + Paper reference list: + -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network ` paper. + -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks ` paper. + -`Perceptual Extreme Super Resolution Network with Receptive Field Block ` paper. + + """ + + def __init__( + self, + net_cfg_name: str = "vgg19", + batch_norm: bool = False, + num_classes: int = 1000, + model_weights_path: str = "", + feature_nodes: list = None, + feature_normalize_mean: list = None, + feature_normalize_std: list = None, + ) -> None: + super(ContentLoss, self).__init__() + # Define the feature extraction model + model = _FeatureExtractor(net_cfg_name, batch_norm, num_classes) + # Load the pre-trained model + if model_weights_path == "": + model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1) + elif model_weights_path is not None and os.path.exists(model_weights_path): + checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage) + if "state_dict" in checkpoint.keys(): + model.load_state_dict(checkpoint["state_dict"]) + else: + model.load_state_dict(checkpoint) + else: + raise FileNotFoundError("Model weight file not found") + # Extract the output of the feature extraction layer + self.feature_extractor = create_feature_extractor(model, feature_nodes) + # Select the specified layers as the feature extraction layer + self.feature_extractor_nodes = feature_nodes + # input normalization + self.normalize = transforms.Normalize(feature_normalize_mean, feature_normalize_std) + # Freeze model parameters without derivatives + for model_parameters in self.feature_extractor.parameters(): + model_parameters.requires_grad = False + self.feature_extractor.eval() + + def forward(self, sr_tensor: Tensor, gt_tensor: Tensor) -> [Tensor]: + assert sr_tensor.size() == gt_tensor.size(), "Two tensor must have the same size" + device = sr_tensor.device + + losses = [] + # input normalization + sr_tensor = self.normalize(sr_tensor) + gt_tensor = self.normalize(gt_tensor) + + # Get the output of the feature extraction layer + sr_feature = self.feature_extractor(sr_tensor) + gt_feature = self.feature_extractor(gt_tensor) + + # Compute feature loss + for i in range(len(self.feature_extractor_nodes)): + losses.append(F_torch.l1_loss(sr_feature[self.feature_extractor_nodes[i]], + gt_feature[self.feature_extractor_nodes[i]])) + + losses = torch.Tensor([losses]).to(device) + + return losses + + +def rrdbnet_x2(**kwargs: Any) -> RRDBNet: + model = RRDBNet(upscale=2, **kwargs) + + return model + + +def rrdbnet_x4(**kwargs: Any) -> RRDBNet: + model = RRDBNet(upscale=4, **kwargs) + + return model + + +def rrdbnet_x8(**kwargs: Any) -> RRDBNet: + model = RRDBNet(upscale=8, **kwargs) + + return model + + +def discriminator_for_vgg(**kwargs) -> DiscriminatorForVGG: + model = DiscriminatorForVGG(**kwargs) + + return model diff --git a/blissful_tuner/facefix.py b/blissful_tuner/facefix.py new file mode 100644 index 0000000000000000000000000000000000000000..a40653d027cd827a91d0cc5effe807269162c5c1 --- /dev/null +++ b/blissful_tuner/facefix.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Face restoration for Blissful Tuner Extension + +License: Apache 2.0 +Created on Wed Apr 23 10:19:19 2025 +@author: blyss +""" +from rich.traceback import install as install_rich_tracebacks +from tqdm import tqdm +from gfpgan import GFPGANer +import torch +from torchvision.transforms.functional import normalize +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from codeformer.basicsr.utils.registry import ARCH_REGISTRY +from basicsr.utils import img2tensor, tensor2img +from video_processing_common import BlissfulVideoProcessor, setup_parser_video_common, set_seed +from utils import BlissfulLogger +logger = BlissfulLogger(__name__, "#8e00ed") +install_rich_tracebacks() + + +def main(): + parser = setup_parser_video_common(description="Restore faces with GFPGAN or CODEFORMER") + parser.add_argument("--only_center", action="store_true", help="Only process center face") + parser.add_argument("--weight", type=float, default=0.5, help="Strength of GFPGAN or CodeFormer power") + parser.add_argument('-s', '--upscale', type=float, default=1, help='The final upsampling scale of the image. Default: 1') + parser.add_argument('--detection_model', type=str, default='retinaface_resnet50', help='Face detector. Default: retinaface_resnet50') + parser.add_argument("--mode", type=str, default="gfpgan", help="Mode - either gfpgan or codeformer") + device = "cuda" if torch.cuda.is_available() else "cpu" + args = parser.parse_args() + logger.info("Loading input...") + VideoProcessor = BlissfulVideoProcessor(device, torch.float32) + VideoProcessor.prepare_files_and_path(args.input, args.output, args.mode.upper()) + frames, fps, _, _ = VideoProcessor.load_frames() + set_seed(args.seed) + if args.mode.lower() == "gfpgan": + restorer = GFPGANer( + model_path=args.model, + upscale=args.upscale, + arch='clean', + channel_multiplier=2, + bg_upsampler=None) + # ------------------------ restore ------------------------ + for frame in tqdm(frames): + # restore faces and background if necessary + _, _, restored_frame = restorer.enhance( + frame, + has_aligned=False, + only_center_face=args.only_center, + paste_back=True, + weight=args.weight) + VideoProcessor.write_np_or_tensor_to_png(restored_frame) + del restored_frame + elif args.mode.lower() == "codeformer": + net = ARCH_REGISTRY.get('CodeFormer')( + dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, + connect_list=['32', '64', '128', '256']).to(device) + checkpoint = torch.load(args.model)['params_ema'] + net.load_state_dict(checkpoint) + net.eval() + + face_helper = FaceRestoreHelper( + args.upscale, + face_size=512, + crop_ratio=(1, 1), + det_model=args.detection_model, + save_ext='png', + use_parse=True, + device=device) + + for frame in tqdm(frames): + # clean all the intermediate results to process the next image + face_helper.clean_all() + face_helper.read_image(frame) + # get face landmarks for each face + _ = face_helper.get_face_landmarks_5( + only_center_face=args.only_center, resize=640, eye_dist_threshold=5) + # align and warp each face + face_helper.align_warp_face() + # face restoration for each cropped face + for cropped_face in face_helper.cropped_faces: + # prepare data + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + + try: + with torch.no_grad(): + output = net(cropped_face_t, w=args.weight, adain=True)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except Exception as error: + logger.info(f'\tFailed inference for CodeFormer: {error}') + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + restored_face = restored_face.astype('uint8') + face_helper.add_restored_face(restored_face) + + face_helper.get_inverse_affine(None) + restored_img = face_helper.paste_faces_to_input_image() + VideoProcessor.write_np_or_tensor_to_png(restored_img) + del restored_img + + VideoProcessor.write_buffered_frames_to_output(fps, args.keep_pngs) + + +if __name__ == '__main__': + main() diff --git a/blissful_tuner/fp8_optimization.py b/blissful_tuner/fp8_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..76128da0c5991ce1bb31e300bf303ae1cb65387f --- /dev/null +++ b/blissful_tuner/fp8_optimization.py @@ -0,0 +1,339 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from blissful_tuner.utils import BlissfulLogger + +from tqdm import tqdm + +logger = BlissfulLogger(__name__, "#8e00ed") + + + +# based on ComfyUI's and MinusZoneAI's fp8_linear optimization +def fp8_linear_forward(cls, original_dtype, input): + weight_dtype = cls.weight.dtype + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if len(input.shape) == 3: + target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn + inn = input.reshape(-1, input.shape[2]).to(target_dtype) + w = cls.weight.t() + + scale = torch.ones((1), device=input.device, dtype=torch.float32) + bias = cls.bias.to(original_dtype) if cls.bias is not None else None + + if bias is not None: + o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale) + else: + o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale) + + if isinstance(o, tuple): + o = o[0] + + return o.reshape((-1, input.shape[1], cls.weight.shape[0])) + else: + return cls.original_forward(input.to(original_dtype)) + else: + return cls.original_forward(input) + + +def convert_fp8_linear(module, original_dtype, params_to_keep={}): + setattr(module, "fp8_matmul_enabled", True) + + for name, module in module.named_modules(): + if not any(keyword in name for keyword in params_to_keep): + if isinstance(module, nn.Linear): + original_forward = module.forward + setattr(module, "original_forward", original_forward) + setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) + + +# Below has been ported from https://github.com/kohya-ss/musubi-tuner/ +def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1): + """ + Calculate the maximum representable value in FP8 format. + Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). + + Args: + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + sign_bits (int): Number of sign bits (0 or 1) + + Returns: + float: Maximum value representable in FP8 format + """ + assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8" + if exp_bits == 5: + return 57344 + # Calculate exponent bias + bias = 2 ** (exp_bits - 1) - 1 + + # Calculate maximum mantissa value + mantissa_max = 1.0 + for i in range(mantissa_bits - 1): + mantissa_max += 2 ** -(i + 1) + + # Calculate maximum value + max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias)) + + return max_value + + +def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None): + """ + Quantize a tensor to FP8 format. + + Args: + tensor (torch.Tensor): Tensor to quantize + scale (float or torch.Tensor): Scale factor + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + sign_bits (int): Number of sign bits + + Returns: + tuple: (quantized_tensor, scale_factor) + """ + # Create scaled tensor + scaled_tensor = tensor / scale + + # Calculate FP8 parameters + bias = 2 ** (exp_bits - 1) - 1 + + if max_value is None: + # Calculate max and min values + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits) + min_value = -max_value if sign_bits > 0 else 0.0 + + # Clamp tensor to range + clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value) + + # Quantization process + abs_values = torch.abs(clamped_tensor) + nonzero_mask = abs_values > 0 + + # Calculate log scales (only for non-zero elements) + log_scales = torch.zeros_like(clamped_tensor) + if nonzero_mask.any(): + log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach() + + # Limit log scales and calculate quantization factor + log_scales = torch.clamp(log_scales, min=1.0) + quant_factor = 2.0 ** (log_scales - mantissa_bits - bias) + + # Quantize and dequantize + quantized = torch.round(clamped_tensor / quant_factor) * quant_factor + + return quantized, scale + + +def optimize_state_dict_with_fp8( + state_dict, + calc_device, + target_layer_keys=None, + exclude_layer_keys=None, + exp_bits=4, + mantissa_bits=3, + move_to_device=False +): + """ + Optimize Linear layer weights in a model's state dict to FP8 format + + Args: + state_dict (dict): State dict to optimize, replaced in-place + calc_device (str): Device to quantize tensors on + target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers) + exclude_layer_keys (list, optional): Layer key patterns to exclude + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + move_to_device (bool): Move optimized tensors to the calculating device + + Returns: + dict: FP8 optimized state dict with FP8 quantized weights and corresponding scale values + """ + if exp_bits == 4 and mantissa_bits == 3: + fp8_dtype = torch.float8_e4m3fn + elif exp_bits == 5 and mantissa_bits == 2: + fp8_dtype = torch.float8_e5m2 + else: + raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}") + + # Calculate FP8 max value + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) + min_value = -max_value # this function supports only signed FP8 + + optimized_count = 0 + average_quantization_error = 0.0 + + # Find target keys for Linear layer weights + target_state_dict_keys = [] + for key in state_dict.keys(): + is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight") + is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys) + is_target = is_target and not is_excluded + + if is_target and isinstance(state_dict[key], torch.Tensor): + target_state_dict_keys.append(key) + + # Process each target weight tensor + for key in tqdm(target_state_dict_keys): + value = state_dict[key] + + # Save original device and dtype + original_device = value.device + original_dtype = value.dtype + + # Move to calculation device if provided + if calc_device is not None: + value = value.to(calc_device) + + # Calculate scale factor based on the maximum absolute value in the tensor + scale = torch.max(torch.abs(value.flatten())) / max_value + + # Quantize weight to FP8 format + quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value) + + # Otherwise, store the quantized weight and corresponding scale value. + fp8_key = key # Use the original key for the quantized weight + scale_key = key.replace(".weight", ".scale_weight") + + quantized_weight = quantized_weight.to(fp8_dtype) + + # Reconstruct tensor by scaling back up + reconstructed = quantized_weight.to(original_dtype) * scale + + # Calculate the mean relative error (in percent) + average_quantization_error += (torch.mean(torch.abs(value - reconstructed)) / (torch.mean(torch.abs(value)) + 1e-8)) * 100 # Adding a small epsilon to avoid division by zero issues if necessary. + + if not move_to_device: + quantized_weight = quantized_weight.to(original_device) + + scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device) + + state_dict[fp8_key] = quantized_weight + state_dict[scale_key] = scale_tensor + + optimized_count += 1 + + # Optionally free memory on the calculation device every 16 optimizations. + if calc_device is not None and optimized_count % 16 == 0: + torch.cuda.empty_cache() + if optimized_count > 0: + average_quantization_error /= optimized_count + logger.info(f"Number of optimized Linear layers: {optimized_count}") + logger.info(f"Mean quantization error: {average_quantization_error:.2f}%") + else: + logger.info("optimize_state_dict_with_fp8 didn't optimize any layers! Maybe check your include/exclude keys?") + return state_dict + + +def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None): + """ + Patched forward method for Linear layers with FP8 weights. + + Args: + self: Linear layer instance + x (torch.Tensor): Input tensor + use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) + max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor. + + Returns: + torch.Tensor: Result of linear transformation + """ + if use_scaled_mm and x.ndim == 3: + input_dtype = x.dtype + original_weight_dtype = self.scale_weight.dtype + weight_dtype = self.weight.dtype + assert weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2], "Only FP8 E4M3FN/E5M2 format is supported" + target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn + e_bits = 5 if target_dtype == torch.float8_e5m2 else 4 + m_bits = 2 if target_dtype == torch.float8_e5m2 else 3 + + if max_value is None: + # no input quantization + scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device) + else: + # calculate scale factor for input tensor + scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32) + + # quantize input tensor to FP8: this seems to consume a lot of memory + x, _ = quantize_tensor_to_fp8(x, scale_x, e_bits, m_bits, 1, max_value, -max_value) + + original_shape = x.shape + x = x.reshape(-1, x.shape[2]).to(target_dtype) + + weight = self.weight.t() + scale_weight = self.scale_weight.to(torch.float32) + + if self.bias is not None: + # float32 is not supported with bias in scaled_mm + o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight) + else: + o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight) + + return o.reshape(original_shape[0], original_shape[1], -1) + + else: + # Dequantize the weight + original_dtype = self.scale_weight.dtype + dequantized_weight = self.weight.to(original_dtype) * self.scale_weight + + # Perform linear transformation + if self.bias is not None: + output = F.linear(x, dequantized_weight, self.bias) + else: + output = F.linear(x, dequantized_weight) + + return output + + +def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False, scale_input_tensor=None): + """ + Apply monkey patching to a model using FP8 optimized state dict. + + Args: + model (nn.Module): Model instance to patch + optimized_state_dict (dict): FP8 optimized state dict + use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) + + Returns: + nn.Module: The patched model (same instance, modified in-place) + """ + max_value = None + if use_scaled_mm: + setattr(model, "fp8_matmul_enabled", True) + if scale_input_tensor is not None: + max_value = calculate_fp8_maxval(4, 3) if "e4m3" in scale_input_tensor else calculate_fp8_maxval(5, 2) if "e5m2" in scale_input_tensor else None + + # Find all scale keys to identify FP8-optimized layers + scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")] + + # Enumerate patched layers + patched_module_paths = set() + for scale_key in scale_keys: + # Extract module path from scale key (remove .scale_weight) + module_path = scale_key.rsplit(".scale_weight", 1)[0] + patched_module_paths.add(module_path) + + patched_count = 0 + + # Apply monkey patch to each layer with FP8 weights + for name, module in model.named_modules(): + # Check if this module has a corresponding scale_weight + has_scale = name in patched_module_paths + + # Apply patch if it's a Linear layer with FP8 scale + if isinstance(module, nn.Linear) and has_scale: + # register the scale_weight as a buffer to load the state_dict + module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype)) + + # Create a new forward method with the patched version. + def new_forward(self, x): + return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value) + + # Bind method to module + module.forward = new_forward.__get__(module, type(module)) + + patched_count += 1 + + logger.info(f"Number of monkey-patched Linear layers: {patched_count}") + return model diff --git a/blissful_tuner/gfpgan/LICENSE b/blissful_tuner/gfpgan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1e947f96094dfcf7a0c3500e6a60ac124af00258 --- /dev/null +++ b/blissful_tuner/gfpgan/LICENSE @@ -0,0 +1,353 @@ +THIS FOLDER AND SUBFOLDERS (all GFPGAN related code and files) LICENSED AS BELOW + +Tencent is pleased to support the open source community by making GFPGAN available. + +Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. + +GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below. + + +Terms of the Apache License Version 2.0: +--------------------------------------------- +Apache License + +Version 2.0, January 2004 + +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +1. Definitions. + +“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License. + +“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.” + +“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of this License; and + +You must cause any modified files to carry prominent notices stating that You changed the files; and + +You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + +If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + + + +Other dependencies and licenses: + + +Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein: +--------------------------------------------- +1. basicsr +Copyright 2018-2020 BasicSR Authors + + +This BasicSR project is released under the Apache 2.0 license. + +A copy of Apache 2.0 is included in this file. + +StyleGAN2 +The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch. +The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license. +DFDNet +The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. + +Terms of the Nvidia License: +--------------------------------------------- + +1. Definitions + +"Licensor" means any person or entity that distributes its Work. + +"Software" means the original work of authorship made available under +this License. + +"Work" means the Software and any additions to or derivative works of +the Software that are made available under this License. + +"Nvidia Processors" means any central processing unit (CPU), graphics +processing unit (GPU), field-programmable gate array (FPGA), +application-specific integrated circuit (ASIC) or any combination +thereof designed, made, sold, or provided by Nvidia or its affiliates. + +The terms "reproduce," "reproduction," "derivative works," and +"distribution" have the meaning as provided under U.S. copyright law; +provided, however, that for the purposes of this License, derivative +works shall not include works that remain separable from, or merely +link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are "made available" under this License +by including in or with the Work either (a) a copyright notice +referencing the applicability of this License to the Work, or (b) a +copy of this License. + +2. License Grants + + 2.1 Copyright Grant. Subject to the terms and conditions of this + License, each Licensor grants to you a perpetual, worldwide, + non-exclusive, royalty-free, copyright license to reproduce, + prepare derivative works of, publicly display, publicly perform, + sublicense and distribute its Work and any resulting derivative + works in any form. + +3. Limitations + + 3.1 Redistribution. You may reproduce or distribute the Work only + if (a) you do so under this License, (b) you include a complete + copy of this License with your distribution, and (c) you retain + without modification any copyright, patent, trademark, or + attribution notices that are present in the Work. + + 3.2 Derivative Works. You may specify that additional or different + terms apply to the use, reproduction, and distribution of your + derivative works of the Work ("Your Terms") only if (a) Your Terms + provide that the use limitation in Section 3.3 applies to your + derivative works, and (b) you identify the specific derivative + works that are subject to Your Terms. Notwithstanding Your Terms, + this License (including the redistribution requirements in Section + 3.1) will continue to apply to the Work itself. + + 3.3 Use Limitation. The Work and any derivative works thereof only + may be used or intended for use non-commercially. The Work or + derivative works thereof may be used or intended for use by Nvidia + or its affiliates commercially or non-commercially. As used herein, + "non-commercially" means for research or evaluation purposes only. + + 3.4 Patent Claims. If you bring or threaten to bring a patent claim + against any Licensor (including any claim, cross-claim or + counterclaim in a lawsuit) to enforce any patents that you allege + are infringed by any Work, then your rights under this License from + such Licensor (including the grants in Sections 2.1 and 2.2) will + terminate immediately. + + 3.5 Trademarks. This License does not grant any rights to use any + Licensor's or its affiliates' names, logos, or trademarks, except + as necessary to reproduce the notices described in this License. + + 3.6 Termination. If you violate any term of this License, then your + rights under this License (including the grants in Sections 2.1 and + 2.2) will terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +MIT License + +Copyright (c) 2019 Kim Seonghyeon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + + +Open Source Software licensed under the BSD 3-Clause license: +--------------------------------------------- +1. torchvision +Copyright (c) Soumith Chintala 2016, +All rights reserved. + +2. torch +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + + +Terms of the BSD 3-Clause License: +--------------------------------------------- +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: +--------------------------------------------- +1. numpy +Copyright (c) 2005-2020, NumPy Developers. +All rights reserved. + +A copy of BSD 3-Clause License is included in this file. + +The NumPy repository and source distributions bundle several libraries that are +compatibly licensed. We list these here. + +Name: Numpydoc +Files: doc/sphinxext/numpydoc/* +License: BSD-2-Clause + For details, see doc/sphinxext/LICENSE.txt + +Name: scipy-sphinx-theme +Files: doc/scipy-sphinx-theme/* +License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0 + For details, see doc/scipy-sphinx-theme/LICENSE.txt + +Name: lapack-lite +Files: numpy/linalg/lapack_lite/* +License: BSD-3-Clause + For details, see numpy/linalg/lapack_lite/LICENSE.txt + +Name: tempita +Files: tools/npy_tempita/* +License: MIT + For details, see tools/npy_tempita/license.txt + +Name: dragon4 +Files: numpy/core/src/multiarray/dragon4.c +License: MIT + For license text, see numpy/core/src/multiarray/dragon4.c + + + +Open Source Software licensed under the MIT license: +--------------------------------------------- +1. facexlib +Copyright (c) 2020 Xintao Wang + +2. opencv-python +Copyright (c) Olli-Pekka Heinisuo +Please note that only files in cv2 package are used. + + +Terms of the MIT License: +--------------------------------------------- +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + +Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein: +--------------------------------------------- +1. tqdm +Copyright (c) 2013 noamraph + +`tqdm` is a product of collaborative work. +Unless otherwise stated, all authors (see commit logs) retain copyright +for their respective work, and release the work under the MIT licence +(text below). + +Exceptions or notable authors are listed below +in reverse chronological order: + +* files: * + MPLv2.0 2015-2020 (c) Casper da Costa-Luis + [casperdcl](https://github.com/casperdcl). +* files: tqdm/_tqdm.py + MIT 2016 (c) [PR #96] on behalf of Google Inc. +* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore + MIT 2013 (c) Noam Yorav-Raphael, original author. + +[PR #96]: tqdm/tqdm#96 + + +Mozilla Public Licence (MPL) v. 2.0 - Exhibit A +----------------------------------------------- + +This Source Code Form is subject to the terms of the +Mozilla Public License, v. 2.0. +If a copy of the MPL was not distributed with this file, +You can obtain one at https://mozilla.org/MPL/2.0/. + + +MIT License (MIT) +----------------- + +Copyright (c) 2013 noamraph + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/blissful_tuner/gfpgan/__init__.py b/blissful_tuner/gfpgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94daaeebce5604d61999f0b1b354b9a9e299b991 --- /dev/null +++ b/blissful_tuner/gfpgan/__init__.py @@ -0,0 +1,7 @@ +# flake8: noqa +from .archs import * +from .data import * +from .models import * +from .utils import * + +# from .version import * diff --git a/blissful_tuner/gfpgan/archs/__init__.py b/blissful_tuner/gfpgan/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bec5f17bfa38729b55f57cae8e40c27310db2b7b --- /dev/null +++ b/blissful_tuner/gfpgan/archs/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import arch modules for registry +# scan all the files that end with '_arch.py' under the archs folder +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames] diff --git a/blissful_tuner/gfpgan/archs/arcface_arch.py b/blissful_tuner/gfpgan/archs/arcface_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d3bd97f83334450bd78ad2c3b9871102a56b70 --- /dev/null +++ b/blissful_tuner/gfpgan/archs/arcface_arch.py @@ -0,0 +1,245 @@ +import torch.nn as nn +from basicsr.utils.registry import ARCH_REGISTRY + + +def conv3x3(inplanes, outplanes, stride=1): + """A simple wrapper for 3x3 convolution with padding. + + Args: + inplanes (int): Channel number of inputs. + outplanes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + """ + return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + """Basic residual block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IRBlock(nn.Module): + """Improved residual block (IR Block) used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 4 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + """The squeeze-and-excitation block (SEBlock) used in the IRBlock. + + Args: + channel (int): Channel number of inputs. + reduction (int): Channel reduction ration. Default: 16. + """ + + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +@ARCH_REGISTRY.register() +class ResNetArcFace(nn.Module): + """ArcFace with ResNet architectures. + + Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. + + Args: + block (str): Block used in the ArcFace architecture. + layers (tuple(int)): Block numbers in each layer. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + + def __init__(self, block, layers, use_se=True): + if block == 'IRBlock': + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1d(512) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + x = self.bn5(x) + + return x diff --git a/blissful_tuner/gfpgan/archs/gfpgan_bilinear_arch.py b/blissful_tuner/gfpgan/archs/gfpgan_bilinear_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..52e0de88de8543cf4afdc3988c4cdfc7c7060687 --- /dev/null +++ b/blissful_tuner/gfpgan/archs/gfpgan_bilinear_arch.py @@ -0,0 +1,312 @@ +import math +import random +import torch +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn + +from .gfpganv1_arch import ResUpBlock +from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, + StyleGAN2GeneratorBilinear) + + +class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for + deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + lr_mlp=0.01, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorBilinearSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + lr_mlp=lr_mlp, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorBilinearSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +@ARCH_REGISTRY.register() +class GFPGANBilinear(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for + deployment. It can be easily converted to the clean version: GFPGANv1Clean. + + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANBilinear, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels)) + in_channels = out_channels + + self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = EqualLinear( + channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + lr_mlp=lr_mlp, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) + self.condition_shift.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True): + """Forward function for GFPGANBilinear. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = self.conv_body_first(x) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + + feat = self.final_conv(feat) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs diff --git a/blissful_tuner/gfpgan/archs/gfpganv1_arch.py b/blissful_tuner/gfpgan/archs/gfpganv1_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf316200b386bc6aa7a8829655828f71893473b --- /dev/null +++ b/blissful_tuner/gfpgan/archs/gfpganv1_arch.py @@ -0,0 +1,439 @@ +import math +import random +import torch +from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, + StyleGAN2Generator) +from basicsr.ops.fused_act import FusedLeakyReLU +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + + +class StyleGAN2GeneratorSFT(StyleGAN2Generator): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be + applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1). + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ConvUpLayer(nn.Module): + """Convolutional upsampling layer. It uses bilinear upsampler + Conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + bias=True, + bias_init_val=0, + activate=True): + super(ConvUpLayer, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + # self.scale is used to scale the convolution weights, which is related to the common initializations. + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + + if bias and not activate: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + # activation + if activate: + if bias: + self.activation = FusedLeakyReLU(out_channels) + else: + self.activation = ScaledLeakyReLU(0.2) + else: + self.activation = None + + def forward(self, x): + # bilinear upsample + out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + # conv + out = F.conv2d( + out, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + # activation + if self.activation is not None: + out = self.activation(out) + return out + + +class ResUpBlock(nn.Module): + """Residual block with upsampling. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + """ + + def __init__(self, in_channels, out_channels): + super(ResUpBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True) + self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out + + +@ARCH_REGISTRY.register() +class GFPGANv1(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be + applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1). + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANv1, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel)) + in_channels = out_channels + + self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = EqualLinear( + channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) + self.condition_shift.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs): + """Forward function for GFPGANv1. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = self.conv_body_first(x) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + + feat = self.final_conv(feat) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs + + +@ARCH_REGISTRY.register() +class FacialComponentDiscriminator(nn.Module): + """Facial component (eyes, mouth, noise) discriminator used in GFPGAN. + """ + + def __init__(self): + super(FacialComponentDiscriminator, self).__init__() + # It now uses a VGG-style architectrue with fixed model size + self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False) + + def forward(self, x, return_feats=False, **kwargs): + """Forward function for FacialComponentDiscriminator. + + Args: + x (Tensor): Input images. + return_feats (bool): Whether to return intermediate features. Default: False. + """ + feat = self.conv1(x) + feat = self.conv3(self.conv2(feat)) + rlt_feats = [] + if return_feats: + rlt_feats.append(feat.clone()) + feat = self.conv5(self.conv4(feat)) + if return_feats: + rlt_feats.append(feat.clone()) + out = self.final_conv(feat) + + if return_feats: + return out, rlt_feats + else: + return out, None diff --git a/blissful_tuner/gfpgan/archs/gfpganv1_clean_arch.py b/blissful_tuner/gfpgan/archs/gfpganv1_clean_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..d6c2705876d18ccae69e0ef9e7678a456f86bb58 --- /dev/null +++ b/blissful_tuner/gfpgan/archs/gfpganv1_clean_arch.py @@ -0,0 +1,324 @@ +import math +import random +import torch +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + +from .stylegan2_clean_arch import StyleGAN2GeneratorClean + + +class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False): + super(StyleGAN2GeneratorCSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorCSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ResBlock(nn.Module): + """Residual block with bilinear upsampling/downsampling. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + mode (str): Upsampling/downsampling mode. Options: down | up. Default: down. + """ + + def __init__(self, in_channels, out_channels, mode='down'): + super(ResBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) + if mode == 'down': + self.scale_factor = 0.5 + elif mode == 'up': + self.scale_factor = 2 + + def forward(self, x): + out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) + # upsample/downsample + out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) + out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) + # skip + x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) + skip = self.skip(x) + out = out + skip + return out + + +@ARCH_REGISTRY.register() +class GFPGANv1Clean(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANv1Clean, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down')) + in_channels = out_channels + + self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up')) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorCSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) + self.condition_shift.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs): + """Forward function for GFPGANv1Clean. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs diff --git a/blissful_tuner/gfpgan/archs/restoreformer_arch.py b/blissful_tuner/gfpgan/archs/restoreformer_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..66cdff3e542061c27d6fdc3d32b8bb28011d95d6 --- /dev/null +++ b/blissful_tuner/gfpgan/archs/restoreformer_arch.py @@ -0,0 +1,658 @@ +"""Modified from https://github.com/wzhouxiff/RestoreFormer +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + # could possible replace this here + # #\start... + # find closest encodings + + min_value, min_encoding_indices = torch.min(d, dim=1) + + min_encoding_indices = min_encoding_indices.unsqueeze(1) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # .........\end + + # with: + # .........\start + # min_encoding_indices = torch.argmin(d, dim=1) + # z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:, None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +# pytorch_diffusion + derived encoder decoder +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode='nearest') + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class MultiHeadAttnBlock(nn.Module): + + def __init__(self, in_channels, head_size=1): + super().__init__() + self.in_channels = in_channels + self.head_size = head_size + self.att_size = in_channels // head_size + assert (in_channels % head_size == 0), 'The size of head should be divided by the number of channels.' + + self.norm1 = Normalize(in_channels) + self.norm2 = Normalize(in_channels) + + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.num = 0 + + def forward(self, x, y=None): + h_ = x + h_ = self.norm1(h_) + if y is None: + y = h_ + else: + y = self.norm2(y) + + q = self.q(y) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, self.head_size, self.att_size, h * w) + q = q.permute(0, 3, 1, 2) # b, hw, head, att + + k = k.reshape(b, self.head_size, self.att_size, h * w) + k = k.permute(0, 3, 1, 2) + + v = v.reshape(b, self.head_size, self.att_size, h * w) + v = v.permute(0, 3, 1, 2) + + q = q.transpose(1, 2) + v = v.transpose(1, 2) + k = k.transpose(1, 2).transpose(2, 3) + + scale = int(self.att_size)**(-0.5) + q.mul_(scale) + w_ = torch.matmul(q, k) + w_ = F.softmax(w_, dim=3) + + w_ = w_.matmul(v) + + w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att] + w_ = w_.view(b, h, w, -1) + w_ = w_.permute(0, 3, 1, 2) + + w_ = self.proj_out(w_) + + return x + w_ + + +class MultiHeadEncoder(nn.Module): + + def __init__(self, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks=2, + attn_resolutions=(16, ), + dropout=0.0, + resamp_with_conv=True, + in_channels=3, + resolution=512, + z_channels=256, + double_z=True, + enable_mid=True, + head_size=1, + **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.enable_mid = enable_mid + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(MultiHeadAttnBlock(block_in, head_size)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + if self.enable_mid: + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + hs = {} + # timestep embedding + temb = None + + # downsampling + h = self.conv_in(x) + hs['in'] = h + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h, temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + + if i_level != self.num_resolutions - 1: + # hs.append(h) + hs['block_' + str(i_level)] = h + h = self.down[i_level].downsample(h) + + # middle + # h = hs[-1] + if self.enable_mid: + h = self.mid.block_1(h, temb) + hs['block_' + str(i_level) + '_atten'] = h + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + hs['mid_atten'] = h + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # hs.append(h) + hs['out'] = h + + return hs + + +class MultiHeadDecoder(nn.Module): + + def __init__(self, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks=2, + attn_resolutions=(16, ), + dropout=0.0, + resamp_with_conv=True, + in_channels=3, + resolution=512, + z_channels=256, + give_pre_end=False, + enable_mid=True, + head_size=1, + **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.enable_mid = enable_mid + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + if self.enable_mid: + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(MultiHeadAttnBlock(block_in, head_size)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + if self.enable_mid: + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class MultiHeadDecoderTransformer(nn.Module): + + def __init__(self, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks=2, + attn_resolutions=(16, ), + dropout=0.0, + resamp_with_conv=True, + in_channels=3, + resolution=512, + z_channels=256, + give_pre_end=False, + enable_mid=True, + head_size=1, + **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.enable_mid = enable_mid + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + if self.enable_mid: + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(MultiHeadAttnBlock(block_in, head_size)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z, hs): + # assert z.shape[1:] == self.z_shape[1:] + # self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + if self.enable_mid: + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h, hs['mid_atten']) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, hs['block_' + str(i_level) + '_atten']) + # hfeature = h.clone() + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class RestoreFormer(nn.Module): + + def __init__(self, + n_embed=1024, + embed_dim=256, + ch=64, + out_ch=3, + ch_mult=(1, 2, 2, 4, 4, 8), + num_res_blocks=2, + attn_resolutions=(16, ), + dropout=0.0, + in_channels=3, + resolution=512, + z_channels=256, + double_z=False, + enable_mid=True, + fix_decoder=False, + fix_codebook=True, + fix_encoder=False, + head_size=8): + super(RestoreFormer, self).__init__() + + self.encoder = MultiHeadEncoder( + ch=ch, + out_ch=out_ch, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + dropout=dropout, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + double_z=double_z, + enable_mid=enable_mid, + head_size=head_size) + self.decoder = MultiHeadDecoderTransformer( + ch=ch, + out_ch=out_ch, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + dropout=dropout, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + enable_mid=enable_mid, + head_size=head_size) + + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25) + + self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + + if fix_decoder: + for _, param in self.decoder.named_parameters(): + param.requires_grad = False + for _, param in self.post_quant_conv.named_parameters(): + param.requires_grad = False + for _, param in self.quantize.named_parameters(): + param.requires_grad = False + elif fix_codebook: + for _, param in self.quantize.named_parameters(): + param.requires_grad = False + + if fix_encoder: + for _, param in self.encoder.named_parameters(): + param.requires_grad = False + + def encode(self, x): + + hs = self.encoder(x) + h = self.quant_conv(hs['out']) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info, hs + + def decode(self, quant, hs): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant, hs) + + return dec + + def forward(self, input, **kwargs): + quant, diff, info, hs = self.encode(input) + dec = self.decode(quant, hs) + + return dec, None diff --git a/blissful_tuner/gfpgan/archs/stylegan2_bilinear_arch.py b/blissful_tuner/gfpgan/archs/stylegan2_bilinear_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1342ee3c9a6b8f742fb76ce7d5b907cd39fbc350 --- /dev/null +++ b/blissful_tuner/gfpgan/archs/stylegan2_bilinear_arch.py @@ -0,0 +1,613 @@ +import math +import random +import torch +from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + + +class NormStyleCode(nn.Module): + + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class EqualLinear(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Size of each sample. + out_channels (int): Size of each output sample. + bias (bool): If set to ``False``, the layer will not learn an additive + bias. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + lr_mul (float): Learning rate multiplier. Default: 1. + activation (None | str): The activation after ``linear`` operation. + Supported: 'fused_lrelu', None. Default: None. + """ + + def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None): + super(EqualLinear, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lr_mul = lr_mul + self.activation = activation + if self.activation not in ['fused_lrelu', None]: + raise ValueError(f'Wrong activation value in EqualLinear: {activation}' + "Supported ones are: ['fused_lrelu', None].") + self.scale = (1 / math.sqrt(in_channels)) * lr_mul + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + if self.bias is None: + bias = None + else: + bias = self.bias * self.lr_mul + if self.activation == 'fused_lrelu': + out = F.linear(x, self.weight * self.scale) + out = fused_leaky_relu(out, bias) + else: + out = F.linear(x, self.weight * self.scale, bias=bias) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, bias={self.bias is not None})') + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. + Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + eps (float): A value added to the denominator for numerical stability. + Default: 1e-8. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8, + interpolation_mode='bilinear'): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + self.interpolation_mode = interpolation_mode + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + # modulation inside each modulated conv + self.modulation = EqualLinear( + num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None) + + self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.scale * self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + if self.sample_mode == 'upsample': + x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners) + elif self.sample_mode == 'downsample': + x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, ' + f'demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + """Style conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode='bilinear'): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + interpolation_mode=interpolation_mode) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.activate = FusedLeakyReLU(out_channels) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # activation (with bias) + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'): + super(ToRGB, self).__init__() + self.upsample = upsample + self.interpolation_mode = interpolation_mode + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + self.modulated_conv = ModulatedConv2d( + in_channels, + 3, + kernel_size=1, + num_style_feat=num_style_feat, + demodulate=False, + sample_mode=None, + interpolation_mode=interpolation_mode) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate( + skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2GeneratorBilinear(nn.Module): + """StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + lr_mlp=0.01, + narrow=1, + interpolation_mode='bilinear'): + super(StyleGAN2GeneratorBilinear, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.append( + EqualLinear( + num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp, + activation='fused_lrelu')) + self.style_mlp = nn.Sequential(*style_mlp_layers) + + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv( + channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode=interpolation_mode) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels['4'] + # noise + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample', + interpolation_mode=interpolation_mode)) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode=interpolation_mode)) + self.to_rgbs.append( + ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2Generator. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. + Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is + False. Default: True. + truncation (float): TODO. Default: 1. + truncation_latent (Tensor | None): TODO. Default: None. + inject_index (int | None): The injection index for mixing noise. + Default: None. + return_latents (bool): Whether to return style latents. + Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latent with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ScaledLeakyReLU(nn.Module): + """Scaled LeakyReLU. + + Args: + negative_slope (float): Negative slope. Default: 0.2. + """ + + def __init__(self, negative_slope=0.2): + super(ScaledLeakyReLU, self).__init__() + self.negative_slope = negative_slope + + def forward(self, x): + out = F.leaky_relu(x, negative_slope=self.negative_slope) + return out * math.sqrt(2) + + +class EqualConv2d(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. + Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. + Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0): + super(EqualConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + out = F.conv2d( + x, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size},' + f' stride={self.stride}, padding={self.padding}, ' + f'bias={self.bias is not None})') + + +class ConvLayer(nn.Sequential): + """Conv Layer used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Kernel size. + downsample (bool): Whether downsample by a factor of 2. + Default: False. + bias (bool): Whether with bias. Default: True. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + downsample=False, + bias=True, + activate=True, + interpolation_mode='bilinear'): + layers = [] + self.interpolation_mode = interpolation_mode + # downsample + if downsample: + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + + layers.append( + torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners)) + stride = 1 + self.padding = kernel_size // 2 + # conv + layers.append( + EqualConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias + and not activate)) + # activation + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channels)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super(ConvLayer, self).__init__(*layers) + + +class ResBlock(nn.Module): + """Residual block used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + """ + + def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'): + super(ResBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvLayer( + in_channels, + out_channels, + 3, + downsample=True, + interpolation_mode=interpolation_mode, + bias=True, + activate=True) + self.skip = ConvLayer( + in_channels, + out_channels, + 1, + downsample=True, + interpolation_mode=interpolation_mode, + bias=False, + activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out diff --git a/blissful_tuner/gfpgan/archs/stylegan2_clean_arch.py b/blissful_tuner/gfpgan/archs/stylegan2_clean_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2ee94e50401b95e4c9997adef5581d521d725f --- /dev/null +++ b/blissful_tuner/gfpgan/archs/stylegan2_clean_arch.py @@ -0,0 +1,368 @@ +import math +import random +import torch +from basicsr.archs.arch_util import default_init_weights +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + + +class NormStyleCode(nn.Module): + + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + eps (float): A value added to the denominator for numerical stability. Default: 1e-8. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + + # modulation inside each modulated conv + self.modulation = nn.Linear(num_style_feat, in_channels, bias=True) + # initialization + default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear') + + self.weight = nn.Parameter( + torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) / + math.sqrt(in_channels * kernel_size**2)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + # upsample or downsample if necessary + if self.sample_mode == 'upsample': + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + elif self.sample_mode == 'downsample': + x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + """Style conv used in StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + """ + + def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) * 2**0.5 # for conversion + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # add bias + out = out + self.bias + # activation + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB (image space) from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True): + super(ToRGB, self).__init__() + self.upsample = upsample + self.modulated_conv = ModulatedConv2d( + in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2GeneratorClean(nn.Module): + """Clean version of StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1): + super(StyleGAN2GeneratorClean, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.extend( + [nn.Linear(num_style_feat, num_style_feat, bias=True), + nn.LeakyReLU(negative_slope=0.2, inplace=True)]) + self.style_mlp = nn.Sequential(*style_mlp_layers) + # initialization + default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu') + + # channel list + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv( + channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels['4'] + # noise + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample')) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None)) + self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorClean. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None diff --git a/blissful_tuner/gfpgan/data/__init__.py b/blissful_tuner/gfpgan/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69fd9f9026407c4d185f86b122000485b06fd986 --- /dev/null +++ b/blissful_tuner/gfpgan/data/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import dataset modules for registry +# scan all the files that end with '_dataset.py' under the data folder +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames] diff --git a/blissful_tuner/gfpgan/data/ffhq_degradation_dataset.py b/blissful_tuner/gfpgan/data/ffhq_degradation_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..64e5755e1211f171cb2a883d47e8d253061f90aa --- /dev/null +++ b/blissful_tuner/gfpgan/data/ffhq_degradation_dataset.py @@ -0,0 +1,230 @@ +import cv2 +import math +import numpy as np +import os.path as osp +import torch +import torch.utils.data as data +from basicsr.data import degradations as degradations +from basicsr.data.data_util import paths_from_folder +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, + normalize) + + +@DATASET_REGISTRY.register() +class FFHQDegradationDataset(data.Dataset): + """FFHQ dataset for GFPGAN. + + It reads high resolution images, and then generate low-quality (LQ) images on-the-fly. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + io_backend (dict): IO backend type and other kwarg. + mean (list | tuple): Image mean. + std (list | tuple): Image std. + use_hflip (bool): Whether to horizontally flip. + Please see more options in the codes. + """ + + def __init__(self, opt): + super(FFHQDegradationDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.mean = opt['mean'] + self.std = opt['std'] + self.out_size = opt['out_size'] + + self.crop_components = opt.get('crop_components', False) # facial components + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions + + if self.crop_components: + # load component list from a pre-process pth files + self.components_list = torch.load(opt.get('component_path')) + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # disk backend: scan file list from a folder + self.paths = paths_from_folder(self.gt_folder) + + # degradation configurations + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.blur_sigma = opt['blur_sigma'] + self.downsample_range = opt['downsample_range'] + self.noise_range = opt['noise_range'] + self.jpeg_range = opt['jpeg_range'] + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob') + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + # to gray + self.gray_prob = opt.get('gray_prob') + + logger = get_root_logger() + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + self.color_jitter_shift /= 255. + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + def get_component_coordinates(self, index, status): + """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file""" + components_bbox = self.components_list[f'{index:08d}'] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0] + components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0] + + # get coordinates + locations = [] + for part in ['left_eye', 'right_eye', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations.append(loc) + return locations + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + h, w, _ = img_gt.shape + + # get facial component coordinates + if self.crop_components: + locations = self.get_component_coordinates(index, status) + loc_left_eye, loc_right_eye, loc_mouth = locations + + # ------------------------ generate lq image ------------------------ # + # blur + kernel = degradations.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + noise_range=None) + img_lq = cv2.filter2D(img_gt, -1, kernel) + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) + # noise + if self.noise_range is not None: + img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range) + # jpeg compression + if self.jpeg_range is not None: + img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range) + + # resize to original size + img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_lq = self.color_jitter(img_lq, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) + img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) + if self.opt.get('gt_gray'): # whether convert GT to gray images + img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) + img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) + + # round and clip + img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. + + # normalize + normalize(img_gt, self.mean, self.std, inplace=True) + normalize(img_lq, self.mean, self.std, inplace=True) + + if self.crop_components: + return_dict = { + 'lq': img_lq, + 'gt': img_gt, + 'gt_path': gt_path, + 'loc_left_eye': loc_left_eye, + 'loc_right_eye': loc_right_eye, + 'loc_mouth': loc_mouth + } + return return_dict + else: + return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/blissful_tuner/gfpgan/models/__init__.py b/blissful_tuner/gfpgan/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6afad57a3794b867dabbdb617a16355a24d6a8b3 --- /dev/null +++ b/blissful_tuner/gfpgan/models/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import model modules for registry +# scan all the files that end with '_model.py' under the model folder +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames] diff --git a/blissful_tuner/gfpgan/models/gfpgan_model.py b/blissful_tuner/gfpgan/models/gfpgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b5fb8c953b1ef67b457f56492ad3291d6e5f126d --- /dev/null +++ b/blissful_tuner/gfpgan/models/gfpgan_model.py @@ -0,0 +1,579 @@ +import math +import os.path as osp +import torch +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.losses.gan_loss import r1_penalty +from basicsr.metrics import calculate_metric +from basicsr.models.base_model import BaseModel +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +from collections import OrderedDict +from torch.nn import functional as F +from torchvision.ops import roi_align +from tqdm import tqdm + + +@MODEL_REGISTRY.register() +class GFPGANModel(BaseModel): + """The GFPGAN model for Towards real-world blind face restoratin with generative facial prior""" + + def __init__(self, opt): + super(GFPGANModel, self).__init__(opt) + self.idx = 0 # it is used for saving data for check + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + self.log_size = int(math.log(self.opt['network_g']['out_size'], 2)) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + train_opt = self.opt['train'] + + # ----------- define net_d ----------- # + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + # ----------- define net_g with Exponential Moving Average (EMA) ----------- # + # net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + + self.net_g.train() + self.net_d.train() + self.net_g_ema.eval() + + # ----------- facial component networks ----------- # + if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt): + self.use_facial_disc = True + else: + self.use_facial_disc = False + + if self.use_facial_disc: + # left eye + self.net_d_left_eye = build_network(self.opt['network_d_left_eye']) + self.net_d_left_eye = self.model_to_device(self.net_d_left_eye) + self.print_network(self.net_d_left_eye) + load_path = self.opt['path'].get('pretrain_network_d_left_eye') + if load_path is not None: + self.load_network(self.net_d_left_eye, load_path, True, 'params') + # right eye + self.net_d_right_eye = build_network(self.opt['network_d_right_eye']) + self.net_d_right_eye = self.model_to_device(self.net_d_right_eye) + self.print_network(self.net_d_right_eye) + load_path = self.opt['path'].get('pretrain_network_d_right_eye') + if load_path is not None: + self.load_network(self.net_d_right_eye, load_path, True, 'params') + # mouth + self.net_d_mouth = build_network(self.opt['network_d_mouth']) + self.net_d_mouth = self.model_to_device(self.net_d_mouth) + self.print_network(self.net_d_mouth) + load_path = self.opt['path'].get('pretrain_network_d_mouth') + if load_path is not None: + self.load_network(self.net_d_mouth, load_path, True, 'params') + + self.net_d_left_eye.train() + self.net_d_right_eye.train() + self.net_d_mouth.train() + + # ----------- define facial component gan loss ----------- # + self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device) + + # ----------- define losses ----------- # + # pixel loss + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + # perceptual loss + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + # L1 loss is used in pyramid loss, component style loss and identity loss + self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device) + + # gan loss (wgan) + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + # ----------- define identity loss ----------- # + if 'network_identity' in self.opt: + self.use_identity = True + else: + self.use_identity = False + + if self.use_identity: + # define identity network + self.network_identity = build_network(self.opt['network_identity']) + self.network_identity = self.model_to_device(self.network_identity) + self.print_network(self.network_identity) + load_path = self.opt['path'].get('pretrain_network_identity') + if load_path is not None: + self.load_network(self.network_identity, load_path, True, None) + self.network_identity.eval() + for param in self.network_identity.parameters(): + param.requires_grad = False + + # regularization weights + self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + self.net_d_reg_every = train_opt['net_d_reg_every'] + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + + # ----------- optimizer g ----------- # + net_g_reg_ratio = 1 + normal_params = [] + for _, param in self.net_g.named_parameters(): + normal_params.append(param) + optim_params_g = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }] + optim_type = train_opt['optim_g'].pop('type') + lr = train_opt['optim_g']['lr'] * net_g_reg_ratio + betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio) + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas) + self.optimizers.append(self.optimizer_g) + + # ----------- optimizer d ----------- # + net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) + normal_params = [] + for _, param in self.net_d.named_parameters(): + normal_params.append(param) + optim_params_d = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }] + optim_type = train_opt['optim_d'].pop('type') + lr = train_opt['optim_d']['lr'] * net_d_reg_ratio + betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio) + self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas) + self.optimizers.append(self.optimizer_d) + + # ----------- optimizers for facial component networks ----------- # + if self.use_facial_disc: + # setup optimizers for facial component discriminators + optim_type = train_opt['optim_component'].pop('type') + lr = train_opt['optim_component']['lr'] + # left eye + self.optimizer_d_left_eye = self.get_optimizer( + optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_left_eye) + # right eye + self.optimizer_d_right_eye = self.get_optimizer( + optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_right_eye) + # mouth + self.optimizer_d_mouth = self.get_optimizer( + optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_mouth) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + if 'loc_left_eye' in data: + # get facial component locations, shape (batch, 4) + self.loc_left_eyes = data['loc_left_eye'] + self.loc_right_eyes = data['loc_right_eye'] + self.loc_mouths = data['loc_mouth'] + + # uncomment to check data + # import torchvision + # if self.opt['rank'] == 0: + # import os + # os.makedirs('tmp/gt', exist_ok=True) + # os.makedirs('tmp/lq', exist_ok=True) + # print(self.idx) + # torchvision.utils.save_image( + # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + # torchvision.utils.save_image( + # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + # self.idx = self.idx + 1 + + def construct_img_pyramid(self): + """Construct image pyramid for intermediate restoration loss""" + pyramid_gt = [self.gt] + down_img = self.gt + for _ in range(0, self.log_size - 3): + down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False) + pyramid_gt.insert(0, down_img) + return pyramid_gt + + def get_roi_regions(self, eye_out_size=80, mouth_out_size=120): + face_ratio = int(self.opt['network_g']['out_size'] / 512) + eye_out_size *= face_ratio + mouth_out_size *= face_ratio + + rois_eyes = [] + rois_mouths = [] + for b in range(self.loc_left_eyes.size(0)): # loop for batch size + # left eye and right eye + img_inds = self.loc_left_eyes.new_full((2, 1), b) + bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4) + rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5) + rois_eyes.append(rois) + # mouse + img_inds = self.loc_left_eyes.new_full((1, 1), b) + rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5) + rois_mouths.append(rois) + + rois_eyes = torch.cat(rois_eyes, 0).to(self.device) + rois_mouths = torch.cat(rois_mouths, 0).to(self.device) + + # real images + all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio + self.left_eyes_gt = all_eyes[0::2, :, :, :] + self.right_eyes_gt = all_eyes[1::2, :, :, :] + self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio + # output + all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio + self.left_eyes = all_eyes[0::2, :, :, :] + self.right_eyes = all_eyes[1::2, :, :, :] + self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + def gray_resize_for_identity(self, out, size=128): + out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) + out_gray = out_gray.unsqueeze(1) + out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) + return out_gray + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + self.optimizer_g.zero_grad() + + # do not update facial component net_d + if self.use_facial_disc: + for p in self.net_d_left_eye.parameters(): + p.requires_grad = False + for p in self.net_d_right_eye.parameters(): + p.requires_grad = False + for p in self.net_d_mouth.parameters(): + p.requires_grad = False + + # image pyramid loss weight + pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0) + if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')): + pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error + if pyramid_loss_weight > 0: + self.output, out_rgbs = self.net_g(self.lq, return_rgb=True) + pyramid_gt = self.construct_img_pyramid() + else: + self.output, out_rgbs = self.net_g(self.lq, return_rgb=False) + + # get roi-align regions + if self.use_facial_disc: + self.get_roi_regions(eye_out_size=80, mouth_out_size=120) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # image pyramid loss + if pyramid_loss_weight > 0: + for i in range(0, self.log_size - 2): + l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight + l_g_total += l_pyramid + loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid + + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + # facial component loss + if self.use_facial_disc: + # left eye + fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_left_eye'] = l_g_gan + # right eye + fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_right_eye'] = l_g_gan + # mouth + fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True) + l_g_gan = self.cri_component(fake_mouth, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_mouth'] = l_g_gan + + if self.opt['train'].get('comp_style_weight', 0) > 0: + # get gt feat + _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True) + _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True) + _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True) + + def _comp_style(feat, feat_gt, criterion): + return criterion(self._gram_mat(feat[0]), self._gram_mat( + feat_gt[0].detach())) * 0.5 + criterion( + self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach())) + + # facial component style loss + comp_style_loss = 0 + comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1) + comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1) + comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1) + comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight'] + l_g_total += comp_style_loss + loss_dict['l_g_comp_style_loss'] = comp_style_loss + + # identity loss + if self.use_identity: + identity_weight = self.opt['train']['identity_weight'] + # get gray images and resize + out_gray = self.gray_resize_for_identity(self.output) + gt_gray = self.gray_resize_for_identity(self.gt) + + identity_gt = self.network_identity(gt_gray).detach() + identity_out = self.network_identity(out_gray) + l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight + l_g_total += l_identity + loss_dict['l_identity'] = l_identity + + l_g_total.backward() + self.optimizer_g.step() + + # EMA + self.model_ema(decay=0.5**(32 / (10 * 1000))) + + # ----------- optimize net_d ----------- # + for p in self.net_d.parameters(): + p.requires_grad = True + self.optimizer_d.zero_grad() + if self.use_facial_disc: + for p in self.net_d_left_eye.parameters(): + p.requires_grad = True + for p in self.net_d_right_eye.parameters(): + p.requires_grad = True + for p in self.net_d_mouth.parameters(): + p.requires_grad = True + self.optimizer_d_left_eye.zero_grad() + self.optimizer_d_right_eye.zero_grad() + self.optimizer_d_mouth.zero_grad() + + fake_d_pred = self.net_d(self.output.detach()) + real_d_pred = self.net_d(self.gt) + l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d'] = l_d + # In WGAN, real_score should be positive and fake_score should be negative + loss_dict['real_score'] = real_d_pred.detach().mean() + loss_dict['fake_score'] = fake_d_pred.detach().mean() + l_d.backward() + + # regularization loss + if current_iter % self.net_d_reg_every == 0: + self.gt.requires_grad = True + real_pred = self.net_d(self.gt) + l_d_r1 = r1_penalty(real_pred, self.gt) + l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0]) + loss_dict['l_d_r1'] = l_d_r1.detach().mean() + l_d_r1.backward() + + self.optimizer_d.step() + + # optimize facial component discriminators + if self.use_facial_disc: + # left eye + fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach()) + real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt) + l_d_left_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_left_eye'] = l_d_left_eye + l_d_left_eye.backward() + # right eye + fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach()) + real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt) + l_d_right_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_right_eye'] = l_d_right_eye + l_d_right_eye.backward() + # mouth + fake_d_pred, _ = self.net_d_mouth(self.mouths.detach()) + real_d_pred, _ = self.net_d_mouth(self.mouths_gt) + l_d_mouth = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_mouth'] = l_d_mouth + l_d_mouth.backward() + + self.optimizer_d_left_eye.step() + self.optimizer_d_right_eye.step() + self.optimizer_d_mouth.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _ = self.net_g_ema(self.lq) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _ = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', False) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1)) + metric_data['img'] = sr_img + if hasattr(self, 'gt'): + gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1)) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + self.metric_results[name] += calculate_metric(metric_data, opt_) + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def save(self, epoch, current_iter): + # save net_g and net_d + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + self.save_network(self.net_d, 'net_d', current_iter) + # save component discriminators + if self.use_facial_disc: + self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter) + self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter) + self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter) + # save training state + self.save_training_state(epoch, current_iter) diff --git a/blissful_tuner/gfpgan/train.py b/blissful_tuner/gfpgan/train.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5f1f909ae15a8d830ef65dcb43436d4f4ee7ae --- /dev/null +++ b/blissful_tuner/gfpgan/train.py @@ -0,0 +1,11 @@ +# flake8: noqa +import os.path as osp +from basicsr.train import train_pipeline + +import gfpgan.archs +import gfpgan.data +import gfpgan.models + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/blissful_tuner/gfpgan/utils.py b/blissful_tuner/gfpgan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..231cde713d289a8115c888fb5a55a1501ccd56ac --- /dev/null +++ b/blissful_tuner/gfpgan/utils.py @@ -0,0 +1,148 @@ +import cv2 +import os +import torch +from basicsr.utils import img2tensor, tensor2img +from basicsr.utils.download_util import load_file_from_url +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from torchvision.transforms.functional import normalize + +from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear +from gfpgan.archs.gfpganv1_arch import GFPGANv1 +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class GFPGANer(): + """Helper for restoration with GFPGAN. + + It will detect and crop faces, and then resize the faces to 512x512. + GFPGAN is used to restored the resized faces. + The background is upsampled with the bg_upsampler. + Finally, the faces will be pasted back to the upsample background image. + + Args: + model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). + upscale (float): The upscale of the final output. Default: 2. + arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + bg_upsampler (nn.Module): The upsampler for the background. Default: None. + """ + + def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None): + self.upscale = upscale + self.bg_upsampler = bg_upsampler + + # initialize model + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + # initialize the GFP-GAN + if arch == 'clean': + self.gfpgan = GFPGANv1Clean( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'bilinear': + self.gfpgan = GFPGANBilinear( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'original': + self.gfpgan = GFPGANv1( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=True, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'RestoreFormer': + from gfpgan.archs.restoreformer_arch import RestoreFormer + self.gfpgan = RestoreFormer() + # initialize face helper + self.face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=self.device, + model_rootpath=model_path) + + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None) + loadnet = torch.load(os.path.join(model_path, "GFPGANv1.4.pth")) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + self.gfpgan.load_state_dict(loadnet[keyname], strict=True) + self.gfpgan.eval() + self.gfpgan = self.gfpgan.to(self.device) + + @torch.no_grad() + def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5): + self.face_helper.clean_all() + + if has_aligned: # the inputs are already aligned + img = cv2.resize(img, (512, 512)) + self.face_helper.cropped_faces = [img] + else: + self.face_helper.read_image(img) + # get face landmarks for each face + self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5) + # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels + # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations. + # align and warp each face + self.face_helper.align_warp_face() + + # face restoration + for cropped_face in self.face_helper.cropped_faces: + # prepare data + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) + + try: + output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0] + # convert to image + restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) + except RuntimeError as error: + print(f'\tFailed inference for GFPGAN: {error}.') + restored_face = cropped_face + + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) + + if not has_aligned and paste_back: + # upsample the background + if self.bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] + else: + bg_img = None + + self.face_helper.get_inverse_affine(None) + # paste each restored face to the input image + restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img) + return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img + else: + return self.face_helper.cropped_faces, self.face_helper.restored_faces, None diff --git a/blissful_tuner/gimmvfi/LICENSE b/blissful_tuner/gimmvfi/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1cb1f4a1c2dec6b6a4cb18eaecc3953c16b41e08 --- /dev/null +++ b/blissful_tuner/gimmvfi/LICENSE @@ -0,0 +1,15 @@ +THIS FOLDER AND SUBFOLDERS IN ADDITION TO GIMMVFI.PY (all GIMM-VFI related code and files) LICENSED AS BELOW + +S-Lab License 1.0 +Copyright 2024 S-Lab + +Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. diff --git a/blissful_tuner/gimmvfi/generalizable_INR/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4496bce54cdf230398b87366fac1b931fc8e140e --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +# from .gimm import GIMM + +# from .gimmvfi_f import GIMMVFI_F +# from .gimmvfi_r import GIMMVFI_R + + +# def gimm(config): +# return GIMM(config) + + +# def gimmvfi_f(config): +# return GIMMVFI_F(config) + + +# def gimmvfi_r(config): +# return GIMMVFI_R(config) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/configs.py b/blissful_tuner/gimmvfi/generalizable_INR/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..3e804604b44d8743dbf47ec2cc4d3e9c1d7e2f5d --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/configs.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +from typing import List, Optional +from dataclasses import dataclass, field + +from omegaconf import OmegaConf, MISSING +from .modules.module_config import HypoNetConfig + + +@dataclass +class GIMMConfig: + type: str = "gimm" + ema: Optional[bool] = None + ema_value: Optional[float] = None + fwarp_type: str = "linear" + hyponet: HypoNetConfig = field(default_factory=HypoNetConfig) + coord_range: List[float] = MISSING + modulated_layer_idxs: Optional[List[int]] = None + + @classmethod + def create(cls, config): + # We need to specify the type of the default DataEncoderConfig. + # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value) + # hence merging with the config with other type would cause config error. + defaults = OmegaConf.structured(cls(ema=False)) + config = OmegaConf.merge(defaults, config) + return config + + +@dataclass +class GIMMVFIConfig: + type: str = "gimmvfi" + ema: Optional[bool] = None + ema_value: Optional[float] = None + fwarp_type: str = "linear" + rec_weight: float = 0.1 + hyponet: HypoNetConfig = field(default_factory=HypoNetConfig) + raft_iter: int = 20 + coord_range: List[float] = MISSING + modulated_layer_idxs: Optional[List[int]] = None + + @classmethod + def create(cls, config): + # We need to specify the type of the default DataEncoderConfig. + # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value) + # hence merging with the config with other type would cause config error. + defaults = OmegaConf.structured(cls(ema=False)) + config = OmegaConf.merge(defaults, config) + return config diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/LICENSE b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/README.md b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0365a39977c2903b19eb7a6a65aced439d289696 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/README.md @@ -0,0 +1,164 @@ +# FlowFormer: A Transformer Architecture for Optical Flow +### [Project Page](https://drinkingcoder.github.io/publication/flowformer/) + +> FlowFormer: A Transformer Architecture for Optical Flow +> [Zhaoyang Huang](https://drinkingcoder.github.io)\*, Xiaoyu Shi\*, Chao Zhang, Qiang Wang, Ka Chun Cheung, [Hongwei Qin](http://qinhongwei.com/academic/), [Jifeng Dai](https://jifengdai.org/), [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/) +> ECCV 2022 + + + + +## News +Our FlowFormer++ and VideoFlow are accepted by CVPR and ICCV, which ranks 2nd and 1st on the Sintel benchmark! +Please also refer to our [FlowFormer++](https://github.com/XiaoyuShi97/FlowFormerPlusPlus) and [VideoFlow](https://github.com/XiaoyuShi97/VideoFlow). + +## TODO List +- [x] Code release (2022-8-1) +- [x] Models release (2022-8-1) + +## Data Preparation +Similar to RAFT, to evaluate/train FlowFormer, you will need to download the required datasets. +* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) +* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) +* [Sintel](http://sintel.is.tue.mpg.de/) +* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) +* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) + +By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder + +```Shell +├── datasets + ├── Sintel + ├── test + ├── training + ├── KITTI + ├── testing + ├── training + ├── devkit + ├── FlyingChairs_release + ├── data + ├── FlyingThings3D + ├── frames_cleanpass + ├── frames_finalpass + ├── optical_flow +``` + +## Requirements +```shell +conda create --name flowformer +conda activate flowformer +conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch +pip install yacs loguru einops timm==0.4.12 imageio +``` + +## Training +The script will load the config according to the training stage. The trained model will be saved in a directory in `logs` and `checkpoints`. For example, the following script will load the config `configs/default.py`. The trained model will be saved as `logs/xxxx/final` and `checkpoints/chairs.pth`. +```shell +python -u train_FlowFormer.py --name chairs --stage chairs --validation chairs +``` +To finish the entire training schedule, you can run: +```shell +./run_train.sh +``` + +## Models +We provide [models](https://drive.google.com/drive/folders/1K2dcWxaqOLiQ3PoqRdokrgWsGIf3yBA_?usp=sharing) trained in the four stages. The default path of the models for evaluation is: +```Shell +├── checkpoints + ├── chairs.pth + ├── things.pth + ├── sintel.pth + ├── kitti.pth + ├── flowformer-small.pth + ├── things_kitti.pth +``` +flowformer-small.pth is a small version of our flowformer. things_kitti.pth is the FlowFormer# introduced in our [supplementary](https://drinkingcoder.github.io/publication/flowformer/images/FlowFormer-supp.pdf), used for KITTI training set evaluation. + +## Evaluation +The model to be evaluated is assigned by the `_CN.model` in the config file. + +Evaluating the model on the Sintel training set and the KITTI training set. The corresponding config file is `configs/things_eval.py`. +```Shell +# with tiling technique +python evaluate_FlowFormer_tile.py --eval sintel_validation +python evaluate_FlowFormer_tile.py --eval kitti_validation --model checkpoints/things_kitti.pth +# without tiling technique +python evaluate_FlowFormer.py --dataset sintel +``` +||with tile|w/o tile| +|----|-----|--------| +|clean|0.94|1.01| +|final|2.33|2.40| + +Evaluating the small version model. The corresponding config file is `configs/small_things_eval.py`. +```Shell +# with tiling technique +python evaluate_FlowFormer_tile.py --eval sintel_validation --small +# without tiling technique +python evaluate_FlowFormer.py --dataset sintel --small +``` +||with tile|w/o tile| +|----|-----|--------| +|clean|1.21|1.32| +|final|2.61|2.68| + + +Generating the submission for the Sintel and KITTI benchmarks. The corresponding config file is `configs/submission.py`. +```Shell +python evaluate_FlowFormer_tile.py --eval sintel_submission +python evaluate_FlowFormer_tile.py --eval kitti_submission +``` +Visualizing the sintel dataset: +```Shell +python visualize_flow.py --eval_type sintel --keep_size +``` +Visualizing an image sequence extracted from a video: +```Shell +python visualize_flow.py --eval_type seq +``` +The default image sequence format is: +```Shell +├── demo_data + ├── mihoyo + ├── 000001.png + ├── 000002.png + ├── 000003.png + . + . + . + ├── 001000.png +``` + + +## License +FlowFormer is released under the Apache License + +## Citation +```bibtex +@article{huang2022flowformer, + title={{FlowFormer}: A Transformer Architecture for Optical Flow}, + author={Huang, Zhaoyang and Shi, Xiaoyu and Zhang, Chao and Wang, Qiang and Cheung, Ka Chun and Qin, Hongwei and Dai, Jifeng and Li, Hongsheng}, + journal={{ECCV}}, + year={2022} +} +@inproceedings{shi2023flowformer++, + title={Flowformer++: Masked cost volume autoencoding for pretraining optical flow estimation}, + author={Shi, Xiaoyu and Huang, Zhaoyang and Li, Dasong and Zhang, Manyuan and Cheung, Ka Chun and See, Simon and Qin, Hongwei and Dai, Jifeng and Li, Hongsheng}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={1599--1610}, + year={2023} +} +@article{huang2023flowformer, + title={FlowFormer: A Transformer Architecture and Its Masked Cost Volume Autoencoding for Optical Flow}, + author={Huang, Zhaoyang and Shi, Xiaoyu and Zhang, Chao and Wang, Qiang and Li, Yijin and Qin, Hongwei and Dai, Jifeng and Wang, Xiaogang and Li, Hongsheng}, + journal={arXiv preprint arXiv:2306.05442}, + year={2023} +} +``` + +## Acknowledgement + +In this project, we use parts of codes in: +- [RAFT](https://github.com/princeton-vl/RAFT) +- [GMA](https://github.com/zacjiang/GMA) +- [timm](https://github.com/rwightman/pytorch-image-models) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d13b162df2c38ec265ae2e21c3389ac8eb2a047f --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/__init__.py @@ -0,0 +1,18 @@ +import torch +from .configs.submission import get_cfg +from .core.FlowFormer import build_flowformer + + +def initialize_Flowformer(): + cfg = get_cfg() + model = build_flowformer(cfg) + + ckpt = torch.load(cfg.model, map_location="cpu") + + def convert(param): + return {k.replace("module.", ""): v for k, v in param.items() if "module" in k} + + ckpt = convert(ckpt) + model.load_state_dict(ckpt) + + return model diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/alt_cuda_corr/correlation.cpp b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/alt_cuda_corr/correlation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b01584d19edb99e7feec5f2e4c51169a1ed208db --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/alt_cuda_corr/correlation.cpp @@ -0,0 +1,54 @@ +#include +#include + +// CUDA forward declarations +std::vector corr_cuda_forward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + int radius); + +std::vector corr_cuda_backward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + torch::Tensor corr_grad, + int radius); + +// C++ interface +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector corr_forward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + int radius) { + CHECK_INPUT(fmap1); + CHECK_INPUT(fmap2); + CHECK_INPUT(coords); + + return corr_cuda_forward(fmap1, fmap2, coords, radius); +} + + +std::vector corr_backward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + torch::Tensor corr_grad, + int radius) { + CHECK_INPUT(fmap1); + CHECK_INPUT(fmap2); + CHECK_INPUT(coords); + CHECK_INPUT(corr_grad); + + return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &corr_forward, "CORR forward"); + m.def("backward", &corr_backward, "CORR backward"); +} \ No newline at end of file diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/alt_cuda_corr/correlation_kernel.cu b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/alt_cuda_corr/correlation_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..145e5804a16ece51b8ff5f1cb61ae8dab4fc3bb7 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/alt_cuda_corr/correlation_kernel.cu @@ -0,0 +1,324 @@ +#include +#include +#include +#include + + +#define BLOCK_H 4 +#define BLOCK_W 8 +#define BLOCK_HW BLOCK_H * BLOCK_W +#define CHANNEL_STRIDE 32 + + +__forceinline__ __device__ +bool within_bounds(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +template +__global__ void corr_forward_kernel( + const torch::PackedTensorAccessor32 fmap1, + const torch::PackedTensorAccessor32 fmap2, + const torch::PackedTensorAccessor32 coords, + torch::PackedTensorAccessor32 corr, + int r) +{ + const int b = blockIdx.x; + const int h0 = blockIdx.y * blockDim.x; + const int w0 = blockIdx.z * blockDim.y; + const int tid = threadIdx.x * blockDim.y + threadIdx.y; + + const int H1 = fmap1.size(1); + const int W1 = fmap1.size(2); + const int H2 = fmap2.size(1); + const int W2 = fmap2.size(2); + const int N = coords.size(1); + const int C = fmap1.size(3); + + __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t x2s[BLOCK_HW]; + __shared__ scalar_t y2s[BLOCK_HW]; + + for (int c=0; c(floor(y2s[k1]))-r+iy; + int w2 = static_cast(floor(x2s[k1]))-r+ix; + int c2 = tid % CHANNEL_STRIDE; + + auto fptr = fmap2[b][h2][w2]; + if (within_bounds(h2, w2, H2, W2)) + f2[c2][k1] = fptr[c+c2]; + else + f2[c2][k1] = 0.0; + } + + __syncthreads(); + + scalar_t s = 0.0; + for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_nw) += nw; + + if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_ne) += ne; + + if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_sw) += sw; + + if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_se) += se; + } + } + } + } +} + + +template +__global__ void corr_backward_kernel( + const torch::PackedTensorAccessor32 fmap1, + const torch::PackedTensorAccessor32 fmap2, + const torch::PackedTensorAccessor32 coords, + const torch::PackedTensorAccessor32 corr_grad, + torch::PackedTensorAccessor32 fmap1_grad, + torch::PackedTensorAccessor32 fmap2_grad, + torch::PackedTensorAccessor32 coords_grad, + int r) +{ + + const int b = blockIdx.x; + const int h0 = blockIdx.y * blockDim.x; + const int w0 = blockIdx.z * blockDim.y; + const int tid = threadIdx.x * blockDim.y + threadIdx.y; + + const int H1 = fmap1.size(1); + const int W1 = fmap1.size(2); + const int H2 = fmap2.size(1); + const int W2 = fmap2.size(2); + const int N = coords.size(1); + const int C = fmap1.size(3); + + __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; + + __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; + + __shared__ scalar_t x2s[BLOCK_HW]; + __shared__ scalar_t y2s[BLOCK_HW]; + + for (int c=0; c(floor(y2s[k1]))-r+iy; + int w2 = static_cast(floor(x2s[k1]))-r+ix; + int c2 = tid % CHANNEL_STRIDE; + + auto fptr = fmap2[b][h2][w2]; + if (within_bounds(h2, w2, H2, W2)) + f2[c2][k1] = fptr[c+c2]; + else + f2[c2][k1] = 0.0; + + f2_grad[c2][k1] = 0.0; + } + + __syncthreads(); + + const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; + scalar_t g = 0.0; + + int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); + int ix_ne = H1*W1*((iy-1) + rd*ix); + int ix_sw = H1*W1*(iy + rd*(ix-1)); + int ix_se = H1*W1*(iy + rd*ix); + + if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_nw) * dy * dx; + + if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_ne) * dy * (1-dx); + + if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_sw) * (1-dy) * dx; + + if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); + + for (int k=0; k(floor(y2s[k1]))-r+iy; + int w2 = static_cast(floor(x2s[k1]))-r+ix; + int c2 = tid % CHANNEL_STRIDE; + + scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; + if (within_bounds(h2, w2, H2, W2)) + atomicAdd(fptr+c+c2, f2_grad[c2][k1]); + } + } + } + } + __syncthreads(); + + + for (int k=0; k corr_cuda_forward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + int radius) +{ + const auto B = coords.size(0); + const auto N = coords.size(1); + const auto H = coords.size(2); + const auto W = coords.size(3); + + const auto rd = 2 * radius + 1; + auto opts = fmap1.options(); + auto corr = torch::zeros({B, N, rd*rd, H, W}, opts); + + const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W); + const dim3 threads(BLOCK_H, BLOCK_W); + + corr_forward_kernel<<>>( + fmap1.packed_accessor32(), + fmap2.packed_accessor32(), + coords.packed_accessor32(), + corr.packed_accessor32(), + radius); + + return {corr}; +} + +std::vector corr_cuda_backward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + torch::Tensor corr_grad, + int radius) +{ + const auto B = coords.size(0); + const auto N = coords.size(1); + + const auto H1 = fmap1.size(1); + const auto W1 = fmap1.size(2); + const auto H2 = fmap2.size(1); + const auto W2 = fmap2.size(2); + const auto C = fmap1.size(3); + + auto opts = fmap1.options(); + auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); + auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); + auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); + + const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); + const dim3 threads(BLOCK_H, BLOCK_W); + + + corr_backward_kernel<<>>( + fmap1.packed_accessor32(), + fmap2.packed_accessor32(), + coords.packed_accessor32(), + corr_grad.packed_accessor32(), + fmap1_grad.packed_accessor32(), + fmap2_grad.packed_accessor32(), + coords_grad.packed_accessor32(), + radius); + + return {fmap1_grad, fmap2_grad, coords_grad}; +} \ No newline at end of file diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/alt_cuda_corr/setup.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/alt_cuda_corr/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9dfa19d3a88125e4f6ecccc445213234138673df --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/alt_cuda_corr/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name="correlation", + ext_modules=[ + CUDAExtension( + "alt_cuda_corr", + sources=["correlation.cpp", "correlation_kernel.cu"], + extra_compile_args={"cxx": [], "nvcc": ["-O3"]}, + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/chairs_split.txt b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/chairs_split.txt new file mode 100644 index 0000000000000000000000000000000000000000..6ae8f0b72a22fc061552604c94664e3a0287914e --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/chairs_split.txt @@ -0,0 +1,22872 @@ +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +2 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 \ No newline at end of file diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/default.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/default.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5e5036cd25d469f7498d82f41cc81f5e1c458b --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/default.py @@ -0,0 +1,78 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +_CN.name = "default" +_CN.suffix = "arxiv2" +_CN.gamma = 0.8 +_CN.max_flow = 400 +_CN.batch_size = 8 +_CN.sum_freq = 100 +_CN.val_freq = 5000 +_CN.image_size = [368, 496] +_CN.add_noise = True +_CN.critical_params = [] + +_CN.transformer = "latentcostformer" +_CN.restore_ckpt = None + +########################################### +# latentcostformer +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = "linear" +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 8 +_CN.latentcostformer.cost_latent_dim = 128 +_CN.latentcostformer.predictor_dim = 128 +_CN.latentcostformer.motion_feature_dim = 209 # use concat, so double query_latent_dim +_CN.latentcostformer.arc_type = "transformer" +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 3 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = "single" +_CN.latentcostformer.gma = True +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 64 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.cnet = "twins" +_CN.latentcostformer.fnet = "twins" +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False +# decoder +_CN.latentcostformer.decoder_depth = 12 +_CN.latentcostformer.critical_params = [ + "cost_heads_num", + "vert_c_dim", + "cnet", + "pretrain", + "add_flow_token", + "encoder_depth", + "gma", + "cost_encoder_res", +] +########################################## + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = "OneCycleLR" + +_CN.trainer.optimizer = "adamw" +_CN.trainer.canonical_lr = 25e-5 +_CN.trainer.adamw_decay = 1e-4 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = "linear" + + +def get_cfg(): + return _CN.clone() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/kitti.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/kitti.py new file mode 100644 index 0000000000000000000000000000000000000000..74b663fb51281cce1dfb25637f31900d4842c0da --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/kitti.py @@ -0,0 +1,83 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +_CN.name = "kitti" +_CN.suffix = "kitti" +_CN.gamma = 0.85 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 499999999 +_CN.image_size = [432, 960] +_CN.add_noise = True +_CN.critical_params = [] + +_CN.transformer = "latentcostformer" + +_CN.model = None + +_CN.restore_ckpt = "checkpoints/sintel.pth" + +# latentcostformer +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = "linear" +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 8 +_CN.latentcostformer.cost_latent_dim = 128 +_CN.latentcostformer.predictor_dim = 128 +_CN.latentcostformer.motion_feature_dim = 209 # use concat, so double query_latent_dim +_CN.latentcostformer.arc_type = "transformer" +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 3 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.vertical_encoder_attn = "twins" +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = "single" +_CN.latentcostformer.gma = "GMA" +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 64 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.pwc_aug = False +_CN.latentcostformer.cnet = "twins" +_CN.latentcostformer.fnet = "twins" +_CN.latentcostformer.no_sc = False +_CN.latentcostformer.use_rpe = False +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False +# decoder +_CN.latentcostformer.decoder_depth = 12 +_CN.latentcostformer.critical_params = [ + "cost_heads_num", + "vert_c_dim", + "cnet", + "pretrain", + "add_flow_token", + "encoder_depth", + "gma", + "cost_encoder_res", +] + + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = "OneCycleLR" +_CN.trainer.optimizer = "adamw" +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-5 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 50000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = "linear" + + +def get_cfg(): + return _CN.clone() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/sintel.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/sintel.py new file mode 100644 index 0000000000000000000000000000000000000000..56d45538c487f90064ef89e8ec03b5cdadbb4579 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/sintel.py @@ -0,0 +1,77 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +_CN.name = "default" +_CN.suffix = "sintel" +_CN.gamma = 0.85 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 5000000 +_CN.image_size = [432, 960] +_CN.add_noise = True +_CN.critical_params = [] + +_CN.transformer = "latentcostformer" +_CN.restore_ckpt = "checkpoints/things.pth" + +# latentcostformer +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = "linear" +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 8 +_CN.latentcostformer.cost_latent_dim = 128 +_CN.latentcostformer.arc_type = "transformer" +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 3 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = "single" +_CN.latentcostformer.no_pe = False +_CN.latentcostformer.gma = "GMA" +_CN.latentcostformer.kernel_size = 9 +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 64 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.cnet = "twins" +_CN.latentcostformer.fnet = "twins" +_CN.latentcostformer.no_sc = False +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False + +# decoder +_CN.latentcostformer.decoder_depth = 12 +_CN.latentcostformer.critical_params = [ + "cost_heads_num", + "vert_c_dim", + "cnet", + "pretrain", + "add_flow_token", + "encoder_depth", + "gma", + "cost_encoder_res", +] + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = "OneCycleLR" +_CN.trainer.optimizer = "adamw" +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-5 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = "linear" + + +def get_cfg(): + return _CN.clone() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/small_things_eval.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/small_things_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7c1fa553ec92f515a6141fdd46d2899dcb933647 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/small_things_eval.py @@ -0,0 +1,77 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +_CN.name = "" +_CN.suffix = "" +_CN.gamma = 0.8 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 5000000 +_CN.image_size = [432, 960] +_CN.add_noise = False +_CN.critical_params = [] + +_CN.transformer = "latentcostformer" +_CN.model = "checkpoints/flowformer-small/things.pth" + +# latentcostformer +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = "linear" +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 4 +_CN.latentcostformer.cost_latent_dim = 32 +_CN.latentcostformer.arc_type = "transformer" +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 1 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = "single" +_CN.latentcostformer.no_pe = False +_CN.latentcostformer.gma = "GMA" +_CN.latentcostformer.kernel_size = 9 +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 0 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.cnet = "basicencoder" +_CN.latentcostformer.fnet = "basicencoder" +_CN.latentcostformer.no_sc = False +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False + +# decoder +_CN.latentcostformer.decoder_depth = 32 +_CN.latentcostformer.critical_params = [ + "cost_heads_num", + "vert_c_dim", + "cnet", + "pretrain", + "add_flow_token", + "encoder_depth", + "gma", + "cost_encoder_res", +] + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = "OneCycleLR" +_CN.trainer.optimizer = "adamw" +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-4 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = "linear" + + +def get_cfg(): + return _CN.clone() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/submission.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/submission.py new file mode 100644 index 0000000000000000000000000000000000000000..4d225d331fe96981d3413c1fe90ed668053f03ad --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/submission.py @@ -0,0 +1,77 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +_CN.name = "" +_CN.suffix = "" +_CN.gamma = 0.8 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 5000000 +_CN.image_size = [432, 960] +_CN.add_noise = False +_CN.critical_params = [] + +_CN.transformer = "latentcostformer" +_CN.model = "pretrained_ckpt/flowformer_sintel.pth" + +# latentcostformer +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = "linear" +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 8 +_CN.latentcostformer.cost_latent_dim = 128 +_CN.latentcostformer.arc_type = "transformer" +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 3 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = "single" +_CN.latentcostformer.no_pe = False +_CN.latentcostformer.gma = "GMA" +_CN.latentcostformer.kernel_size = 9 +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 64 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.cnet = "twins" +_CN.latentcostformer.fnet = "twins" +_CN.latentcostformer.no_sc = False +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False + +# decoder +_CN.latentcostformer.decoder_depth = 32 +_CN.latentcostformer.critical_params = [ + "cost_heads_num", + "vert_c_dim", + "cnet", + "pretrain", + "add_flow_token", + "encoder_depth", + "gma", + "cost_encoder_res", +] + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = "OneCycleLR" +_CN.trainer.optimizer = "adamw" +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-4 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = "linear" + + +def get_cfg(): + return _CN.clone() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/things.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/things.py new file mode 100644 index 0000000000000000000000000000000000000000..85049605c2667fe8ea112a3fa36e0644f5f65b02 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/things.py @@ -0,0 +1,76 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +_CN.name = "" +_CN.suffix = "" +_CN.gamma = 0.8 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 5000000 +_CN.image_size = [432, 960] +_CN.add_noise = True +_CN.critical_params = [] + +_CN.transformer = "latentcostformer" +_CN.restore_ckpt = "checkpoints/chairs.pth" + +####################################### +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = "linear" +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 8 +_CN.latentcostformer.cost_latent_dim = 128 +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 3 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.nat_rep = "abs" +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = "single" +_CN.latentcostformer.no_pe = False +_CN.latentcostformer.gma = "GMA" +_CN.latentcostformer.kernel_size = 9 +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 64 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.cnet = "twins" +_CN.latentcostformer.fnet = "twins" +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False + +# decoder +_CN.latentcostformer.decoder_depth = 12 +_CN.latentcostformer.critical_params = [ + "cost_heads_num", + "vert_c_dim", + "cnet", + "pretrain", + "add_flow_token", + "encoder_depth", + "gma", + "cost_encoder_res", +] + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = "OneCycleLR" +_CN.trainer.optimizer = "adamw" +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-4 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = "linear" + + +def get_cfg(): + return _CN.clone() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/things_eval.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/things_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..812d0fc034ca11b0e0f730193aace718008159d0 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/things_eval.py @@ -0,0 +1,77 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +_CN.name = "" +_CN.suffix = "" +_CN.gamma = 0.8 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 5000000 +_CN.image_size = [432, 960] +_CN.add_noise = False +_CN.critical_params = [] + +_CN.transformer = "latentcostformer" +_CN.model = "checkpoints/things.pth" + +# latentcostformer +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = "linear" +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 8 +_CN.latentcostformer.cost_latent_dim = 128 +_CN.latentcostformer.arc_type = "transformer" +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 3 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = "single" +_CN.latentcostformer.no_pe = False +_CN.latentcostformer.gma = "GMA" +_CN.latentcostformer.kernel_size = 9 +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 64 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.cnet = "twins" +_CN.latentcostformer.fnet = "twins" +_CN.latentcostformer.no_sc = False +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False + +# decoder +_CN.latentcostformer.decoder_depth = 32 +_CN.latentcostformer.critical_params = [ + "cost_heads_num", + "vert_c_dim", + "cnet", + "pretrain", + "add_flow_token", + "encoder_depth", + "gma", + "cost_encoder_res", +] + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = "OneCycleLR" +_CN.trainer.optimizer = "adamw" +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-4 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = "linear" + + +def get_cfg(): + return _CN.clone() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/things_flowformer_sharp.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/things_flowformer_sharp.py new file mode 100644 index 0000000000000000000000000000000000000000..29b6525a435de7113b22d426a922e99d13b14f7a --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/configs/things_flowformer_sharp.py @@ -0,0 +1,76 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +_CN.name = "" +_CN.suffix = "" +_CN.gamma = 0.8 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 5000000 +_CN.image_size = [400, 720] +_CN.add_noise = True +_CN.critical_params = [] + +_CN.transformer = "latentcostformer" +_CN.restore_ckpt = "checkpoints/chairs.pth" + +####################################### +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = "linear" +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 8 +_CN.latentcostformer.cost_latent_dim = 128 +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 3 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.nat_rep = "abs" +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = "single" +_CN.latentcostformer.no_pe = False +_CN.latentcostformer.gma = "GMA" +_CN.latentcostformer.kernel_size = 9 +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 64 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.cnet = "twins" +_CN.latentcostformer.fnet = "twins" +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False + +# decoder +_CN.latentcostformer.decoder_depth = 12 +_CN.latentcostformer.critical_params = [ + "cost_heads_num", + "vert_c_dim", + "cnet", + "pretrain", + "add_flow_token", + "encoder_depth", + "gma", + "cost_encoder_res", +] + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = "OneCycleLR" +_CN.trainer.optimizer = "adamw" +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-4 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = "linear" + + +def get_cfg(): + return _CN.clone() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/attention.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..00959171d8ac8ebb751aba9582825b66193b2706 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/attention.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + + +class BroadMultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(BroadMultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim / heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q.squeeze(), "i (heads d) -> heads i d", heads=self.heads) + K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads) + + dots = einsum("hid, bhjd -> bhij", Q, K) * self.scale # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, _, _ = K.shape + _, N, _ = Q.shape + + V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads) + + out = einsum("bhij, bhjd -> bhid", attn, V) + out = rearrange(out, "b heads n d -> b n (heads d)", b=B, n=N) + + return out + + +class MultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim / heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads) + K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads) + + dots = ( + einsum("bhid, bhjd -> bhij", Q, K) * self.scale + ) # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, HW, _ = Q.shape + + V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads) + + out = einsum("bhij, bhjd -> bhid", attn, V) + out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW) + + return out + + +# class MultiHeadAttentionRelative_encoder(nn.Module): +# def __init__(self, dim, heads): +# super(MultiHeadAttentionRelative, self).__init__() +# self.dim = dim +# self.heads = heads +# self.scale = (dim/heads) ** -0.5 +# self.attend = nn.Softmax(dim=-1) + +# def attend_with_rpe(self, Q, K, Q_r, K_r): +# """ +# Q: [BH1W1, H3W3, dim] +# K: [BH1W1, H3W3, dim] +# Q_r: [BH1W1, H3W3, H3W3, dim] +# K_r: [BH1W1, H3W3, H3W3, dim] +# """ + +# Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + +# # context-context similarity +# c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads H3W3 H3W3] +# # context-position similarity +# c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3] +# # position-context similarity +# p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:]) +# p_c = torch.squeeze(p_c, dim=4) +# p_c = p_c.permute(0, 1, 3, 2) +# dots = c_c + c_p + p_c +# return self.attend(dots) + +# def forward(self, Q, K, V, Q_r, K_r): +# attn = self.attend_with_rpe(Q, K, Q_r, K_r) +# B, HW, _ = Q.shape + +# V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + +# out = einsum('bhij, bhjd -> bhid', attn, V) +# out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + +# return out + + +class MultiHeadAttentionRelative(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttentionRelative, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim / heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K, Q_r, K_r): + """ + Q: [BH1W1, 1, dim] + K: [BH1W1, H3W3, dim] + Q_r: [BH1W1, H3W3, dim] + K_r: [BH1W1, H3W3, dim] + """ + + Q = rearrange( + Q, "b i (heads d) -> b heads i d", heads=self.heads + ) # [BH1W1, heads, 1, dim] + K = rearrange( + K, "b j (heads d) -> b heads j d", heads=self.heads + ) # [BH1W1, heads, H3W3, dim] + K_r = rearrange( + K_r, "b j (heads d) -> b heads j d", heads=self.heads + ) # [BH1W1, heads, H3W3, dim] + Q_r = rearrange( + Q_r, "b j (heads d) -> b heads j d", heads=self.heads + ) # [BH1W1, heads, H3W3, dim] + + # context-context similarity + c_c = einsum("bhid, bhjd -> bhij", Q, K) * self.scale # [(B H1W1) heads 1 H3W3] + # context-position similarity + c_p = ( + einsum("bhid, bhjd -> bhij", Q, K_r) * self.scale + ) # [(B H1W1) heads 1 H3W3] + # position-context similarity + p_c = ( + einsum("bhijd, bhikd -> bhijk", Q_r[:, :, :, None, :], K[:, :, :, None, :]) + * self.scale + ) + p_c = torch.squeeze(p_c, dim=4) + p_c = p_c.permute(0, 1, 3, 2) + dots = c_c + c_p + p_c + return self.attend(dots) + + def forward(self, Q, K, V, Q_r, K_r): + attn = self.attend_with_rpe(Q, K, Q_r, K_r) + B, HW, _ = Q.shape + + V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads) + + out = einsum("bhij, bhjd -> bhid", attn, V) + out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW) + + return out + + +def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device) + return torch.cat( + [ + torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR), + torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR), + torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR), + torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR), + ], + dim=-1, + ) + + +def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device) + return torch.cat( + [ + torch.sin(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)), + torch.cos(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)), + torch.sin(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)), + torch.cos(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)), + ], + dim=-1, + ) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/cnn.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..da1887cf4d5860c592d6c5e72c4f46a0b5052eb6 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/cnn.py @@ -0,0 +1,649 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +import math +import numpy as np + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride + ) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, norm_fn="batch", dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + mul = input_dim // 3 + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64 * mul) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64 * mul) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64 * mul) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, 64 * mul, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 * mul + self.layer1 = self._make_layer(64 * mul, stride=1) + self.layer2 = self._make_layer(96 * mul, stride=2) + self.layer3 = self._make_layer(128 * mul, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128 * mul, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class ConvNets(nn.Module): + def __init__(self, in_dim, out_dim, inter_dim, depth, stride=1): + super(ConvNets, self).__init__() + + self.conv_first = nn.Conv2d( + in_dim, inter_dim, kernel_size=3, padding=1, stride=stride + ) + self.conv_last = nn.Conv2d( + inter_dim, out_dim, kernel_size=3, padding=1, stride=stride + ) + self.relu = nn.ReLU(inplace=True) + self.inter_convs = nn.ModuleList( + [ + ResidualBlock(inter_dim, inter_dim, norm_fn="none", stride=1) + for i in range(depth) + ] + ) + + def forward(self, x): + x = self.relu(self.conv_first(x)) + for inter_conv in self.inter_convs: + x = inter_conv(x) + x = self.conv_last(x) + return x + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.motion_feature_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicFuseMotion(nn.Module): + def __init__(self, args): + super(BasicFuseMotion, self).__init__() + cor_planes = args.motion_feature_dim + out_planes = args.query_latent_dim + + self.normf1 = nn.InstanceNorm2d(128) + self.normf2 = nn.InstanceNorm2d(128) + + self.convf1 = nn.Conv2d(2, 128, 3, padding=1) + self.convf2 = nn.Conv2d(128, 128, 3, padding=1) + self.convf3 = nn.Conv2d(128, 64, 3, padding=1) + + s = 1 + self.normc1 = nn.InstanceNorm2d(256 * s) + self.normc2 = nn.InstanceNorm2d(256 * s) + self.normc3 = nn.InstanceNorm2d(256 * s) + + self.convc1 = nn.Conv2d(cor_planes + 128, 256 * s, 1, padding=0) + self.convc2 = nn.Conv2d(256 * s, 256 * s, 3, padding=1) + self.convc3 = nn.Conv2d(256 * s, 256 * s, 3, padding=1) + self.convc4 = nn.Conv2d(256 * s, 256 * s, 3, padding=1) + self.conv = nn.Conv2d(256 * s + 64, out_planes, 1, padding=0) + + def forward(self, flow, feat, context1=None): + flo = F.relu(self.normf1(self.convf1(flow))) + flo = F.relu(self.normf2(self.convf2(flo))) + flo = self.convf3(flo) + + feat = torch.cat([feat, context1], dim=1) + feat = F.relu(self.normc1(self.convc1(feat))) + feat = F.relu(self.normc2(self.convc2(feat))) + feat = F.relu(self.normc3(self.convc3(feat))) + feat = self.convc4(feat) + + feat = torch.cat([flo, feat], dim=1) + feat = F.relu(self.conv(feat)) + + return feat + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow + + +class DirectMeanMaskPredictor(nn.Module): + def __init__(self, args): + super(DirectMeanMaskPredictor, self).__init__() + self.flow_head = FlowHead(args.predictor_dim, hidden_dim=256) + self.mask = nn.Sequential( + nn.Conv2d(args.predictor_dim, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, motion_features): + delta_flow = self.flow_head(motion_features) + mask = 0.25 * self.mask(motion_features) + + return mask, delta_flow + + +class BaiscMeanPredictor(nn.Module): + def __init__(self, args, hidden_dim=128): + super(BaiscMeanPredictor, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, latent, flow): + motion_features = self.encoder(flow, latent) + delta_flow = self.flow_head(motion_features) + mask = 0.25 * self.mask(motion_features) + + return mask, delta_flow + + +class BasicRPEEncoder(nn.Module): + def __init__(self, args): + super(BasicRPEEncoder, self).__init__() + self.args = args + dim = args.query_latent_dim + self.encoder = nn.Sequential( + nn.Linear(2, dim // 2), + nn.ReLU(inplace=True), + nn.Linear(dim // 2, dim), + nn.ReLU(inplace=True), + nn.Linear(dim, dim), + ) + + def forward(self, rpe_tokens): + return self.encoder(rpe_tokens) + + +from .twins import Block, CrossBlock + + +class TwinsSelfAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsSelfAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0.0 + drop_rate = 0.0 + attn_drop_rate = 0.0 + + self.local_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=ws, + with_rpe=True, + ) + self.global_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=1, + with_rpe=True, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + x = self.global_block(x, size) + + tgt = self.local_block(tgt, size) + tgt = self.global_block(tgt, size) + return x, tgt + + +class TwinsCrossAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsCrossAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0.0 + drop_rate = 0.0 + attn_drop_rate = 0.0 + + self.local_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=ws, + with_rpe=True, + ) + self.global_block = CrossBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=1, + with_rpe=True, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + tgt = self.local_block(tgt, size) + x, tgt = self.global_block(x, tgt, size) + + return x, tgt diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/convnext.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..667534c4df4c6ceaa83abfd7c07b099f2d647e20 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/convnext.py @@ -0,0 +1,98 @@ +#from turtle import forward +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + + +class ConvNextLayer(nn.Module): + def __init__(self, dim, depth=4): + super().__init__() + self.net = nn.Sequential(*[ConvNextBlock(dim=dim) for j in range(depth)]) + + def forward(self, x): + return self.net(x) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class ConvNextBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + # print(f"conv next layer") + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + x + return x + + +class LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/decoder.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7b50088562a30a458f47a2633db86fadc6dd77 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/decoder.py @@ -0,0 +1,321 @@ +import loguru +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + +from ...utils.utils import coords_grid, bilinear_sampler, upflow8 +from .attention import ( + MultiHeadAttention, + LinearPositionEmbeddingSine, + ExpPositionEmbeddingSine, +) +from typing import Optional, Tuple + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from .gru import BasicUpdateBlock, GMAUpdateBlock +from .gma import Attention + + +def initialize_flow(img): + """Flow is represented as difference between two means flow = mean1 - mean0""" + N, C, H, W = img.shape + mean = coords_grid(N, H, W).to(img.device) + mean_init = coords_grid(N, H, W).to(img.device) + + # optical flow computed as difference: flow = mean1 - mean0 + return mean, mean_init + + +class CrossAttentionLayer(nn.Module): + # def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + def __init__( + self, + qk_dim, + v_dim, + query_token_dim, + tgt_token_dim, + add_flow_token=True, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + pe="linear", + ): + super(CrossAttentionLayer, self).__init__() + + head_dim = qk_dim // num_heads + self.scale = head_dim**-0.5 + self.query_token_dim = query_token_dim + self.pe = pe + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = MultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = ( + nn.Linear(query_token_dim, qk_dim, bias=True), + nn.Linear(tgt_token_dim, qk_dim, bias=True), + nn.Linear(tgt_token_dim, v_dim, bias=True), + ) + + self.proj = nn.Linear(v_dim * 2, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout), + ) + self.add_flow_token = add_flow_token + self.dim = qk_dim + + def forward(self, query, key, value, memory, query_coord, patch_size, size_h3w3): + """ + query_coord [B, 2, H1, W1] + """ + B, _, H1, W1 = query_coord.shape + + if key is None and value is None: + key = self.k(memory) + value = self.v(memory) + + # [B, 2, H1, W1] -> [BH1W1, 1, 2] + query_coord = query_coord.contiguous() + query_coord = ( + query_coord.view(B, 2, -1) + .permute(0, 2, 1)[:, :, None, :] + .contiguous() + .view(B * H1 * W1, 1, 2) + ) + if self.pe == "linear": + query_coord_enc = LinearPositionEmbeddingSine(query_coord, dim=self.dim) + elif self.pe == "exp": + query_coord_enc = ExpPositionEmbeddingSine(query_coord, dim=self.dim) + + short_cut = query + query = self.norm1(query) + + if self.add_flow_token: + q = self.q(query + query_coord_enc) + else: + q = self.q(query_coord_enc) + k, v = key, value + + x = self.multi_head_attn(q, k, v) + + x = self.proj(torch.cat([x, short_cut], dim=2)) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x, k, v + + +class MemoryDecoderLayer(nn.Module): + def __init__(self, dim, cfg): + super(MemoryDecoderLayer, self).__init__() + self.cfg = cfg + self.patch_size = cfg.patch_size # for converting coords into H2', W2' space + + query_token_dim, tgt_token_dim = cfg.query_latent_dim, cfg.cost_latent_dim + qk_dim, v_dim = query_token_dim, query_token_dim + self.cross_attend = CrossAttentionLayer( + qk_dim, + v_dim, + query_token_dim, + tgt_token_dim, + add_flow_token=cfg.add_flow_token, + dropout=cfg.dropout, + ) + + def forward(self, query, key, value, memory, coords1, size, size_h3w3): + """ + x: [B*H1*W1, 1, C] + memory: [B*H1*W1, H2'*W2', C] + coords1 [B, 2, H2, W2] + size: B, C, H1, W1 + 1. Note that here coords0 and coords1 are in H2, W2 space. + Should first convert it into H2', W2' space. + 2. We assume the upper-left point to be [0, 0], instead of letting center of upper-left patch to be [0, 0] + """ + x_global, k, v = self.cross_attend( + query, key, value, memory, coords1, self.patch_size, size_h3w3 + ) + B, C, H1, W1 = size + C = self.cfg.query_latent_dim + x_global = x_global.view(B, H1, W1, C).permute(0, 3, 1, 2) + return x_global, k, v + + +class ReverseCostExtractor(nn.Module): + def __init__(self, cfg): + super(ReverseCostExtractor, self).__init__() + self.cfg = cfg + + def forward(self, cost_maps, coords0, coords1): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + BH1W1, heads, H2, W2 = cost_maps.shape + B, _, H1, W1 = coords1.shape + + assert (H1 == H2) and (W1 == W2) + assert BH1W1 == B * H1 * W1 + + cost_maps = cost_maps.reshape(B, H1 * W1 * heads, H2, W2) + coords = coords1.permute(0, 2, 3, 1) + corr = bilinear_sampler(cost_maps, coords) # [B, H1*W1*heads, H2, W2] + corr = rearrange( + corr, + "b (h1 w1 heads) h2 w2 -> (b h2 w2) heads h1 w1", + b=B, + heads=heads, + h1=H1, + w1=W1, + h2=H2, + w2=W2, + ) + + r = 4 + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords0.device) + centroid = coords0.permute(0, 2, 3, 1).reshape(BH1W1, 1, 1, 2) + delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords = centroid + delta + corr = bilinear_sampler(corr, coords) + corr = corr.view(B, H1, W1, -1).permute(0, 3, 1, 2) + return corr + + +class MemoryDecoder(nn.Module): + def __init__(self, cfg): + super(MemoryDecoder, self).__init__() + dim = self.dim = cfg.query_latent_dim + self.cfg = cfg + + self.flow_token_encoder = nn.Sequential( + nn.Conv2d(81 * cfg.cost_heads_num, dim, 1, 1), + nn.GELU(), + nn.Conv2d(dim, dim, 1, 1), + ) + self.proj = nn.Conv2d(256, 256, 1) + self.depth = cfg.decoder_depth + self.decoder_layer = MemoryDecoderLayer(dim, cfg) + + if self.cfg.gma: + self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128) + self.att = Attention( + args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128 + ) + else: + self.update_block = BasicUpdateBlock(self.cfg, hidden_dim=128) + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def encode_flow_token(self, cost_maps, coords): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + r = 4 + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid = coords.reshape(batch * h1 * w1, 1, 1, 2) + delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords = centroid + delta + corr = bilinear_sampler(cost_maps, coords) + corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2) + return corr + + def forward(self, cost_memory, context, data={}, flow_init=None, iters=None): + """ + memory: [B*H1*W1, H2'*W2', C] + context: [B, D, H1, W1] + """ + cost_maps = data["cost_maps"] + coords0, coords1 = initialize_flow(context) + + if flow_init is not None: + # print("[Using warm start]") + coords1 = coords1 + flow_init + + # flow = coords1 + + flow_predictions = [] + + context = self.proj(context) + net, inp = torch.split(context, [128, 128], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + if self.cfg.gma: + attention = self.att(inp) + + size = net.shape + key, value = None, None + if iters is None: + iters = self.depth + for idx in range(iters): + coords1 = coords1.detach() + + cost_forward = self.encode_flow_token(cost_maps, coords1) + # cost_backward = self.reverse_cost_extractor(cost_maps, coords0, coords1) + + query = self.flow_token_encoder(cost_forward) + query = ( + query.permute(0, 2, 3, 1) + .contiguous() + .view(size[0] * size[2] * size[3], 1, self.dim) + ) + cost_global, key, value = self.decoder_layer( + query, key, value, cost_memory, coords1, size, data["H3W3"] + ) + if self.cfg.only_global: + corr = cost_global + else: + corr = torch.cat([cost_global, cost_forward], dim=1) + + flow = coords1 - coords0 + + if self.cfg.gma: + net, up_mask, delta_flow = self.update_block( + net, inp, corr, flow, attention + ) + else: + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # flow = delta_flow + coords1 = coords1 + delta_flow + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + flow_predictions.append(flow_up) + + # if self.training: + # return flow_predictions + # else: + return flow_predictions[-1], coords1 - coords0 diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/encoder.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eb09013da5271d45c6f0a11707abaca7c8d72c1b --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/encoder.py @@ -0,0 +1,539 @@ +import loguru +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +import numpy as np + +from einops.layers.torch import Rearrange +from einops import rearrange +import sys +from ...utils.utils import coords_grid, bilinear_sampler, upflow8 +from .attention import ( + BroadMultiHeadAttention, + MultiHeadAttention, + LinearPositionEmbeddingSine, + ExpPositionEmbeddingSine, +) +from ..encoders import twins_svt_large +from typing import Optional, Tuple +from .twins import Size_, PosConv +from .cnn import TwinsSelfAttentionLayer, TwinsCrossAttentionLayer, BasicEncoder +from .mlpmixer import MLPMixerLayer +from .convnext import ConvNextLayer +import time + +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe="linear"): + super().__init__() + self.patch_size = patch_size + self.dim = embed_dim + self.pe = pe + + # assert patch_size == 8 + if patch_size == 8: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d( + embed_dim // 4, embed_dim // 2, kernel_size=6, stride=2, padding=2 + ), + nn.ReLU(), + nn.Conv2d( + embed_dim // 2, embed_dim, kernel_size=6, stride=2, padding=2 + ), + ) + elif patch_size == 4: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d( + embed_dim // 4, embed_dim, kernel_size=6, stride=2, padding=2 + ), + ) + else: + print(f"patch size = {patch_size} is unacceptable.") + + self.ffn_with_coord = nn.Sequential( + nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1), + nn.ReLU(), + nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1), + ) + self.norm = nn.LayerNorm(embed_dim * 2) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape # C == 1 + + pad_l = pad_t = 0 + pad_r = (self.patch_size - W % self.patch_size) % self.patch_size + pad_b = (self.patch_size - H % self.patch_size) % self.patch_size + x = F.pad(x, (pad_l, pad_r, pad_t, pad_b)) + + x = self.proj(x) + out_size = x.shape[2:] + + patch_coord = ( + coords_grid(B, out_size[0], out_size[1]).to(x.device) * self.patch_size + + self.patch_size / 2 + ) # in feature coordinate space + patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1) + if self.pe == "linear": + patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim) + elif self.pe == "exp": + patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim) + patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view( + B, -1, out_size[0], out_size[1] + ) + + x_pe = torch.cat([x, patch_coord_enc], dim=1) + x = self.ffn_with_coord(x_pe) + x = self.norm(x.flatten(2).transpose(1, 2)) + + return x, out_size + + +from .twins import Block, CrossBlock + + +class GroupVerticalSelfAttentionLayer(nn.Module): + def __init__( + self, + dim, + cfg, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + ): + super(GroupVerticalSelfAttentionLayer, self).__init__() + self.cfg = cfg + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + embed_dim = dim + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0.0 + drop_rate = dropout + attn_drop_rate = 0.0 + + self.block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=ws, + with_rpe=True, + vert_c_dim=cfg.vert_c_dim, + groupattention=True, + cfg=self.cfg, + ) + + def forward(self, x, size, context=None): + x = self.block(x, size, context) + + return x + + +class VerticalSelfAttentionLayer(nn.Module): + def __init__( + self, + dim, + cfg, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + ): + super(VerticalSelfAttentionLayer, self).__init__() + self.cfg = cfg + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + embed_dim = dim + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0.0 + drop_rate = dropout + attn_drop_rate = 0.0 + + self.local_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=ws, + with_rpe=True, + vert_c_dim=cfg.vert_c_dim, + ) + self.global_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=1, + with_rpe=True, + vert_c_dim=cfg.vert_c_dim, + ) + + def forward(self, x, size, context=None): + x = self.local_block(x, size, context) + x = self.global_block(x, size, context) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class SelfAttentionLayer(nn.Module): + def __init__( + self, + dim, + cfg, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + ): + super(SelfAttentionLayer, self).__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.multi_head_attn = MultiHeadAttention(dim, num_heads) + self.q, self.k, self.v = ( + nn.Linear(dim, dim, bias=True), + nn.Linear(dim, dim, bias=True), + nn.Linear(dim, dim, bias=True), + ) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = x + x = self.norm1(x) + + q, k, v = self.q(x), self.k(x), self.v(x) + + x = self.multi_head_attn(q, k, v) + + x = self.proj(x) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class CrossAttentionLayer(nn.Module): + def __init__( + self, + qk_dim, + v_dim, + query_token_dim, + tgt_token_dim, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + ): + super(CrossAttentionLayer, self).__init__() + assert ( + qk_dim % num_heads == 0 + ), f"dim {qk_dim} should be divided by num_heads {num_heads}." + assert ( + v_dim % num_heads == 0 + ), f"dim {v_dim} should be divided by num_heads {num_heads}." + """ + Query Token: [N, C] -> [N, qk_dim] (Q) + Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V) + """ + self.num_heads = num_heads + head_dim = qk_dim // num_heads + self.scale = head_dim**-0.5 + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = ( + nn.Linear(query_token_dim, qk_dim, bias=True), + nn.Linear(tgt_token_dim, qk_dim, bias=True), + nn.Linear(tgt_token_dim, v_dim, bias=True), + ) + + self.proj = nn.Linear(v_dim, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout), + ) + + def forward(self, query, tgt_token): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = query + query = self.norm1(query) + + q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token) + + x = self.multi_head_attn(q, k, v) + + x = short_cut + self.proj_drop(self.proj(x)) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + +class CostPerceiverEncoder(nn.Module): + def __init__(self, cfg): + super(CostPerceiverEncoder, self).__init__() + self.cfg = cfg + self.patch_size = cfg.patch_size + self.patch_embed = PatchEmbed( + in_chans=self.cfg.cost_heads_num, + patch_size=self.patch_size, + embed_dim=cfg.cost_latent_input_dim, + pe=cfg.pe, + ) + + self.depth = cfg.encoder_depth + + self.latent_tokens = nn.Parameter( + torch.randn(1, cfg.cost_latent_token_num, cfg.cost_latent_dim) + ) + + query_token_dim, tgt_token_dim = ( + cfg.cost_latent_dim, + cfg.cost_latent_input_dim * 2, + ) + qk_dim, v_dim = query_token_dim, query_token_dim + self.input_layer = CrossAttentionLayer( + qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=cfg.dropout + ) + + if cfg.use_mlp: + self.encoder_layers = nn.ModuleList( + [ + MLPMixerLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) + for idx in range(self.depth) + ] + ) + else: + self.encoder_layers = nn.ModuleList( + [ + SelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) + for idx in range(self.depth) + ] + ) + + if self.cfg.vertical_conv: + self.vertical_encoder_layers = nn.ModuleList( + [ConvNextLayer(cfg.cost_latent_dim) for idx in range(self.depth)] + ) + else: + self.vertical_encoder_layers = nn.ModuleList( + [ + VerticalSelfAttentionLayer( + cfg.cost_latent_dim, cfg, dropout=cfg.dropout + ) + for idx in range(self.depth) + ] + ) + self.cost_scale_aug = None + if "cost_scale_aug" in cfg.keys(): + self.cost_scale_aug = cfg.cost_scale_aug + print("[Using cost_scale_aug: {}]".format(self.cost_scale_aug)) + + def forward(self, cost_volume, data, context=None): + B, heads, H1, W1, H2, W2 = cost_volume.shape + cost_maps = ( + cost_volume.permute(0, 2, 3, 1, 4, 5) + .contiguous() + .view(B * H1 * W1, self.cfg.cost_heads_num, H2, W2) + ) + data["cost_maps"] = cost_maps + + if self.cost_scale_aug is not None: + scale_factor = ( + torch.FloatTensor(B * H1 * W1, self.cfg.cost_heads_num, H2, W2) + .uniform_(self.cost_scale_aug[0], self.cost_scale_aug[1]) + .cuda() + ) + cost_maps = cost_maps * scale_factor + + x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C + data["H3W3"] = size + H3, W3 = size + + x = self.input_layer(self.latent_tokens, x) + + short_cut = x + + for idx, layer in enumerate(self.encoder_layers): + x = layer(x) + if self.cfg.vertical_conv: + # B, H1*W1, K, D -> B, K, D, H1*W1 -> B*K, D, H1, W1 + x = ( + x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1) + .permute(0, 3, 1, 2) + .reshape(B * self.cfg.cost_latent_token_num, -1, H1, W1) + ) + x = self.vertical_encoder_layers[idx](x) + # B*K, D, H1, W1 -> B, K, D, H1*W1 -> B, H1*W1, K, D + x = ( + x.view(B, self.cfg.cost_latent_token_num, -1, H1 * W1) + .permute(0, 2, 3, 1) + .reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1) + ) + else: + x = ( + x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1) + .permute(0, 2, 1, 3) + .reshape(B * self.cfg.cost_latent_token_num, H1 * W1, -1) + ) + x = self.vertical_encoder_layers[idx](x, (H1, W1), context) + x = ( + x.view(B, self.cfg.cost_latent_token_num, H1 * W1, -1) + .permute(0, 2, 1, 3) + .reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1) + ) + + if self.cfg.cost_encoder_res is True: + x = x + short_cut + # print("~~~~") + return x + + +class MemoryEncoder(nn.Module): + def __init__(self, cfg): + super(MemoryEncoder, self).__init__() + self.cfg = cfg + + if cfg.fnet == "twins": + self.feat_encoder = twins_svt_large(pretrained=self.cfg.pretrain) + elif cfg.fnet == "basicencoder": + self.feat_encoder = BasicEncoder(output_dim=256, norm_fn="instance") + else: + exit() + self.channel_convertor = nn.Conv2d( + cfg.encoder_latent_dim, cfg.encoder_latent_dim, 1, padding=0, bias=False + ) + self.cost_perceiver_encoder = CostPerceiverEncoder(cfg) + + def corr(self, fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = rearrange( + fmap1, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num + ) + fmap2 = rearrange( + fmap2, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num + ) + corr = einsum("bhid, bhjd -> bhij", fmap1, fmap2) + corr = corr.permute(0, 2, 1, 3).view( + batch * ht * wd, self.cfg.cost_heads_num, ht, wd + ) + # corr = self.norm(self.relu(corr)) + corr = corr.view(batch, ht * wd, self.cfg.cost_heads_num, ht * wd).permute( + 0, 2, 1, 3 + ) + corr = corr.view(batch, self.cfg.cost_heads_num, ht, wd, ht, wd) + + return corr + + def forward(self, img1, img2, data, context=None, return_feat=False): + # The original implementation + # feat_s = self.feat_encoder(img1) + # feat_t = self.feat_encoder(img2) + # feat_s = self.channel_convertor(feat_s) + # feat_t = self.channel_convertor(feat_t) + + imgs = torch.cat([img1, img2], dim=0) + feats = self.feat_encoder(imgs) + feats = self.channel_convertor(feats) + B = feats.shape[0] // 2 + feat_s = feats[:B] + if return_feat: + ffeat = feats[:B] + feat_t = feats[B:] + + B, C, H, W = feat_s.shape + size = (H, W) + + if self.cfg.feat_cross_attn: + feat_s = feat_s.flatten(2).transpose(1, 2) + feat_t = feat_t.flatten(2).transpose(1, 2) + + for layer in self.layers: + feat_s, feat_t = layer(feat_s, feat_t, size) + + feat_s = feat_s.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + feat_t = feat_t.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + cost_volume = self.corr(feat_s, feat_t) + x = self.cost_perceiver_encoder(cost_volume, data, context) + + if return_feat: + return x, ffeat + return x diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gma.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gma.py new file mode 100644 index 0000000000000000000000000000000000000000..0394543cdbf48af871918b80d3403132be963655 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gma.py @@ -0,0 +1,123 @@ +import torch +from torch import nn, einsum +from einops import rearrange + + +class RelPosEmb(nn.Module): + def __init__(self, max_pos_size, dim_head): + super().__init__() + self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) + self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) + + deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange( + max_pos_size + ).view(-1, 1) + rel_ind = deltas + max_pos_size - 1 + self.register_buffer("rel_ind", rel_ind) + + def forward(self, q): + batch, heads, h, w, c = q.shape + height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) + width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) + + height_emb = rearrange(height_emb, "(x u) d -> x u () d", x=h) + width_emb = rearrange(width_emb, "(y v) d -> y () v d", y=w) + + height_score = einsum("b h x y d, x u v d -> b h x y u v", q, height_emb) + width_score = einsum("b h x y d, y u v d -> b h x y u v", q, width_emb) + + return height_score + width_score + + +class Attention(nn.Module): + def __init__( + self, + *, + args, + dim, + max_pos_size=100, + heads=4, + dim_head=128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head**-0.5 + inner_dim = heads * dim_head + + self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) + + self.pos_emb = RelPosEmb(max_pos_size, dim_head) + for param in self.pos_emb.parameters(): + param.requires_grad = False + + def forward(self, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + q, k = self.to_qk(fmap).chunk(2, dim=1) + + q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k)) + q = self.scale * q + + # if self.args.position_only: + # sim = self.pos_emb(q) + + # elif self.args.position_and_content: + # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + # sim_pos = self.pos_emb(q) + # sim = sim_content + sim_pos + + # else: + sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k) + + sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)") + attn = sim.softmax(dim=-1) + + return attn + + +class Aggregate(nn.Module): + def __init__( + self, + args, + dim, + heads=4, + dim_head=128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head**-0.5 + inner_dim = heads * dim_head + + self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) + + self.gamma = nn.Parameter(torch.zeros(1)) + + if dim != inner_dim: + self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) + else: + self.project = None + + def forward(self, attn, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + v = self.to_v(fmap) + v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads) + out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + + if self.project is not None: + out = self.project(out) + + out = fmap + self.gamma * out + + return out + + +if __name__ == "__main__": + att = Attention(dim=128, heads=1) + fmap = torch.randn(2, 128, 40, 90) + out = att(fmap) + + print(out.shape) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gru.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gru.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8e9f0d6f70ceeb29b04c51d803afdf1ad6aa80 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gru.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + if args.only_global: + print("[Decoding with only global cost]") + cor_planes = args.query_latent_dim + else: + cor_planes = 81 + args.query_latent_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow + + +from .gma import Aggregate + + +class GMAUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128): + super().__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU( + hidden_dim=hidden_dim, input_dim=128 + hidden_dim + hidden_dim + ) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) + + def forward(self, net, inp, corr, flow, attention): + motion_features = self.encoder(flow, corr) + motion_features_global = self.aggregator(attention, motion_features) + inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) + + # Attentional update + net = self.gru(net, inp_cat) + + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/mlpmixer.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/mlpmixer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3646aae56c6de91508ff3458eb3c7e592abbe10 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/mlpmixer.py @@ -0,0 +1,55 @@ +from torch import nn +from einops.layers.torch import Rearrange, Reduce +from functools import partial +import numpy as np + + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear): + return nn.Sequential( + dense(dim, dim * expansion_factor), + nn.GELU(), + nn.Dropout(dropout), + dense(dim * expansion_factor, dim), + nn.Dropout(dropout), + ) + + +class MLPMixerLayer(nn.Module): + def __init__(self, dim, cfg, drop_path=0.0, dropout=0.0): + super(MLPMixerLayer, self).__init__() + + # print(f"use mlp mixer layer") + K = cfg.cost_latent_token_num + expansion_factor = cfg.mlp_expansion_factor + chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear + + self.mlpmixer = nn.Sequential( + PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)), + PreNormResidual( + dim, FeedForward(dim, expansion_factor, dropout, chan_last) + ), + ) + + def compute_params(self): + num = 0 + for param in self.mlpmixer.parameters(): + num += np.prod(param.size()) + + return num + + def forward(self, x): + """ + x: [BH1W1, K, D] + """ + + return self.mlpmixer(x) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/transformer.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1e5b981ee699d9d62e1f1e6455520765485510ac --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/transformer.py @@ -0,0 +1,74 @@ +import loguru +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + +from ...utils.utils import coords_grid, bilinear_sampler, upflow8 +from ..common import ( + FeedForward, + pyramid_retrieve_tokens, + sampler, + sampler_gaussian_fix, + retrieve_tokens, + MultiHeadAttention, + MLP, +) +from ..encoders import twins_svt_large_context, twins_svt_large +from ...position_encoding import PositionEncodingSine, LinearPositionEncoding +from .twins import PosConv +from .encoder import MemoryEncoder +from .decoder import MemoryDecoder +from .cnn import BasicEncoder + + +class FlowFormer(nn.Module): + def __init__(self, cfg): + super(FlowFormer, self).__init__() + self.cfg = cfg + + self.memory_encoder = MemoryEncoder(cfg) + self.memory_decoder = MemoryDecoder(cfg) + if cfg.cnet == "twins": + self.context_encoder = twins_svt_large(pretrained=self.cfg.pretrain) + elif cfg.cnet == "basicencoder": + self.context_encoder = BasicEncoder(output_dim=256, norm_fn="instance") + + def build_coord(self, img): + N, C, H, W = img.shape + coords = coords_grid(N, H // 8, W // 8) + return coords + + def forward( + self, image1, image2, output=None, flow_init=None, return_feat=False, iters=None + ): + # Following https://github.com/princeton-vl/RAFT/ + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + data = {} + + if self.cfg.context_concat: + context = self.context_encoder(torch.cat([image1, image2], dim=1)) + else: + if return_feat: + context, cfeat = self.context_encoder(image1, return_feat=return_feat) + else: + context = self.context_encoder(image1) + if return_feat: + cost_memory, ffeat = self.memory_encoder( + image1, image2, data, context, return_feat=return_feat + ) + else: + cost_memory = self.memory_encoder(image1, image2, data, context) + + flow_predictions = self.memory_decoder( + cost_memory, context, data, flow_init=flow_init, iters=iters + ) + + if return_feat: + return flow_predictions, cfeat, ffeat + return flow_predictions diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/twins.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..78531d17710bb8c078215421bb7b3d7fb059eb3d --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/twins.py @@ -0,0 +1,1360 @@ +""" Twins +A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers` + - https://arxiv.org/pdf/2104.13840.pdf +Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below +""" +# -------------------------------------------------------- +# Twins +# Copyright (c) 2021 Meituan +# Licensed under The Apache 2.0 License [see LICENSE for details] +# Written by Xinjie Li, Xiangxiang Chu +# -------------------------------------------------------- +import math +from copy import deepcopy +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import Attention +from timm.models.helpers import build_model_with_cfg#, overlay_external_default_cfg +from .attention import MultiHeadAttention, LinearPositionEmbeddingSine +from ...utils.utils import coords_grid, bilinear_sampler, upflow8 + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": None, + "crop_pct": 0.9, + "interpolation": "bicubic", + "fixed_input_size": True, + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "first_conv": "patch_embeds.0.proj", + "classifier": "head", + **kwargs, + } + + +default_cfgs = { + "twins_pcpvt_small": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth", + ), + "twins_pcpvt_base": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth", + ), + "twins_pcpvt_large": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth", + ), + "twins_svt_small": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth", + ), + "twins_svt_base": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth", + ), + "twins_svt_large": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth", + ), +} + +Size_ = Tuple[int, int] + + +class GroupAttnRPEContext(nn.Module): + """Latent cost tokens attend to different group""" + + def __init__( + self, + dim, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + ws=1, + cfg=None, + vert_c_dim=0, + ): + super(GroupAttnRPEContext, self).__init__() + assert ws != 1 + assert cfg is not None + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + assert ( + cfg.cost_latent_token_num % 5 == 0 + ), "cost_latent_token_num should be divided by 5." + assert vert_c_dim > 0, "vert_c_dim should not be 0" + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.vert_c_dim = vert_c_dim + + self.cfg = cfg + + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C + self.vert_c_dim + H, W = size + batch_num = B // 5 + + context = context.repeat(B // context.shape[0], 1, 1, 1) + context = context.view(B, -1, H * W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp * Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + coords_enc = coords_enc.reshape(B, Hp, Wp, C_qk) + + q = ( + self.q(x_qk + coords_enc) + .reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads) + .transpose(2, 3) + ) + q = q.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x_qk + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat( + [ + kv[:batch_num, self.ws : Hp, :, :], + kv[:batch_num, Hp - self.ws : Hp, :, :], + ], + dim=1, + ) + kv_down = torch.cat( + [ + kv[batch_num : batch_num * 2, : self.ws, :, :], + kv[batch_num : batch_num * 2, : Hp - self.ws, :, :], + ], + dim=1, + ) + kv_left = torch.cat( + [ + kv[batch_num * 2 : batch_num * 3, :, self.ws : Wp, :], + kv[batch_num * 2 : batch_num * 3, :, Wp - self.ws : Wp, :], + ], + dim=2, + ) + kv_right = torch.cat( + [ + kv[batch_num * 3 : batch_num * 4, :, : self.ws, :], + kv[batch_num * 3 : batch_num * 4, :, : Wp - self.ws, :], + ], + dim=2, + ) + kv_center = kv[batch_num * 4 : batch_num * 5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape( + B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads + ).transpose(2, 3) + k = k.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + v = v.reshape( + B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads + ).transpose(2, 3) + v = v.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GroupAttnRPE(nn.Module): + """Latent cost tokens attend to different group""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, ws=1, cfg=None): + super(GroupAttnRPE, self).__init__() + assert ws != 1 + assert cfg is not None + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + assert ( + cfg.cost_latent_token_num % 5 == 0 + ), "cost_latent_token_num should be divided by 5." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.cfg = cfg + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + batch_num = B // 5 + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp * Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + coords_enc = coords_enc.reshape(B, Hp, Wp, C) + + q = ( + self.q(x + coords_enc) + .reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads) + .transpose(2, 3) + ) + q = q.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat( + [ + kv[:batch_num, self.ws : Hp, :, :], + kv[:batch_num, Hp - self.ws : Hp, :, :], + ], + dim=1, + ) + kv_down = torch.cat( + [ + kv[batch_num : batch_num * 2, : self.ws, :, :], + kv[batch_num : batch_num * 2, : Hp - self.ws, :, :], + ], + dim=1, + ) + kv_left = torch.cat( + [ + kv[batch_num * 2 : batch_num * 3, :, self.ws : Wp, :], + kv[batch_num * 2 : batch_num * 3, :, Wp - self.ws : Wp, :], + ], + dim=2, + ) + kv_right = torch.cat( + [ + kv[batch_num * 3 : batch_num * 4, :, : self.ws, :], + kv[batch_num * 3 : batch_num * 4, :, : Wp - self.ws, :], + ], + dim=2, + ) + kv_center = kv[batch_num * 4 : batch_num * 5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape( + B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads + ).transpose(2, 3) + k = k.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + v = v.reshape( + B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads + ).transpose(2, 3) + v = v.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LocallyGroupedAttnRPEContext(nn.Module): + """LSA: self attention within a group""" + + def __init__( + self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, ws=1, vert_c_dim=0 + ): + assert ws != 1 + super(LocallyGroupedAttnRPEContext, self).__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.vert_c_dim = vert_c_dim + + self.context_proj = nn.Linear(256, vert_c_dim) + # context are not added to value + self.q = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + C_qk = C + self.vert_c_dim + + context = context.repeat(B // context.shape[0], 1, 1, 1) + context = context.view(B, -1, H * W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + x_qk = x_qk.reshape(B, _h, self.ws, _w, self.ws, C_qk).transpose(2, 3) + + v = ( + self.v(x) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk).view( + B, self.ws, self.ws, C_qk + ) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x_qk = x_qk + coords_enc[:, None, None, :, :, :] + + q = ( + self.q(x_qk) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + k = ( + self.k(x_qk) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalSubSampleAttnRPEContext(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__( + self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1, vert_c_dim=0 + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.vert_c_dim = vert_c_dim + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr_key = nn.Conv2d( + dim + vert_c_dim, dim, kernel_size=sr_ratio, stride=sr_ratio + ) + self.sr_value = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C + self.vert_c_dim + H, W = size + context = context.repeat(B // context.shape[0], 1, 1, 1) + context = context.view(B, -1, H * W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp * Wp + x = x.view(B, -1, C) + x_qk = x_qk.view(B, -1, C_qk) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = ( + self.q(x_qk + coords_enc) + .reshape(B, padded_N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr_key is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x_qk = x_qk.permute(0, 2, 1).reshape(B, C_qk, *padded_size) + x = self.sr_value(x).reshape(B, C, -1).permute(0, 2, 1) + x_qk = self.sr_key(x_qk).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + x_qk = self.norm(x_qk) + + coords = coords_grid( + B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio + ).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = ( + self.k(x_qk + coords_enc) + .reshape( + B, + (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + v = ( + self.v(x) + .reshape( + B, + (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class LocallyGroupedAttnRPE(nn.Module): + """LSA: self attention within a group""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, ws=1): + assert ws != 1 + super(LocallyGroupedAttnRPE, self).__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + v = ( + self.v(x) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C).view( + B, self.ws, self.ws, C + ) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x = x + coords_enc[:, None, None, :, :, :] + + q = ( + self.q(x) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + k = ( + self.k(x) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalSubSampleAttnRPE(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp * Wp + x = x.view(B, -1, C) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = ( + self.q(x + coords_enc) + .reshape(B, padded_N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + + coords = coords_grid( + B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio + ).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = ( + self.k(x + coords_enc) + .reshape( + B, + (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + v = ( + self.v(x) + .reshape( + B, + (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CrossGlobalSubSampleAttnRPE(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + coords = coords_grid(B, *size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, H*W, C + # x: B, H*W, C + q = ( + self.q(x + coords_enc) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + coords = coords_grid(B, size[0] // self.sr_ratio, size[1] // self.sr_ratio).to( + x.device + ) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = ( + self.k(tgt + coords_enc) + .reshape( + B, + (size[0] // self.sr_ratio) * (size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + v = ( + self.v(tgt) + .reshape( + B, + (size[0] // self.sr_ratio) * (size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class LocallyGroupedAttn(nn.Module): + """LSA: self attention within a group""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, ws=1): + assert ws != 1 + super(LocallyGroupedAttn, self).__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + qkv = ( + self.qkv(x) + .reshape( + B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalSubSampleAttn(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_): + B, N, C = x.shape + q = ( + self.q(x) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = ( + self.kv(x) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CrossGlobalSubSampleAttn(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + q = ( + self.q(x) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + kv = ( + self.kv(tgt) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CrossBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ws=None, + with_rpe=True, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = CrossGlobalSubSampleAttnRPE( + dim, num_heads, attn_drop, drop, sr_ratio + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, src, tgt, size: Size_): + src_shortcut, tgt_shortcut = src, tgt + + src, tgt = self.norm1(src), self.norm1(tgt) + src = src_shortcut + self.drop_path(self.attn(src, tgt, size)) + tgt = tgt_shortcut + self.drop_path(self.attn(tgt, src, size)) + + src = src + self.drop_path(self.mlp(self.norm2(src))) + tgt = tgt + self.drop_path(self.mlp(self.norm2(tgt))) + return src, tgt + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ws=None, + with_rpe=False, + vert_c_dim=0, + groupattention=False, + cfg=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + if groupattention: + assert with_rpe, "Not implementing groupattention without rpe" + if vert_c_dim > 0: + self.attn = GroupAttnRPEContext( + dim, num_heads, attn_drop, drop, ws, cfg, vert_c_dim + ) + else: + self.attn = GroupAttnRPE(dim, num_heads, attn_drop, drop, ws, cfg) + elif ws is None: + self.attn = Attention(dim, num_heads, False, None, attn_drop, drop) + elif ws == 1: + if with_rpe: + if vert_c_dim > 0: + self.attn = GlobalSubSampleAttnRPEContext( + dim, num_heads, attn_drop, drop, sr_ratio, vert_c_dim + ) + else: + self.attn = GlobalSubSampleAttnRPE( + dim, num_heads, attn_drop, drop, sr_ratio + ) + else: + self.attn = GlobalSubSampleAttn( + dim, num_heads, attn_drop, drop, sr_ratio + ) + else: + if with_rpe: + if vert_c_dim > 0: + self.attn = LocallyGroupedAttnRPEContext( + dim, num_heads, attn_drop, drop, ws, vert_c_dim + ) + else: + self.attn = LocallyGroupedAttnRPE( + dim, num_heads, attn_drop, drop, ws + ) + else: + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, size: Size_, context=None): + x = x + self.drop_path(self.attn(self.norm1(x), size, context)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), + ) + self.stride = stride + + def forward(self, x, size: Size_): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + def no_weight_decay(self): + return ["proj.%d.weight" % i for i in range(4)] + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert ( + img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0 + ), f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + out_size = (H // self.patch_size[0], W // self.patch_size[1]) + + return x, out_size + + +class Twins(nn.Module): + """Twins Vision Transfomer (Revisiting Spatial Attention) + Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git + """ + + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), + mlp_ratios=(4, 4, 4, 4), + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=(3, 4, 6, 3), + sr_ratios=(8, 4, 2, 1), + wss=None, + block_cls=Block, + init_weight=True, + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] + + img_size = to_2tuple(img_size) + prev_chs = in_chans + self.patch_embeds = nn.ModuleList() + self.pos_drops = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append( + PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i]) + ) + self.pos_drops.append(nn.Dropout(p=drop_rate)) + prev_chs = embed_dims[i] + img_size = tuple(t // patch_size for t in img_size) + patch_size = 2 + + self.blocks = nn.ModuleList() + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + for k in range(len(depths)): + _block = nn.ModuleList( + [ + block_cls( + dim=embed_dims[k], + num_heads=num_heads[k], + mlp_ratio=mlp_ratios[k], + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k], + ) + for i in range(depths[k]) + ] + ) + self.blocks.append(_block) + cur += depths[k] + + self.pos_block = nn.ModuleList( + [PosConv(embed_dim, embed_dim) for embed_dim in embed_dims] + ) + + self.norm = norm_layer(self.num_features) + + # classification head + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) + + # init weights + if init_weight: + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return set(["pos_block." + n for n, p in self.pos_block.named_parameters()]) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward_features(self, x): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block) + ): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + x = self.norm(x) + return x.mean(dim=1) # GAP here + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +# def _create_twins(variant, pretrained=False, **kwargs): +# if kwargs.get('features_only', None): +# raise RuntimeError('features_only not implemented for Vision Transformer models.') + +# model = build_model_with_cfg( +# Twins, variant, pretrained, +# default_cfg=default_cfgs[variant], +# **kwargs) +# return model + + +# @register_model +# def twins_pcpvt_small(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_pcpvt_base(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_pcpvt_large(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_small(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_base(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_large(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) + +# @register_model +# def twins_svt_large_context(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], in_chans=6, init_weight=False, **kwargs) +# return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) +# # def twins_svt_large_context(pretrained=False, **kwargs): +# # model_kwargs = dict( +# # patch_size=4, embed_dims=[128, 256], num_heads=[4, 8], mlp_ratios=[4, 4], +# # depths=[2, 2], wss=[7, 7], sr_ratios=[8, 4], in_chans=6, init_weight=False, **kwargs) +# # return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5cacb241fd63c6a37774a31a0b4439f29f496c44 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/__init__.py @@ -0,0 +1,11 @@ +import torch + + +def build_flowformer(cfg): + name = cfg.transformer + if name == "latentcostformer": + from .LatentCostFormer.transformer import FlowFormer + else: + raise ValueError(f"FlowFormer = {name} is not a valid architecture!") + + return FlowFormer(cfg[name]) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/common.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/common.py new file mode 100644 index 0000000000000000000000000000000000000000..70cc859f86e54a2e3fcacdb7807420e66ae49871 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/common.py @@ -0,0 +1,566 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + +from ..utils.utils import coords_grid, bilinear_sampler, indexing +from loguru import logger + +import math + + +def nerf_encoding(x, L=6, NORMALIZE_FACOR=1 / 300): + """ + x is of shape [*, 2]. The last dimension are two coordinates (x and y). + """ + freq_bands = 2.0 ** torch.linspace(0, L, L - 1).to(x.device) + return torch.cat( + [ + x * NORMALIZE_FACOR, + torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR), + torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR), + torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR), + torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR), + ], + dim=-1, + ) + + +def sampler_gaussian(latent, mean, std, image_size, point_num=25): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # std [B, 1, H, W] + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = ( + F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1)) + * STD_MAX + * delta + * 3 + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta_3sigma + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + latent, coords + ) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) + + return sampled_latents, sampled_weights + + +def sampler_gaussian_zy( + latent, mean, std, image_size, point_num=25, return_deltaXY=False, beta=1 +): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # std [B, 1, H, W] + H, W = image_size + B, HW, D = latent.shape + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = ( + std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3 + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta_3sigma + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + latent, coords + ) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / beta + + if return_deltaXY: + return sampled_latents, sampled_weights, delta_3sigma + else: + return sampled_latents, sampled_weights + + +def sampler_gaussian(latent, mean, std, image_size, point_num=25, return_deltaXY=False): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # std [B, 1, H, W] + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = ( + F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1)) + * STD_MAX + * delta + * 3 + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta_3sigma + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + latent, coords + ) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) + + if return_deltaXY: + return sampled_latents, sampled_weights, delta_3sigma + else: + return sampled_latents, sampled_weights + + +def sampler_gaussian_fix(latent, mean, image_size, point_num=49): + # latent [B, H*W, D] + # mean [B, 2, H, W] + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + latent, coords + ) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return sampled_latents, sampled_weights + + +def sampler_gaussian_fix_pyramid( + latent, feat_pyramid, scale_weight, mean, image_size, point_num=25 +): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # scale weight [B, H*W, layer_num] + + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + sampled_latents = [] + for i in range(len(feat_pyramid)): + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = (centroid + delta) / 2**i + coords = rearrange( + coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W + ) + sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords)) + + sampled_latents = torch.stack( + sampled_latents, dim=1 + ) # [B, layer_num, dim, H*W, point_num] + sampled_latents = sampled_latents.permute( + 0, 3, 4, 2, 1 + ) # [B, H*W, point_num, dim, layer_num] + scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num] + vis_out = scale_weight + scale_weight = torch.unsqueeze( + torch.unsqueeze(scale_weight, dim=2), dim=2 + ) # [B, HW, 1, 1, layer_num] + + weighted_latent = torch.sum( + sampled_latents * scale_weight, dim=-1 + ) # [B, H*W, point_num, dim] + + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return weighted_latent, sampled_weights, vis_out + + +def sampler_gaussian_pyramid( + latent, feat_pyramid, scale_weight, mean, std, image_size, point_num=25 +): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # scale weight [B, H*W, layer_num] + + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = ( + std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3 + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + sampled_latents = [] + for i in range(len(feat_pyramid)): + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = (centroid + delta_3sigma) / 2**i + coords = rearrange( + coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W + ) + sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords)) + + sampled_latents = torch.stack( + sampled_latents, dim=1 + ) # [B, layer_num, dim, H*W, point_num] + sampled_latents = sampled_latents.permute( + 0, 3, 4, 2, 1 + ) # [B, H*W, point_num, dim, layer_num] + scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num] + vis_out = scale_weight + scale_weight = torch.unsqueeze( + torch.unsqueeze(scale_weight, dim=2), dim=2 + ) # [B, HW, 1, 1, layer_num] + + weighted_latent = torch.sum( + sampled_latents * scale_weight, dim=-1 + ) # [B, H*W, point_num, dim] + + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return weighted_latent, sampled_weights, vis_out + + +def sampler_gaussian_fix_MH(latent, mean, image_size, point_num=25): + """different heads have different mean""" + # latent [B, H*W, D] + # mean [B, 2, H, W, heands] + + H, W = image_size + B, HW, D = latent.shape + _, _, _, _, HEADS = mean.shape + STD_MAX = 20 + latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W) + mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = ( + torch.stack(torch.meshgrid(dy, dx), axis=-1) + .to(mean.device) + .repeat(HEADS, 1, 1, 1) + ) # [HEADS, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2) + coords = centroid + delta + coords = rearrange( + coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS + ) + sampled_latents = bilinear_sampler(latent, coords) # [B, dim, H*W*HEADS, pointnum] + sampled_latents = sampled_latents.permute( + 0, 2, 3, 1 + ) # [B, H*W*HEADS, pointnum, dim] + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + return sampled_latents, sampled_weights + + +def sampler_gaussian_fix_pyramid_MH( + latent, feat_pyramid, scale_head_weight, mean, image_size, point_num=25 +): + # latent [B, H*W, D] + # mean [B, 2, H, W, heands] + # scale_head weight [B, H*W, layer_num*heads] + + H, W = image_size + B, HW, D = latent.shape + _, _, _, _, HEADS = mean.shape + + latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W) + mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + sampled_latents = [] + centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2) + for i in range(len(feat_pyramid)): + coords = (centroid) / 2**i + delta + coords = rearrange( + coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS + ) + sampled_latents.append( + bilinear_sampler(feat_pyramid[i], coords) + ) # [B, dim, H*W*HEADS, point_num] + + sampled_latents = torch.stack( + sampled_latents, dim=1 + ) # [B, layer_num, dim, H*W*HEADS, point_num] + sampled_latents = sampled_latents.permute( + 0, 3, 4, 2, 1 + ) # [B, H*W*HEADS, point_num, dim, layer_num] + + scale_head_weight = scale_head_weight.reshape(B, H * W * HEADS, -1) + scale_head_weight = F.softmax(scale_head_weight, dim=2) # [B, H*W*HEADS, layer_num] + scale_head_weight = torch.unsqueeze( + torch.unsqueeze(scale_head_weight, dim=2), dim=2 + ) # [B, H*W*HEADS, 1, 1, layer_num] + + weighted_latent = torch.sum( + sampled_latents * scale_head_weight, dim=-1 + ) # [B, H*W*HEADS, point_num, dim] + + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return weighted_latent, sampled_weights + + +def sampler(feat, center, window_size): + # feat [B, C, H, W] + # center [B, 2, H, W] + center = center.permute(0, 2, 3, 1) # [B, H, W, 2] + B, H, W, C = center.shape + + radius = window_size // 2 + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + center.device + ) # [B*H*W, window_size, point_num**0.5, 2] + + center = center.reshape(B * H * W, 1, 1, 2) + coords = center + delta + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + feat, coords + ) # [B*H*W, dim, window_size, window_size] + # sampled_latents = sampled_latents.permute(0, 2, 3, 1) + + return sampled_latents + + +def retrieve_tokens(feat, center, window_size, sampler): + # feat [B, C, H, W] + # center [B, 2, H, W] + radius = window_size // 2 + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + center.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + B, H, W, C = center.shape + centroid = center.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + if sampler == "nn": + sampled_latents = indexing(feat, coords) + elif sampler == "bilinear": + sampled_latents = bilinear_sampler(feat, coords) + else: + raise ValueError("invalid sampler") + # [B, dim, H*W, point_num] + + return sampled_latents + + +def pyramid_retrieve_tokens( + feat_pyramid, center, image_size, window_sizes, sampler="bilinear" +): + center = center.permute(0, 2, 3, 1) # [B, H, W, 2] + sampled_latents_pyramid = [] + for idx in range(len(window_sizes)): + sampled_latents_pyramid.append( + retrieve_tokens(feat_pyramid[idx], center, window_sizes[idx], sampler) + ) + center = center / 2 + + return torch.cat(sampled_latents_pyramid, dim=-1) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + x = self.net(x) + return x + + +class MLP(nn.Module): + def __init__(self, in_dim=22, out_dim=1, innter_dim=96, depth=5): + super().__init__() + self.FC1 = nn.Linear(in_dim, innter_dim) + self.FC_out = nn.Linear(innter_dim, out_dim) + self.relu = torch.nn.LeakyReLU(0.2) + self.FC_inter = nn.ModuleList( + [nn.Linear(innter_dim, innter_dim) for i in range(depth)] + ) + + def forward(self, x): + x = self.FC1(x) + x = self.relu(x) + for inter_fc in self.FC_inter: + x = inter_fc(x) + x = self.relu(x) + x = self.FC_out(x) + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, dim, heads, num_kv_tokens, cfg, rpe_bias=None, use_rpe=False): + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.num_kv_tokens = num_kv_tokens + self.scale = (dim / heads) ** -0.5 + self.rpe = cfg.rpe + self.attend = nn.Softmax(dim=-1) + self.use_rpe = use_rpe + + if use_rpe: + if rpe_bias is None: + if self.rpe == "element-wise": + self.rpe_bias = nn.Parameter( + torch.zeros(heads, self.num_kv_tokens, dim // heads) + ) + elif self.rpe == "head-wise": + self.rpe_bias = nn.Parameter( + torch.zeros(1, heads, 1, self.num_kv_tokens) + ) + elif self.rpe == "token-wise": + self.rpe_bias = nn.Parameter( + torch.zeros(1, 1, 1, self.num_kv_tokens) + ) # 81 is point_num + elif self.rpe == "implicit": + pass + # self.implicit_pe_fn = MLP(in_dim=22, out_dim=self.dim, innter_dim=int(self.dim//2.4), depth=2) + # raise ValueError('Implicit Encoding Not Implemented') + elif self.rpe == "element-wise-value": + self.rpe_bias = nn.Parameter( + torch.zeros(heads, self.num_kv_tokens, dim // heads) + ) + self.rpe_value = nn.Parameter(torch.randn(self.num_kv_tokens, dim)) + else: + raise ValueError("Not Implemented") + else: + self.rpe_bias = rpe_bias + + def attend_with_rpe(self, Q, K, rpe_bias): + Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads) + K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads) + + dots = ( + einsum("bhid, bhjd -> bhij", Q, K) * self.scale + ) # (b hw) heads 1 pointnum + if self.use_rpe: + if self.rpe == "element-wise": + rpe_bias_weight = ( + einsum("bhid, hjd -> bhij", Q, rpe_bias) * self.scale + ) # (b hw) heads 1 pointnum + dots = dots + rpe_bias_weight + elif self.rpe == "implicit": + pass + rpe_bias_weight = ( + einsum("bhid, bhjd -> bhij", Q, rpe_bias) * self.scale + ) # (b hw) heads 1 pointnum + dots = dots + rpe_bias_weight + elif self.rpe == "head-wise" or self.rpe == "token-wise": + dots = dots + rpe_bias + + return self.attend(dots), dots + + def forward(self, Q, K, V, rpe_bias=None): + if self.use_rpe: + if rpe_bias is None or self.rpe == "element-wise": + rpe_bias = self.rpe_bias + else: + rpe_bias = rearrange( + rpe_bias, "b hw pn (heads d) -> (b hw) heads pn d", heads=self.heads + ) + attn, dots = self.attend_with_rpe(Q, K, rpe_bias) + else: + attn, dots = self.attend_with_rpe(Q, K, None) + B, HW, _ = Q.shape + + if V is not None: + V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads) + + out = einsum("bhij, bhjd -> bhid", attn, V) + out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW) + else: + out = None + + # dots = torch.squeeze(dots, 2) + # dots = rearrange(dots, '(b hw) heads d -> b hw (heads d)', b=B, hw=HW) + + return out, dots diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/encoders.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..92d2d6316dceb26d2e45399f50f94e8fcb367c71 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/FlowFormer/encoders.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import timm +import numpy as np + + +class twins_svt_large(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model("twins_svt_large", pretrained=pretrained) + + del self.svt.head + del self.svt.patch_embeds[2] + del self.svt.patch_embeds[2] + del self.svt.blocks[2] + del self.svt.blocks[2] + del self.svt.pos_block[2] + del self.svt.pos_block[2] + self.svt.norm.weight.requires_grad = False + self.svt.norm.bias.requires_grad = False + + def forward(self, x, data=None, layer=2, return_feat=False): + B = x.shape[0] + if return_feat: + feat = [] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip( + self.svt.patch_embeds, + self.svt.pos_drops, + self.svt.blocks, + self.svt.pos_block, + ) + ): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + if return_feat: + feat.append(x) + if i == layer - 1: + break + if return_feat: + return x, feat + return x + + def compute_params(self, layer=2): + num = 0 + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip( + self.svt.patch_embeds, + self.svt.pos_drops, + self.svt.blocks, + self.svt.pos_block, + ) + ): + for param in embed.parameters(): + num += np.prod(param.size()) + + for param in drop.parameters(): + num += np.prod(param.size()) + + for param in blocks.parameters(): + num += np.prod(param.size()) + + for param in pos_blk.parameters(): + num += np.prod(param.size()) + + if i == layer - 1: + break + + for param in self.svt.head.parameters(): + num += np.prod(param.size()) + + return num + + +class twins_svt_large_context(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model("twins_svt_large_context", pretrained=pretrained) + + def forward(self, x, data=None, layer=2): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip( + self.svt.patch_embeds, + self.svt.pos_drops, + self.svt.blocks, + self.svt.pos_block, + ) + ): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + if i == layer - 1: + break + + return x + + +if __name__ == "__main__": + m = twins_svt_large() + input = torch.randn(2, 3, 400, 800) + out = m.extract_feature(input) + print(out.shape) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/corr.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/corr.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5819f0d56e07de2738b6e5523eada84078a302 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/corr.py @@ -0,0 +1,90 @@ +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/datasets.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..106867cb9328cffd1201a2c863a35bb1813d6133 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/datasets.py @@ -0,0 +1,297 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from .utils import frame_utils +from .utils.augmentor import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[..., None], (1, 1, 3)) + img2 = np.tile(img2[..., None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1, img2, flow, valid.float() + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__( + self, aug_params=None, split="training", root="datasets/Sintel", dstype="clean" + ): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, "flow") + image_root = osp.join(root, split, dstype) + + if split == "test": + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, "*.png"))) + for i in range(len(image_list) - 1): + self.image_list += [[image_list[i], image_list[i + 1]]] + self.extra_info += [(scene, i)] # scene and frame_id + + if split != "test": + self.flow_list += sorted(glob(osp.join(flow_root, scene, "*.flo"))) + + +class FlyingChairs(FlowDataset): + def __init__( + self, aug_params=None, split="train", root="datasets/FlyingChairs_release/data" + ): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, "*.ppm"))) + flows = sorted(glob(osp.join(root, "*.flo"))) + assert len(images) // 2 == len(flows) + + split_list = np.loadtxt("chairs_split.txt", dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split == "training" and xid == 1) or ( + split == "validation" and xid == 2 + ): + self.flow_list += [flows[i]] + self.image_list += [[images[2 * i], images[2 * i + 1]]] + + +class FlyingThings3D(FlowDataset): + def __init__( + self, + aug_params=None, + root="datasets/FlyingThings3D", + dstype="frames_cleanpass", + split="training", + ): + super(FlyingThings3D, self).__init__(aug_params) + + split_dir = "TRAIN" if split == "training" else "TEST" + for cam in ["left"]: + for direction in ["into_future", "into_past"]: + image_dirs = sorted(glob(osp.join(root, dstype, f"{split_dir}/*/*"))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + flow_dirs = sorted( + glob(osp.join(root, f"optical_flow/{split_dir}/*/*")) + ) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, "*.png"))) + flows = sorted(glob(osp.join(fdir, "*.pfm"))) + for i in range(len(flows) - 1): + if direction == "into_future": + self.image_list += [[images[i], images[i + 1]]] + self.flow_list += [flows[i]] + elif direction == "into_past": + self.image_list += [[images[i + 1], images[i]]] + self.flow_list += [flows[i + 1]] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split="training", root="datasets/KITTI"): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == "testing": + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, "image_2/*_10.png"))) + images2 = sorted(glob(osp.join(root, "image_2/*_11.png"))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split("/")[-1] + self.extra_info += [[frame_id]] + self.image_list += [[img1, img2]] + + if split == "training": + self.flow_list = sorted(glob(osp.join(root, "flow_occ/*_10.png"))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root="datasets/HD1k"): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted( + glob(os.path.join(root, "hd1k_flow_gt", "flow_occ/%06d_*.png" % seq_ix)) + ) + images = sorted( + glob(os.path.join(root, "hd1k_input", "image_2/%06d_*.png" % seq_ix)) + ) + + if len(flows) == 0: + break + + for i in range(len(flows) - 1): + self.flow_list += [flows[i]] + self.image_list += [[images[i], images[i + 1]]] + + seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS="C+T+K+S+H"): + """Create the data loader for the corresponding trainign set""" + + if args.stage == "chairs": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.1, + "max_scale": 1.0, + "do_flip": True, + } + train_dataset = FlyingChairs(aug_params, split="training") + + elif args.stage == "things": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.4, + "max_scale": 0.8, + "do_flip": True, + } + clean_dataset = FlyingThings3D(aug_params, dstype="frames_cleanpass") + final_dataset = FlyingThings3D(aug_params, dstype="frames_finalpass") + train_dataset = clean_dataset + final_dataset + + elif args.stage == "sintel": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.2, + "max_scale": 0.6, + "do_flip": True, + } + things = FlyingThings3D(aug_params, dstype="frames_cleanpass") + sintel_clean = MpiSintel(aug_params, split="training", dstype="clean") + sintel_final = MpiSintel(aug_params, split="training", dstype="final") + + if TRAIN_DS == "C+T+K+S+H": + kitti = KITTI( + { + "crop_size": args.image_size, + "min_scale": -0.3, + "max_scale": 0.5, + "do_flip": True, + } + ) + hd1k = HD1K( + { + "crop_size": args.image_size, + "min_scale": -0.5, + "max_scale": 0.2, + "do_flip": True, + } + ) + train_dataset = ( + 100 * sintel_clean + + 100 * sintel_final + + 200 * kitti + + 5 * hd1k + + things + ) + + elif TRAIN_DS == "C+T+K/S": + train_dataset = 100 * sintel_clean + 100 * sintel_final + things + + elif args.stage == "kitti": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.2, + "max_scale": 0.4, + "do_flip": False, + } + train_dataset = KITTI(aug_params, split="training") + + train_loader = data.DataLoader( + train_dataset, + batch_size=args.batch_size, + pin_memory=False, + shuffle=True, + num_workers=128, + drop_last=True, + ) + + print("Training with %d image pairs" % len(train_dataset)) + return train_loader diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/extractor.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7f2dba3ff940708b4db73134e4f912c5734f8c --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride + ) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/loss.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5c3ff7bb3bd117c9f9735282e9eb1f42e50f1c --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/loss.py @@ -0,0 +1,40 @@ +import torch + +MAX_FLOW = 400 + + +def sequence_loss(flow_preds, flow_gt, valid, cfg): + """Loss function defined over sequence of flow predictions""" + + gamma = cfg.gamma + max_flow = cfg.max_flow + n_predictions = len(flow_preds) + flow_loss = 0.0 + flow_gt_thresholds = [5, 10, 20] + + # exlude invalid pixels and extremely large diplacements + mag = torch.sum(flow_gt**2, dim=1).sqrt() + valid = (valid >= 0.5) & (mag < max_flow) + + for i in range(n_predictions): + i_weight = gamma ** (n_predictions - i - 1) + i_loss = (flow_preds[i] - flow_gt).abs() + flow_loss += i_weight * (valid[:, None] * i_loss).mean() + + epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt() + epe = epe.view(-1)[valid.view(-1)] + + metrics = { + "epe": epe.mean().item(), + "1px": (epe < 1).float().mean().item(), + "3px": (epe < 3).float().mean().item(), + "5px": (epe < 5).float().mean().item(), + } + + flow_gt_length = torch.sum(flow_gt**2, dim=1).sqrt() + flow_gt_length = flow_gt_length.view(-1)[valid.view(-1)] + for t in flow_gt_thresholds: + e = epe[flow_gt_length < t] + metrics.update({f"{t}-th-5px": (e < 5).float().mean().item()}) + + return flow_loss, metrics diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/optimizer/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91b70e4218cacb83afa708dcc46edb0460e5def5 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/optimizer/__init__.py @@ -0,0 +1,118 @@ +import torch +from torch.optim.lr_scheduler import ( + MultiStepLR, + CosineAnnealingLR, + ExponentialLR, + OneCycleLR, +) + + +def fetch_optimizer(model, cfg): + """Create the optimizer and learning rate scheduler""" + # optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) + + # scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, + # pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') + optimizer = build_optimizer(model, cfg) + scheduler = build_scheduler(cfg, optimizer) + + return optimizer, scheduler + + +def build_optimizer(model, config): + name = config.optimizer + lr = config.canonical_lr + + if name == "adam": + return torch.optim.Adam( + model.parameters(), + lr=lr, + weight_decay=config.adam_decay, + eps=config.epsilon, + ) + elif name == "adamw": + if hasattr(config, "twins_lr_factor"): + factor = config.twins_lr_factor + print("[Decrease lr of pre-trained model by factor {}]".format(factor)) + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "feat_encoder" not in n + and "context_encoder" not in n + and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if ("feat_encoder" in n or "context_encoder" in n) + and p.requires_grad + ], + "lr": lr * factor, + }, + ] + full = [n for n, _ in model.named_parameters()] + return torch.optim.AdamW( + param_dicts, lr=lr, weight_decay=config.adamw_decay, eps=config.epsilon + ) + else: + return torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=config.adamw_decay, + eps=config.epsilon, + ) + else: + raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") + + +def build_scheduler(config, optimizer): + """ + Returns: + scheduler (dict):{ + 'scheduler': lr_scheduler, + 'interval': 'step', # or 'epoch' + } + """ + # scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} + name = config.scheduler + lr = config.canonical_lr + + if name == "OneCycleLR": + # scheduler = OneCycleLR(optimizer, ) + if hasattr(config, "twins_lr_factor"): + factor = config.twins_lr_factor + scheduler = OneCycleLR( + optimizer, + [lr, lr * factor], + config.num_steps + 100, + pct_start=0.05, + cycle_momentum=False, + anneal_strategy=config.anneal_strategy, + ) + else: + scheduler = OneCycleLR( + optimizer, + lr, + config.num_steps + 100, + pct_start=0.05, + cycle_momentum=False, + anneal_strategy=config.anneal_strategy, + ) + # elif name == 'MultiStepLR': + # scheduler.update( + # {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) + # elif name == 'CosineAnnealing': + # scheduler = CosineAnnealingLR(optimizer, config.num_steps+100) + # scheduler.update( + # {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) + # elif name == 'ExponentialLR': + # scheduler.update( + # {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) + else: + raise NotImplementedError() + + return scheduler diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/position_encoding.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..e1921f8ca650b8f226fce31794cae1332df31f48 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/position_encoding.py @@ -0,0 +1,101 @@ +from loguru import logger +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + div_term = torch.exp( + torch.arange(0, d_model // 2, 2).float() + * (-math.log(10000.0) / d_model // 2) + ) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer("pe", pe.unsqueeze(0)) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, : x.size(2), : x.size(3)] + + +class LinearPositionEncoding(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = ( + torch.ones(max_shape).cumsum(0).float().unsqueeze(0) - 1 + ) / max_shape[0] + x_position = ( + torch.ones(max_shape).cumsum(1).float().unsqueeze(0) - 1 + ) / max_shape[1] + div_term = torch.arange(0, d_model // 2, 2).float() + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term * math.pi) + pe[1::4, :, :] = torch.cos(x_position * div_term * math.pi) + pe[2::4, :, :] = torch.sin(y_position * div_term * math.pi) + pe[3::4, :, :] = torch.cos(y_position * div_term * math.pi) + + self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + # assert x.shape[2] == 80 and x.shape[3] == 80 + + return x + self.pe[:, :, : x.size(2), : x.size(3)] + + +class LearnedPositionEncoding(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(80, 80)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + self.pe = nn.Parameter(torch.randn(1, max_shape[0], max_shape[1], d_model)) + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + # assert x.shape[2] == 80 and x.shape[3] == 80 + + return x + self.pe diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/raft.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..999f6e7cce7947b8a2536997dcfb962b2604736b --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/raft.py @@ -0,0 +1,155 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from update import BasicUpdateBlock, SmallUpdateBlock +from extractor import BasicEncoder, SmallEncoder +from corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if "dropout" not in self.args: + self.args.dropout = 0 + + if "alternate_corr" not in self.args: + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder( + output_dim=128, norm_fn="instance", dropout=args.dropout + ) + self.cnet = SmallEncoder( + output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout + ) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder( + output_dim=256, norm_fn="instance", dropout=args.dropout + ) + self.cnet = BasicEncoder( + output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout + ) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H // 8, W // 8).to(img.device) + coords1 = coords_grid(N, H // 8, W // 8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def forward( + self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False + ): + """Estimate optical flow between pair of frames""" + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + return flow_predictions diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/update.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/update.py new file mode 100644 index 0000000000000000000000000000000000000000..ced6df0658da475bf660b98475acab7e100cabaa --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/update.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/augmentor.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb770bb1844c8d586eb51f2b530e6037f3e6cb4 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/augmentor.py @@ -0,0 +1,336 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F +from . import flow_transforms + + +class FlowAugmentor: + def __init__( + self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, pwc_aug=False + ): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14 + ) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + self.pwc_aug = pwc_aug + if self.pwc_aug: + print("[Using pwc-style spatial augmentation]") + + def color_transform(self, img1, img2): + """Photometric augmentation""" + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array( + self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 + ) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """Occlusion augmentation""" + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd) + ) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize( + img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + img2 = cv2.resize( + img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + flow = cv2.resize( + flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + if img1.shape[0] == self.crop_size[0]: + y0 = 0 + else: + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + if img1.shape[1] == self.crop_size[1]: + x0 = 0 + else: + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + if self.pwc_aug: + th, tw = self.crop_size + schedule = [0.5, 1.0] # initial coeff, final_coeff, half life + difficulty = np.random.uniform(0, 1) + schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * ( + 2 / (1 + np.exp(-1.0986 * difficulty)) - 1 + ) + spatial_augmentor = flow_transforms.SpatialAug( + [th, tw], + scale=[0.4, 0.03, 0.2], + rot=[0.4, 0.03], + trans=[0.4, 0.03], + squeeze=[0.3, 0.0], + schedule_coeff=schedule_coeff, + order=1, + black=False, + ) + flow = np.concatenate( + [flow, np.ones((flow.shape[0], flow.shape[1], 1))], axis=-1 + ) + augmented, flow_valid = spatial_augmentor([img1, img2], flow) + flow = flow_valid[:, :, :2] + img1 = augmented[0] + img2 = augmented[1] + + else: + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter( + brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14 + ) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array( + self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 + ) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid >= 1] + flow0 = flow[valid >= 1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:, 0]).astype(np.int32) + yy = np.round(coords1[:, 1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + pad_t = 0 + pad_b = 0 + pad_l = 0 + pad_r = 0 + if self.crop_size[0] > img1.shape[0]: + pad_b = self.crop_size[0] - img1.shape[0] + if self.crop_size[1] > img1.shape[1]: + pad_r = self.crop_size[1] - img1.shape[1] + if pad_b != 0 or pad_r != 0: + img1 = np.pad( + img1, + ((pad_t, pad_b), (pad_l, pad_r), (0, 0)), + "constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + img2 = np.pad( + img2, + ((pad_t, pad_b), (pad_l, pad_r), (0, 0)), + "constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + flow = np.pad( + flow, + ((pad_t, pad_b), (pad_l, pad_r), (0, 0)), + "constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + valid = np.pad( + valid, + ((pad_t, pad_b), (pad_l, pad_r)), + "constant", + constant_values=((0, 0), (0, 0)), + ) + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd) + ) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize( + img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + img2 = cv2.resize( + img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + flow, valid = self.resize_sparse_flow_map( + flow, valid, fx=scale_x, fy=scale_y + ) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + return img1, img2, flow, valid + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/datasets.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..265bf09f3a6232f026b82a5435f9c6d1084d2874 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/datasets.py @@ -0,0 +1,577 @@ +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from utils import frame_utils +from utils.augmentor import FlowAugmentor, SparseFlowAugmentor + +# from utils import flow_transforms + +from torchvision.utils import save_image + +from utils import flow_viz +import cv2 +from utils.utils import coords_grid, bilinear_sampler + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + # print(self.flow_list[index]) + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0], test=self.is_test) + img2 = frame_utils.read_gen(self.image_list[index][1], test=self.is_test) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[..., None], (1, 1, 3)) + img2 = np.tile(img2[..., None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + return img1, img2, flow, valid.float() + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel_submission(FlowDataset): + def __init__( + self, aug_params=None, split="test", root="datasets/Sintel", dstype="clean" + ): + super(MpiSintel_submission, self).__init__(aug_params) + flow_root = osp.join(root, split, "flow") + image_root = osp.join(root, split, dstype) + + if split == "test": + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, "*.png"))) + for i in range(len(image_list) - 1): + self.image_list += [[image_list[i], image_list[i + 1]]] + self.extra_info += [(scene, i)] # scene and frame_id + + if split != "test": + self.flow_list += sorted(glob(osp.join(flow_root, scene, "*.flo"))) + + +class MpiSintel(FlowDataset): + def __init__( + self, aug_params=None, split="training", root="datasets/Sintel", dstype="clean" + ): + super(MpiSintel, self).__init__(aug_params) + + root = "s3://" + + self.image_list = [] + with open("./flow_dataset/Sintel/Sintel_" + dstype + "_png.txt") as f: + images = f.readlines() + for img1, img2 in zip(images[0::2], images[1::2]): + self.image_list.append([root + img1.strip(), root + img2.strip()]) + + self.flow_list = [] + with open("./flow_dataset/Sintel/Sintel_" + dstype + "_flo.txt") as f: + flows = f.readlines() + for flow in flows: + self.flow_list.append(root + flow.strip()) + + assert len(self.image_list) == len(self.flow_list) + + self.extra_info = [] + with open("./flow_dataset/Sintel/Sintel_" + dstype + "_extra_info.txt") as f: + info = f.readlines() + for scene, id in zip(info[0::2], info[1::2]): + self.extra_info.append((scene.strip(), int(id.strip()))) + # flow_root = osp.join(root, split, 'flow') + # image_root = osp.join(root, split, dstype) + + # if split == 'test': + # self.is_test = True + + # for scene in os.listdir(image_root): + # image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + # for i in range(len(image_list)-1): + # self.image_list += [ [image_list[i], image_list[i+1]] ] + # self.extra_info += [ (scene, i) ] # scene and frame_id + + # if split != 'test': + # self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + +class FlyingChairs(FlowDataset): + def __init__( + self, aug_params=None, split="train", root="datasets/FlyingChairs_release/data" + ): + super(FlyingChairs, self).__init__(aug_params) + + root = "s3://" + + with open("./flow_dataset/flying_chairs/flyingchairs_ppm.txt") as f: + images = f.readlines() + images = [root + img.strip() for img in images] + with open("./flow_dataset/flying_chairs/flyingchairs_flo.txt") as f: + flows = f.readlines() + flows = [root + flo.strip() for flo in flows] + + # images = sorted(glob(osp.join(root, '*.ppm'))) + # flows = sorted(glob(osp.join(root, '*.flo'))) + assert len(images) // 2 == len(flows) + + split_list = np.loadtxt("chairs_split.txt", dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split == "training" and xid == 1) or ( + split == "validation" and xid == 2 + ): + self.flow_list += [flows[i]] + self.image_list += [[images[2 * i], images[2 * i + 1]]] + + +class FlyingThings3D(FlowDataset): + def __init__( + self, aug_params=None, root="datasets/FlyingThings3D", dstype="frames_cleanpass" + ): + super(FlyingThings3D, self).__init__(aug_params) + + root = "s3://" + + self.image_list = [] + with open( + "./flow_dataset/flying_things/flyingthings_" + dstype + "_png.txt" + ) as f: + images = f.readlines() + for img1, img2 in zip(images[0::2], images[1::2]): + self.image_list.append([root + img1.strip(), root + img2.strip()]) + self.flow_list = [] + with open( + "./flow_dataset/flying_things/flyingthings_" + dstype + "_pfm.txt" + ) as f: + flows = f.readlines() + for flow in flows: + self.flow_list.append(root + flow.strip()) + + # for cam in ['left']: + # for direction in ['into_future', 'into_past']: + # image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + # image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + # flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + # flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + # for idir, fdir in zip(image_dirs, flow_dirs): + # images = sorted(glob(osp.join(idir, '*.png')) ) + # flows = sorted(glob(osp.join(fdir, '*.pfm')) ) + # for i in range(len(flows)-1): + # if direction == 'into_future': + # self.image_list += [ [images[i], images[i+1]] ] + # self.flow_list += [ flows[i] ] + # elif direction == 'into_past': + # self.image_list += [ [images[i+1], images[i]] ] + # self.flow_list += [ flows[i+1] ] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split="training", root="datasets/KITTI"): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == "testing": + self.is_test = True + + root = "s3://" + + self.image_list = [] + with open("./flow_dataset/KITTI/KITTI_{}_image.txt".format(split)) as f: + images = f.readlines() + for img1, img2 in zip(images[0::2], images[1::2]): + self.image_list.append([root + img1.strip(), root + img2.strip()]) + + self.extra_info = [] + with open("./flow_dataset/KITTI/KITTI_{}_extra_info.txt".format(split)) as f: + info = f.readlines() + for id in info: + self.extra_info.append([id.strip()]) + + if split == "training": + self.flow_list = [] + with open("./flow_dataset/KITTI/KITTI_{}_flow.txt".format(split)) as f: + flow = f.readlines() + for flo in flow: + self.flow_list.append(root + flo.strip()) + # root = osp.join(root, split) + # images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + # images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + # for img1, img2 in zip(images1, images2): + # frame_id = img1.split('/')[-1] + # self.extra_info += [ [frame_id] ] + # self.image_list += [ [img1, img2] ] + + # if split == 'training': + # self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class AutoFlow(data.Dataset): + def __init__(self, num_steps, crop_size, log_dir, root="datasets/"): + super(AutoFlow, self).__init__() + + root = "s3://" + self.image_list = [] + with open("./flow_dataset/AutoFlow/AutoFlow_image.txt") as f: + images = f.readlines() + for img1, img2 in zip(images[0::2], images[1::2]): + self.image_list.append([root + img1.strip(), root + img2.strip()]) + self.flow_list = [] + with open("./flow_dataset/AutoFlow/AutoFlow_flow.txt") as f: + flows = f.readlines() + for flow in flows: + self.flow_list.append(root + flow.strip()) + + self.crop_size = crop_size + self.log_dir = log_dir + self.num_steps = num_steps + self.scale = 1 + self.order = 1 + self.black = False + self.noise = 0 + self.is_test = False + self.init_seed = False + + self.iter_counts = 0 + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) * 100 + + def __getitem__(self, index): + # print(self.flow_list[index]) + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0], test=self.is_test) + img2 = frame_utils.read_gen(self.image_list[index][1], test=self.is_test) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + index = index % len(self.image_list) + valid = None + + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + # For PWC-style augmentation, pixel values are in [0, 1] + img1 = np.array(img1).astype(np.uint8) / 255.0 + img2 = np.array(img2).astype(np.uint8) / 255.0 + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[..., None], (1, 1, 3)) + img2 = np.tile(img2[..., None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + iter_counts = self.iter_counts + self.iter_counts = self.iter_counts + 1 + print(self.iter_counts) + th, tw = self.crop_size + schedule = [0.5, 1.0, self.num_steps] # initial coeff, final_coeff, half life + schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * ( + 2 / (1 + np.exp(-1.0986 * iter_counts / schedule[2])) - 1 + ) + + co_transform = flow_transforms.Compose( + [ + flow_transforms.Scale(self.scale, order=self.order), + flow_transforms.SpatialAug( + [th, tw], + scale=[0.4, 0.03, 0.2], + rot=[0.4, 0.03], + trans=[0.4, 0.03], + squeeze=[0.3, 0.0], + schedule_coeff=schedule_coeff, + order=self.order, + black=self.black, + ), + flow_transforms.PCAAug(schedule_coeff=schedule_coeff), + flow_transforms.ChromaticAug( + schedule_coeff=schedule_coeff, noise=self.noise + ), + ] + ) + + flow = np.concatenate( + [flow, np.ones((flow.shape[0], flow.shape[1], 1))], axis=-1 + ) + augmented, flow_valid = co_transform([img1, img2], flow) + flow = flow_valid[:, :, :2] + valid = flow_valid[:, :, 2:3] + + img1 = augmented[0] + img2 = augmented[1] + if np.random.binomial(1, 0.5): + # sx = int(np.random.uniform(25,100)) + # sy = int(np.random.uniform(25,100)) + sx = int(np.random.uniform(50, 125)) + sy = int(np.random.uniform(50, 125)) + # sx = int(np.random.uniform(50,150)) + # sy = int(np.random.uniform(50,150)) + cx = int(np.random.uniform(sx, img2.shape[0] - sx)) + cy = int(np.random.uniform(sy, img2.shape[1] - sy)) + img2[cx - sx : cx + sx, cy - sy : cy + sy] = np.mean(np.mean(img2, 0), 0)[ + np.newaxis, np.newaxis + ] + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid).permute(2, 0, 1).float() + valid = valid[0] + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1 * 255, img2 * 255, flow, valid.float() + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root="datasets/HD1k"): + super(HD1K, self).__init__(aug_params, sparse=True) + + root = "s3://" + self.image_list = [] + with open("./flow_dataset/HD1K/HD1K_image.txt") as f: + images = f.readlines() + for img1, img2 in zip(images[0::2], images[1::2]): + self.image_list.append([root + img1.strip(), root + img2.strip()]) + self.flow_list = [] + with open("./flow_dataset/HD1K/HD1K_flow.txt") as f: + flows = f.readlines() + for flow in flows: + self.flow_list.append(root + flow.strip()) + + # seq_ix = 0 + # while 1: + # flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + # images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + # if len(flows) == 0: + # break + + # for i in range(len(flows)-1): + # self.flow_list += [flows[i]] + # self.image_list += [ [images[i], images[i+1]] ] + + # seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS="C+T+K+S+H"): + """Create the data loader for the corresponding trainign set""" + + if args.stage == "chairs": + if hasattr(args.percostformer, "pwc_aug") and args.percostformer.pwc_aug: + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.1, + "max_scale": 1.0, + "do_flip": True, + "pwc_aug": True, + } + else: + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.1, + "max_scale": 1.0, + "do_flip": True, + } + train_dataset = FlyingChairs(aug_params, split="training") + + elif args.stage == "things": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.4, + "max_scale": 0.8, + "do_flip": True, + } + clean_dataset = FlyingThings3D(aug_params, dstype="frames_cleanpass") + final_dataset = FlyingThings3D(aug_params, dstype="frames_finalpass") + train_dataset = clean_dataset + final_dataset + + elif args.stage == "sintel": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.2, + "max_scale": 0.6, + "do_flip": True, + } + things = FlyingThings3D(aug_params, dstype="frames_cleanpass") + sintel_clean = MpiSintel(aug_params, split="training", dstype="clean") + sintel_final = MpiSintel(aug_params, split="training", dstype="final") + + if TRAIN_DS == "C+T+K+S+H": + kitti = KITTI( + { + "crop_size": args.image_size, + "min_scale": -0.3, + "max_scale": 0.5, + "do_flip": True, + } + ) + hd1k = HD1K( + { + "crop_size": args.image_size, + "min_scale": -0.5, + "max_scale": 0.2, + "do_flip": True, + } + ) + train_dataset = ( + 100 * sintel_clean + + 100 * sintel_final + + 200 * kitti + + 5 * hd1k + + things + ) + + elif TRAIN_DS == "C+T+K/S": + train_dataset = 100 * sintel_clean + 100 * sintel_final + things + + elif args.stage == "kitti": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.2, + "max_scale": 0.4, + "do_flip": False, + } + train_dataset = KITTI(aug_params, split="training") + + elif args.stage == "autoflow-pwcaug": + aug_params = { + "num_steps": args.trainer.num_steps, + "crop_size": args.image_size, + "log_dir": args.log_dir, + } + train_dataset = AutoFlow(**aug_params) + + train_loader = data.DataLoader( + train_dataset, + batch_size=args.batch_size, + pin_memory=False, + shuffle=True, + num_workers=args.batch_size, + drop_last=True, + ) + + print("Training with %d image pairs" % len(train_dataset)) + return train_loader + + +if __name__ == "__main__": + aug_params = { + "crop_size": [400, 720], + "min_scale": -0.2, + "max_scale": 0, + "do_flip": True, + } + aug_params["min_scale"] = -0.2 + aug_params["min_stretch"] = -0.2 + sintel_clean = MpiSintel(aug_params, split="training", dstype="clean") + + train_loader = data.DataLoader( + sintel_clean, + batch_size=1, + pin_memory=False, + shuffle=True, + num_workers=1, + drop_last=True, + ) + + for i_batch, data_blob in enumerate(train_loader): + image1, image2, flow, valid = [x for x in data_blob] + print(i_batch, image1.shape) + + # if i_batch==5: + # exit() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/flow_transforms.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/flow_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f1332bcca04bf718f3b6e250c031d3e0cd1e58 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/flow_transforms.py @@ -0,0 +1,657 @@ +from __future__ import division +import torch +import random +import numpy as np +import numbers +import types +import scipy.ndimage as ndimage +import pdb +import torchvision +import PIL.Image as Image +import cv2 +from torch.nn import functional as F + + +class Compose(object): + """Composes several co_transforms together. + For example: + >>> co_transforms.Compose([ + >>> co_transforms.CenterCrop(10), + >>> co_transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, co_transforms): + self.co_transforms = co_transforms + + def __call__(self, input, target): + for t in self.co_transforms: + input, target = t(input, target) + return input, target + + +class Scale(object): + """Rescales the inputs and target arrays to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation order: Default: 2 (bilinear) + """ + + def __init__(self, size, order=1): + self.ratio = size + self.order = order + if order == 0: + self.code = cv2.INTER_NEAREST + elif order == 1: + self.code = cv2.INTER_LINEAR + elif order == 2: + self.code = cv2.INTER_CUBIC + + def __call__(self, inputs, target): + if self.ratio == 1: + return inputs, target + h, w, _ = inputs[0].shape + ratio = self.ratio + + inputs[0] = cv2.resize( + inputs[0], None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR + ) + inputs[1] = cv2.resize( + inputs[1], None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR + ) + # keep the mask same + tmp = cv2.resize( + target[:, :, 2], None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST + ) + target = ( + cv2.resize(target, None, fx=ratio, fy=ratio, interpolation=self.code) + * ratio + ) + target[:, :, 2] = tmp + + return inputs, target + + +class SpatialAug(object): + def __init__( + self, + crop, + scale=None, + rot=None, + trans=None, + squeeze=None, + schedule_coeff=1, + order=1, + black=False, + ): + self.crop = crop + self.scale = scale + self.rot = rot + self.trans = trans + self.squeeze = squeeze + self.t = np.zeros(6) + self.schedule_coeff = schedule_coeff + self.order = order + self.black = black + + def to_identity(self): + self.t[0] = 1 + self.t[2] = 0 + self.t[4] = 0 + self.t[1] = 0 + self.t[3] = 1 + self.t[5] = 0 + + def left_multiply(self, u0, u1, u2, u3, u4, u5): + result = np.zeros(6) + result[0] = self.t[0] * u0 + self.t[1] * u2 + result[1] = self.t[0] * u1 + self.t[1] * u3 + + result[2] = self.t[2] * u0 + self.t[3] * u2 + result[3] = self.t[2] * u1 + self.t[3] * u3 + + result[4] = self.t[4] * u0 + self.t[5] * u2 + u4 + result[5] = self.t[4] * u1 + self.t[5] * u3 + u5 + self.t = result + + def inverse(self): + result = np.zeros(6) + a = self.t[0] + c = self.t[2] + e = self.t[4] + b = self.t[1] + d = self.t[3] + f = self.t[5] + + denom = a * d - b * c + + result[0] = d / denom + result[1] = -b / denom + result[2] = -c / denom + result[3] = a / denom + result[4] = (c * f - d * e) / denom + result[5] = (b * e - a * f) / denom + + return result + + def grid_transform(self, meshgrid, t, normalize=True, gridsize=None): + if gridsize is None: + h, w = meshgrid[0].shape + else: + h, w = gridsize + vgrid = torch.cat( + [ + (meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:, :, np.newaxis], + (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:, :, np.newaxis], + ], + -1, + ) + if normalize: + vgrid[:, :, 0] = 2.0 * vgrid[:, :, 0] / max(w - 1, 1) - 1.0 + vgrid[:, :, 1] = 2.0 * vgrid[:, :, 1] / max(h - 1, 1) - 1.0 + return vgrid + + def __call__(self, inputs, target): + h, w, _ = inputs[0].shape + th, tw = self.crop + meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[ + ::-1 + ] + cornergrid = torch.meshgrid( + [torch.Tensor([0, th - 1]), torch.Tensor([0, tw - 1])] + )[::-1] + + for i in range(50): + # im0 + self.to_identity() + # TODO add mirror + if np.random.binomial(1, 0.5): + mirror = True + else: + mirror = False + ##TODO + # mirror = False + if mirror: + self.left_multiply(-1, 0, 0, 1, 0.5 * tw, -0.5 * th) + else: + self.left_multiply(1, 0, 0, 1, -0.5 * tw, -0.5 * th) + scale0 = 1 + scale1 = 1 + squeeze0 = 1 + squeeze1 = 1 + if not self.rot is None: + rot0 = np.random.uniform(-self.rot[0], +self.rot[0]) + rot1 = ( + np.random.uniform( + -self.rot[1] * self.schedule_coeff, + self.rot[1] * self.schedule_coeff, + ) + + rot0 + ) + self.left_multiply( + np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0 + ) + if not self.trans is None: + trans0 = np.random.uniform(-self.trans[0], +self.trans[0], 2) + trans1 = ( + np.random.uniform( + -self.trans[1] * self.schedule_coeff, + +self.trans[1] * self.schedule_coeff, + 2, + ) + + trans0 + ) + self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th) + if not self.squeeze is None: + squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0])) + squeeze1 = ( + np.exp( + np.random.uniform( + -self.squeeze[1] * self.schedule_coeff, + self.squeeze[1] * self.schedule_coeff, + ) + ) + * squeeze0 + ) + if not self.scale is None: + scale0 = np.exp( + np.random.uniform( + self.scale[2] - self.scale[0], self.scale[2] + self.scale[0] + ) + ) + scale1 = ( + np.exp( + np.random.uniform( + -self.scale[1] * self.schedule_coeff, + self.scale[1] * self.schedule_coeff, + ) + ) + * scale0 + ) + self.left_multiply( + 1.0 / (scale0 * squeeze0), 0, 0, 1.0 / (scale0 / squeeze0), 0, 0 + ) + + self.left_multiply(1, 0, 0, 1, 0.5 * w, 0.5 * h) + transmat0 = self.t.copy() + + # im1 + self.to_identity() + if mirror: + self.left_multiply(-1, 0, 0, 1, 0.5 * tw, -0.5 * th) + else: + self.left_multiply(1, 0, 0, 1, -0.5 * tw, -0.5 * th) + if not self.rot is None: + self.left_multiply( + np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0 + ) + if not self.trans is None: + self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th) + self.left_multiply( + 1.0 / (scale1 * squeeze1), 0, 0, 1.0 / (scale1 / squeeze1), 0, 0 + ) + self.left_multiply(1, 0, 0, 1, 0.5 * w, 0.5 * h) + transmat1 = self.t.copy() + transmat1_inv = self.inverse() + + if self.black: + # black augmentation, allowing 0 values in the input images + # https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu + break + else: + if ( + ( + self.grid_transform( + cornergrid, transmat0, gridsize=[float(h), float(w)] + ).abs() + > 1 + ).sum() + + ( + self.grid_transform( + cornergrid, transmat1, gridsize=[float(h), float(w)] + ).abs() + > 1 + ).sum() + ) == 0: + break + if i == 49: + print("max_iter in augmentation") + self.to_identity() + self.left_multiply(1, 0, 0, 1, -0.5 * tw, -0.5 * th) + self.left_multiply(1, 0, 0, 1, 0.5 * w, 0.5 * h) + transmat0 = self.t.copy() + transmat1 = self.t.copy() + + # do the real work + vgrid = self.grid_transform(meshgrid, transmat0, gridsize=[float(h), float(w)]) + inputs_0 = F.grid_sample( + torch.Tensor(inputs[0]).permute(2, 0, 1)[np.newaxis], vgrid[np.newaxis] + )[0].permute(1, 2, 0) + if self.order == 0: + target_0 = F.grid_sample( + torch.Tensor(target).permute(2, 0, 1)[np.newaxis], + vgrid[np.newaxis], + mode="nearest", + )[0].permute(1, 2, 0) + else: + target_0 = F.grid_sample( + torch.Tensor(target).permute(2, 0, 1)[np.newaxis], vgrid[np.newaxis] + )[0].permute(1, 2, 0) + + mask_0 = target[:, :, 2:3].copy() + mask_0[mask_0 == 0] = np.nan + if self.order == 0: + mask_0 = F.grid_sample( + torch.Tensor(mask_0).permute(2, 0, 1)[np.newaxis], + vgrid[np.newaxis], + mode="nearest", + )[0].permute(1, 2, 0) + else: + mask_0 = F.grid_sample( + torch.Tensor(mask_0).permute(2, 0, 1)[np.newaxis], vgrid[np.newaxis] + )[0].permute(1, 2, 0) + mask_0[torch.isnan(mask_0)] = 0 + + vgrid = self.grid_transform(meshgrid, transmat1, gridsize=[float(h), float(w)]) + inputs_1 = F.grid_sample( + torch.Tensor(inputs[1]).permute(2, 0, 1)[np.newaxis], vgrid[np.newaxis] + )[0].permute(1, 2, 0) + + # flow + pos = target_0[:, :, :2] + self.grid_transform( + meshgrid, transmat0, normalize=False + ) + pos = self.grid_transform(pos.permute(2, 0, 1), transmat1_inv, normalize=False) + if target_0.shape[2] >= 4: + # scale + exp = target_0[:, :, 3:] * scale1 / scale0 + target = torch.cat( + [ + (pos[:, :, 0] - meshgrid[0]).unsqueeze(-1), + (pos[:, :, 1] - meshgrid[1]).unsqueeze(-1), + mask_0, + exp, + ], + -1, + ) + else: + target = torch.cat( + [ + (pos[:, :, 0] - meshgrid[0]).unsqueeze(-1), + (pos[:, :, 1] - meshgrid[1]).unsqueeze(-1), + mask_0, + ], + -1, + ) + # target_0[:,:,2].unsqueeze(-1) ], -1) + inputs = [np.asarray(inputs_0), np.asarray(inputs_1)] + target = np.asarray(target) + return inputs, target + + +class pseudoPCAAug(object): + """ + Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + This version is faster. + """ + + def __init__(self, schedule_coeff=1): + self.augcolor = torchvision.transforms.ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5 / 3.14 + ) + + def __call__(self, inputs, target): + inputs[0] = ( + np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0] * 255)))) + / 255.0 + ) + inputs[1] = ( + np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1] * 255)))) + / 255.0 + ) + return inputs, target + + +class PCAAug(object): + """ + Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + """ + + def __init__( + self, + lmult_pow=[0.4, 0, -0.2], + lmult_mult=[ + 0.4, + 0, + 0, + ], + lmult_add=[ + 0.03, + 0, + 0, + ], + sat_pow=[ + 0.4, + 0, + 0, + ], + sat_mult=[0.5, 0, -0.3], + sat_add=[ + 0.03, + 0, + 0, + ], + col_pow=[ + 0.4, + 0, + 0, + ], + col_mult=[ + 0.2, + 0, + 0, + ], + col_add=[ + 0.02, + 0, + 0, + ], + ladd_pow=[ + 0.4, + 0, + 0, + ], + ladd_mult=[ + 0.4, + 0, + 0, + ], + ladd_add=[ + 0.04, + 0, + 0, + ], + col_rotate=[ + 1.0, + 0, + 0, + ], + schedule_coeff=1, + ): + # no mean + self.pow_nomean = [1, 1, 1] + self.add_nomean = [0, 0, 0] + self.mult_nomean = [1, 1, 1] + self.pow_withmean = [1, 1, 1] + self.add_withmean = [0, 0, 0] + self.mult_withmean = [1, 1, 1] + self.lmult_pow = 1 + self.lmult_mult = 1 + self.lmult_add = 0 + self.col_angle = 0 + if not ladd_pow is None: + self.pow_nomean[0] = np.exp(np.random.normal(ladd_pow[2], ladd_pow[0])) + if not col_pow is None: + self.pow_nomean[1] = np.exp(np.random.normal(col_pow[2], col_pow[0])) + self.pow_nomean[2] = np.exp(np.random.normal(col_pow[2], col_pow[0])) + + if not ladd_add is None: + self.add_nomean[0] = np.random.normal(ladd_add[2], ladd_add[0]) + if not col_add is None: + self.add_nomean[1] = np.random.normal(col_add[2], col_add[0]) + self.add_nomean[2] = np.random.normal(col_add[2], col_add[0]) + + if not ladd_mult is None: + self.mult_nomean[0] = np.exp(np.random.normal(ladd_mult[2], ladd_mult[0])) + if not col_mult is None: + self.mult_nomean[1] = np.exp(np.random.normal(col_mult[2], col_mult[0])) + self.mult_nomean[2] = np.exp(np.random.normal(col_mult[2], col_mult[0])) + + # with mean + if not sat_pow is None: + self.pow_withmean[1] = np.exp( + np.random.uniform(sat_pow[2] - sat_pow[0], sat_pow[2] + sat_pow[0]) + ) + self.pow_withmean[2] = self.pow_withmean[1] + if not sat_add is None: + self.add_withmean[1] = np.random.uniform( + sat_add[2] - sat_add[0], sat_add[2] + sat_add[0] + ) + self.add_withmean[2] = self.add_withmean[1] + if not sat_mult is None: + self.mult_withmean[1] = np.exp( + np.random.uniform(sat_mult[2] - sat_mult[0], sat_mult[2] + sat_mult[0]) + ) + self.mult_withmean[2] = self.mult_withmean[1] + + if not lmult_pow is None: + self.lmult_pow = np.exp( + np.random.uniform( + lmult_pow[2] - lmult_pow[0], lmult_pow[2] + lmult_pow[0] + ) + ) + if not lmult_mult is None: + self.lmult_mult = np.exp( + np.random.uniform( + lmult_mult[2] - lmult_mult[0], lmult_mult[2] + lmult_mult[0] + ) + ) + if not lmult_add is None: + self.lmult_add = np.random.uniform( + lmult_add[2] - lmult_add[0], lmult_add[2] + lmult_add[0] + ) + if not col_rotate is None: + self.col_angle = np.random.uniform( + col_rotate[2] - col_rotate[0], col_rotate[2] + col_rotate[0] + ) + + # eigen vectors + self.eigvec = np.reshape( + [0.51, 0.56, 0.65, 0.79, 0.01, -0.62, 0.35, -0.83, 0.44], [3, 3] + ).transpose() + + def __call__(self, inputs, target): + inputs[0] = self.pca_image(inputs[0]) + inputs[1] = self.pca_image(inputs[1]) + return inputs, target + + def pca_image(self, rgb): + eig = np.dot(rgb, self.eigvec) + max_rgb = np.clip(rgb, 0, np.inf).max((0, 1)) + min_rgb = rgb.min((0, 1)) + mean_rgb = rgb.mean((0, 1)) + max_abs_eig = np.abs(eig).max((0, 1)) + max_l = np.sqrt(np.sum(max_abs_eig * max_abs_eig)) + mean_eig = np.dot(mean_rgb, self.eigvec) + + # no-mean stuff + eig -= mean_eig[np.newaxis, np.newaxis] + + for c in range(3): + if max_abs_eig[c] > 1e-2: + mean_eig[c] /= max_abs_eig[c] + eig[:, :, c] = eig[:, :, c] / max_abs_eig[c] + eig[:, :, c] = ( + np.power(np.abs(eig[:, :, c]), self.pow_nomean[c]) + * ((eig[:, :, c] > 0) - 0.5) + * 2 + ) + eig[:, :, c] = eig[:, :, c] + self.add_nomean[c] + eig[:, :, c] = eig[:, :, c] * self.mult_nomean[c] + eig += mean_eig[np.newaxis, np.newaxis] + + # withmean stuff + if max_abs_eig[0] > 1e-2: + eig[:, :, 0] = ( + np.power(np.abs(eig[:, :, 0]), self.pow_withmean[0]) + * ((eig[:, :, 0] > 0) - 0.5) + * 2 + ) + eig[:, :, 0] = eig[:, :, 0] + self.add_withmean[0] + eig[:, :, 0] = eig[:, :, 0] * self.mult_withmean[0] + + s = np.sqrt(eig[:, :, 1] * eig[:, :, 1] + eig[:, :, 2] * eig[:, :, 2]) + smask = s > 1e-2 + s1 = np.power(s, self.pow_withmean[1]) + s1 = np.clip(s1 + self.add_withmean[1], 0, np.inf) + s1 = s1 * self.mult_withmean[1] + s1 = s1 * smask + s * (1 - smask) + + # color angle + if self.col_angle != 0: + temp1 = ( + np.cos(self.col_angle) * eig[:, :, 1] + - np.sin(self.col_angle) * eig[:, :, 2] + ) + temp2 = ( + np.sin(self.col_angle) * eig[:, :, 1] + + np.cos(self.col_angle) * eig[:, :, 2] + ) + eig[:, :, 1] = temp1 + eig[:, :, 2] = temp2 + + # to origin magnitude + for c in range(3): + if max_abs_eig[c] > 1e-2: + eig[:, :, c] = eig[:, :, c] * max_abs_eig[c] + + if max_l > 1e-2: + l1 = np.sqrt( + eig[:, :, 0] * eig[:, :, 0] + + eig[:, :, 1] * eig[:, :, 1] + + eig[:, :, 2] * eig[:, :, 2] + ) + l1 = l1 / max_l + + eig[:, :, 1][smask] = (eig[:, :, 1] / s * s1)[smask] + eig[:, :, 2][smask] = (eig[:, :, 2] / s * s1)[smask] + # eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask) + # eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask) + + if max_l > 1e-2: + l = np.sqrt( + eig[:, :, 0] * eig[:, :, 0] + + eig[:, :, 1] * eig[:, :, 1] + + eig[:, :, 2] * eig[:, :, 2] + ) + l1 = np.power(l1, self.lmult_pow) + l1 = np.clip(l1 + self.lmult_add, 0, np.inf) + l1 = l1 * self.lmult_mult + l1 = l1 * max_l + lmask = l > 1e-2 + eig[lmask] = (eig / l[:, :, np.newaxis] * l1[:, :, np.newaxis])[lmask] + for c in range(3): + eig[:, :, c][lmask] = (np.clip(eig[:, :, c], -np.inf, max_abs_eig[c]))[ + lmask + ] + # for c in range(3): + # # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask) + # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] + # eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask) + + return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1) + + +class ChromaticAug(object): + """ + Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + """ + + def __init__( + self, + noise=0.06, + gamma=0.02, + brightness=0.02, + contrast=0.02, + color=0.02, + schedule_coeff=1, + ): + self.noise = np.random.uniform(0, noise) + self.gamma = np.exp(np.random.normal(0, gamma * schedule_coeff)) + self.brightness = np.random.normal(0, brightness * schedule_coeff) + self.contrast = np.exp(np.random.normal(0, contrast * schedule_coeff)) + self.color = np.exp(np.random.normal(0, color * schedule_coeff, 3)) + + def __call__(self, inputs, target): + inputs[1] = self.chrom_aug(inputs[1]) + # noise + inputs[0] += np.random.normal(0, self.noise, inputs[0].shape) + inputs[1] += np.random.normal(0, self.noise, inputs[0].shape) + return inputs, target + + def chrom_aug(self, rgb): + # color change + mean_in = rgb.sum(-1) + rgb = rgb * self.color[np.newaxis, np.newaxis] + brightness_coeff = mean_in / (rgb.sum(-1) + 0.01) + rgb = np.clip(rgb * brightness_coeff[:, :, np.newaxis], 0, 1) + # gamma + rgb = np.power(rgb, self.gamma) + # brightness + rgb += self.brightness + # contrast + rgb = 0.5 + (rgb - 0.5) * self.contrast + rgb = np.clip(rgb, 0, 1) + return diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/flow_viz.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/flow_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..b8d93345c88238b40c07ddd9b914cb9a66ad8224 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/flow_viz.py @@ -0,0 +1,136 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u) / np.pi + fk = (a + 1) / 2 * (ncols - 1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:, i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, max_flow=None): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, "input flow must have three dimensions" + assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + if max_flow is None: + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + else: + rad_max = max_flow + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/frame_utils.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bee554e4bb1596bfe15945198f4cb1b5f1d786bf --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/frame_utils.py @@ -0,0 +1,142 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + + +def readFlow(fn): + """Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, "rb") as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print("Magic number incorrect. Invalid .flo file") + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + +def readPFM(file): + file = open(file, "rb") + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b"PF": + color = True + elif header == b"Pf": + color = False + else: + raise Exception("Not a PFM file.") + + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + + +def writeFlow(filename, uv, v=None): + """Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert uv.ndim == 3 + assert uv.shape[2] == 2 + u = uv[:, :, 0] + v = uv[:, :, 1] + else: + u = uv + + assert u.shape == v.shape + height, width = u.shape + f = open(filename, "wb") + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width * nBands)) + tmp[:, np.arange(width) * 2] = u + tmp[:, np.arange(width) * 2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg": + return Image.open(file_name) + elif ext == ".bin" or ext == ".raw": + return np.load(file_name) + elif ext == ".flo": + return readFlow(file_name).astype(np.float32) + elif ext == ".pfm": + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/logger.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ed98f77dbf68a0f41529f96ac526c520fe991705 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/logger.py @@ -0,0 +1,60 @@ +from torch.utils.tensorboard import SummaryWriter +from loguru import logger as loguru_logger + + +class Logger: + def __init__(self, model, scheduler, cfg): + self.model = model + self.scheduler = scheduler + self.total_steps = 0 + self.running_loss = {} + self.writer = None + self.cfg = cfg + + def _print_training_status(self): + metrics_data = [ + self.running_loss[k] / self.cfg.sum_freq + for k in sorted(self.running_loss.keys()) + ] + training_str = "[{:6d}, {}] ".format( + self.total_steps + 1, self.scheduler.get_last_lr() + ) + metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) + + # print the training status + loguru_logger.info(training_str + metrics_str) + + if self.writer is None: + if self.cfg.log_dir is None: + self.writer = SummaryWriter() + else: + self.writer = SummaryWriter(self.cfg.log_dir) + + for k in self.running_loss: + self.writer.add_scalar( + k, self.running_loss[k] / self.cfg.sum_freq, self.total_steps + ) + self.running_loss[k] = 0.0 + + def push(self, metrics): + self.total_steps += 1 + + for key in metrics: + if key not in self.running_loss: + self.running_loss[key] = 0.0 + + self.running_loss[key] += metrics[key] + + if self.total_steps % self.cfg.sum_freq == self.cfg.sum_freq - 1: + self._print_training_status() + self.running_loss = {} + + def write_dict(self, results): + if self.writer is None: + self.writer = SummaryWriter() + + for key in results: + self.writer.add_scalar(key, results[key], self.total_steps) + + def close(self): + self.writer.close() diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/misc.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8bfbe12baafeea8e664b7688efc1002b36d93244 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/misc.py @@ -0,0 +1,33 @@ +import time +import os +import shutil + + +def process_transformer_cfg(cfg): + log_dir = "" + if "critical_params" in cfg: + critical_params = [cfg[key] for key in cfg.critical_params] + for name, param in zip(cfg["critical_params"], critical_params): + log_dir += "{:s}[{:s}]".format(name, str(param)) + + return log_dir + + +def process_cfg(cfg): + log_dir = "logs/" + cfg.name + "/" + cfg.transformer + "/" + critical_params = [cfg.trainer[key] for key in cfg.critical_params] + for name, param in zip(cfg["critical_params"], critical_params): + log_dir += "{:s}[{:s}]".format(name, str(param)) + + log_dir += process_transformer_cfg(cfg[cfg.transformer]) + + now = time.localtime() + now_time = "{:02d}_{:02d}_{:02d}_{:02d}".format( + now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min + ) + log_dir += cfg.suffix + "(" + now_time + ")" + cfg.log_dir = log_dir + os.makedirs(log_dir) + + shutil.copytree("configs", f"{log_dir}/configs") + shutil.copytree("core/FlowFormer", f"{log_dir}/FlowFormer") diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/utils.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8718cbeb89d0fad6275697f6178d83fb643ebef2 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/core/utils/utils.py @@ -0,0 +1,113 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """Pads images such that dimensions are divisible by 8""" + + def __init__(self, dims, mode="sintel"): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == "sintel": + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + elif mode == "kitti400": + self._pad = [0, 0, 0, 400 - self.ht] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 + ) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 + ) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode="bilinear", mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def indexing(img, coords, mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + """ + TODO: directly indexing features instead of sampling + """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True, mode="nearest") + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode="bilinear"): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/evaluate_FlowFormer.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/evaluate_FlowFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c80432c48262839afd107352ec0a4d3ef18bbd --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/evaluate_FlowFormer.py @@ -0,0 +1,201 @@ +import sys + +sys.path.append("core") + +from PIL import Image +import argparse +import os +import time +import numpy as np +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +from configs.default import get_cfg +from configs.things_eval import get_cfg as get_things_cfg +from configs.small_things_eval import get_cfg as get_small_things_cfg +from core.utils.misc import process_cfg +import datasets +from utils import flow_viz +from utils import frame_utils + +# from FlowFormer import FlowFormer +from core.FlowFormer import build_flowformer +from raft import RAFT + +from utils.utils import InputPadder, forward_interpolate + + +@torch.no_grad() +def validate_chairs(model): + """Perform evaluation on the FlyingChairs (test) split""" + model.eval() + epe_list = [] + + val_dataset = datasets.FlyingChairs(split="validation") + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, _ = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + flow_pre, _ = model(image1, image2) + + epe = torch.sum((flow_pre[0].cpu() - flow_gt) ** 2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + epe = np.mean(np.concatenate(epe_list)) + print("Validation Chairs EPE: %f" % epe) + return {"chairs": epe} + + +@torch.no_grad() +def validate_sintel(model): + """Peform validation using the Sintel (train) split""" + model.eval() + results = {} + for dstype in ["clean", "final"]: + val_dataset = datasets.MpiSintel(split="training", dstype=dstype) + epe_list = [] + + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, _ = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + flow_pre = model(image1, image2) + + flow_pre = padder.unpad(flow_pre[0]).cpu()[0] + + epe = torch.sum((flow_pre - flow_gt) ** 2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + epe_all = np.concatenate(epe_list) + epe = np.mean(epe_all) + px1 = np.mean(epe_all < 1) + px3 = np.mean(epe_all < 3) + px5 = np.mean(epe_all < 5) + + print( + "Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" + % (dstype, epe, px1, px3, px5) + ) + results[dstype] = np.mean(epe_list) + + return results + + +@torch.no_grad() +def create_sintel_submission(model, output_path="sintel_submission"): + """Create submission for the Sintel leaderboard""" + + model.eval() + for dstype in ["final", "clean"]: + test_dataset = datasets.MpiSintel(split="test", aug_params=None, dstype=dstype) + + for test_id in range(len(test_dataset)): + if (test_id + 1) % 100 == 0: + print(f"{test_id} / {len(test_dataset)}") + image1, image2, (sequence, frame) = test_dataset[test_id] + image1, image2 = image1[None].cuda(), image2[None].cuda() + + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + flow_pre = model(image1, image2) + + flow_pre = padder.unpad(flow_pre[0]).cpu() + flow = flow_pre[0].permute(1, 2, 0).cpu().numpy() + + output_dir = os.path.join(output_path, dstype, sequence) + output_file = os.path.join(output_dir, "frame%04d.flo" % (frame + 1)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + frame_utils.writeFlow(output_file, flow) + + +@torch.no_grad() +def validate_kitti(model): + """Peform validation using the KITTI-2015 (train) split""" + model.eval() + val_dataset = datasets.KITTI(split="training") + + out_list, epe_list = [], [] + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, valid_gt = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + flow_pre = model(image1, image2) + + flow_pre = padder.unpad(flow_pre[0]).cpu()[0] + + epe = torch.sum((flow_pre - flow_gt) ** 2, dim=0).sqrt() + mag = torch.sum(flow_gt**2, dim=0).sqrt() + + epe = epe.view(-1) + mag = mag.view(-1) + val = valid_gt.view(-1) >= 0.5 + + out = ((epe > 3.0) & ((epe / mag) > 0.05)).float() + epe_list.append(epe[val].mean().item()) + out_list.append(out[val].cpu().numpy()) + + epe_list = np.array(epe_list) + out_list = np.concatenate(out_list) + + epe = np.mean(epe_list) + f1 = 100 * np.mean(out_list) + + print("Validation KITTI: %f, %f" % (epe, f1)) + return {"kitti-epe": epe, "kitti-f1": f1} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", help="restore checkpoint") + parser.add_argument("--dataset", help="dataset for evaluation") + parser.add_argument("--small", action="store_true", help="use small model") + parser.add_argument( + "--mixed_precision", action="store_true", help="use mixed precision" + ) + parser.add_argument( + "--alternate_corr", + action="store_true", + help="use efficent correlation implementation", + ) + args = parser.parse_args() + # cfg = get_cfg() + if args.small: + cfg = get_small_things_cfg() + else: + cfg = get_things_cfg() + cfg.update(vars(args)) + + model = torch.nn.DataParallel(build_flowformer(cfg)) + model.load_state_dict(torch.load(cfg.model)) + + print(args) + + model.cuda() + model.eval() + + # create_sintel_submission(model.module, warm_start=True) + # create_kitti_submission(model.module) + + with torch.no_grad(): + if args.dataset == "chairs": + validate_chairs(model.module) + + elif args.dataset == "sintel": + validate_sintel(model.module) + + elif args.dataset == "kitti": + validate_kitti(model.module) + + elif args.dataset == "sintel_submission": + create_sintel_submission(model.module) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/evaluate_FlowFormer_tile.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/evaluate_FlowFormer_tile.py new file mode 100644 index 0000000000000000000000000000000000000000..8755b75c0183a12f504384c37f340a79647888d3 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/evaluate_FlowFormer_tile.py @@ -0,0 +1,410 @@ +import sys + +from attr import validate + +sys.path.append("core") + +from PIL import Image +import argparse +import os +import time +import numpy as np +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +from configs.submission import get_cfg as get_submission_cfg + +# from configs.kitti_submission import get_cfg as get_kitti_cfg +from configs.things_eval import get_cfg as get_things_cfg +from configs.small_things_eval import get_cfg as get_small_things_cfg +from core.utils.misc import process_cfg +import datasets +from utils import flow_viz +from utils import frame_utils + +from core.FlowFormer import build_flowformer +from raft import RAFT + +from utils.utils import InputPadder, forward_interpolate +import imageio +import itertools + +TRAIN_SIZE = [432, 960] + + +class InputPadder: + """Pads images such that dimensions are divisible by 8""" + + def __init__(self, dims, mode="sintel"): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == "sintel": + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + elif mode == "kitti432": + self._pad = [0, 0, 0, 432 - self.ht] + elif mode == "kitti400": + self._pad = [0, 0, 0, 400 - self.ht] + elif mode == "kitti376": + self._pad = [0, 0, 0, 376 - self.ht] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode="constant", value=0.0) for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20): + if min_overlap >= patch_size[0] or min_overlap >= patch_size[1]: + raise ValueError("!!") + hs = list(range(0, image_shape[0], patch_size[0] - min_overlap)) + ws = list(range(0, image_shape[1], patch_size[1] - min_overlap)) + # Make sure the final patch is flush with the image boundary + hs[-1] = image_shape[0] - patch_size[0] + ws[-1] = image_shape[1] - patch_size[1] + return [(h, w) for h in hs for w in ws] + + +import math + + +def compute_weight( + hws, image_shape, patch_size=TRAIN_SIZE, sigma=1.0, wtype="gaussian" +): + patch_num = len(hws) + h, w = torch.meshgrid(torch.arange(patch_size[0]), torch.arange(patch_size[1])) + h, w = h / float(patch_size[0]), w / float(patch_size[1]) + c_h, c_w = 0.5, 0.5 + h, w = h - c_h, w - c_w + weights_hw = (h**2 + w**2) ** 0.5 / sigma + denorm = 1 / (sigma * math.sqrt(2 * math.pi)) + weights_hw = denorm * torch.exp(-0.5 * (weights_hw) ** 2) + + weights = torch.zeros(1, patch_num, *image_shape) + for idx, (h, w) in enumerate(hws): + weights[:, idx, h : h + patch_size[0], w : w + patch_size[1]] = weights_hw + weights = weights.cuda() + patch_weights = [] + for idx, (h, w) in enumerate(hws): + patch_weights.append( + weights[:, idx : idx + 1, h : h + patch_size[0], w : w + patch_size[1]] + ) + + return patch_weights + + +@torch.no_grad() +def create_sintel_submission( + model, output_path="sintel_submission_multi8_768", sigma=0.05 +): + """Create submission for the Sintel leaderboard""" + print("no warm start") + # print(f"output path: {output_path}") + IMAGE_SIZE = [436, 1024] + + hws = compute_grid_indices(IMAGE_SIZE) + weights = compute_weight(hws, IMAGE_SIZE, TRAIN_SIZE, sigma) + + model.eval() + for dstype in ["final", "clean"]: + test_dataset = datasets.MpiSintel_submission( + split="test", aug_params=None, dstype=dstype, root="./dataset/Sintel/test" + ) + epe_list = [] + for test_id in range(len(test_dataset)): + if (test_id + 1) % 100 == 0: + print(f"{test_id} / {len(test_dataset)}") + # break + image1, image2, (sequence, frame) = test_dataset[test_id] + image1, image2 = image1[None].cuda(), image2[None].cuda() + + flows = 0 + flow_count = 0 + + for idx, (h, w) in enumerate(hws): + image1_tile = image1[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + image2_tile = image2[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + flow_pre, flow_low = model(image1_tile, image2_tile) + + padding = ( + w, + IMAGE_SIZE[1] - w - TRAIN_SIZE[1], + h, + IMAGE_SIZE[0] - h - TRAIN_SIZE[0], + 0, + 0, + ) + flows += F.pad(flow_pre * weights[idx], padding) + flow_count += F.pad(weights[idx], padding) + + flow_pre = flows / flow_count + flow = flow_pre[0].permute(1, 2, 0).cpu().numpy() + + output_dir = os.path.join(output_path, dstype, sequence) + output_file = os.path.join(output_dir, "frame%04d.flo" % (frame + 1)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + frame_utils.writeFlow(output_file, flow) + + +@torch.no_grad() +def create_kitti_submission(model, output_path="kitti_submission", sigma=0.05): + """Create submission for the Sintel leaderboard""" + + IMAGE_SIZE = [432, 1242] + + print(f"output path: {output_path}") + print(f"image size: {IMAGE_SIZE}") + print(f"training size: {TRAIN_SIZE}") + + hws = compute_grid_indices(IMAGE_SIZE) + weights = compute_weight(hws, (432, 1242), TRAIN_SIZE, sigma) + model.eval() + test_dataset = datasets.KITTI(split="testing", aug_params=None) + + if not os.path.exists(output_path): + os.makedirs(output_path) + + for test_id in range(len(test_dataset)): + image1, image2, (frame_id,) = test_dataset[test_id] + new_shape = image1.shape[1:] + if ( + new_shape[1] != IMAGE_SIZE[1] + ): # fix the height=432, adaptive ajust the width + print(f"replace {IMAGE_SIZE} with {new_shape}") + IMAGE_SIZE[0] = 432 + IMAGE_SIZE[1] = new_shape[1] + hws = compute_grid_indices(IMAGE_SIZE) + weights = compute_weight(hws, IMAGE_SIZE, TRAIN_SIZE, sigma) + + padder = InputPadder( + image1.shape, mode="kitti432" + ) # padding the image to height of 432 + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + + flows = 0 + flow_count = 0 + + for idx, (h, w) in enumerate(hws): + image1_tile = image1[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + image2_tile = image2[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + flow_pre, _ = model(image1_tile, image2_tile) + + padding = ( + w, + IMAGE_SIZE[1] - w - TRAIN_SIZE[1], + h, + IMAGE_SIZE[0] - h - TRAIN_SIZE[0], + 0, + 0, + ) + flows += F.pad(flow_pre * weights[idx], padding) + flow_count += F.pad(weights[idx], padding) + + flow_pre = flows / flow_count + flow = padder.unpad(flow_pre[0]).permute(1, 2, 0).cpu().numpy() + + output_filename = os.path.join(output_path, frame_id) + frame_utils.writeFlowKITTI(output_filename, flow) + + flow_img = flow_viz.flow_to_image(flow) + image = Image.fromarray(flow_img) + if not os.path.exists(f"vis_kitti_3patch"): + os.makedirs(f"vis_kitti_3patch/flow") + os.makedirs(f"vis_kitti_3patch/image") + + image.save(f"vis_kitti_3patch/flow/{test_id}.png") + imageio.imwrite( + f"vis_kitti_3patch/image/{test_id}_0.png", + image1[0].cpu().permute(1, 2, 0).numpy(), + ) + imageio.imwrite( + f"vis_kitti_3patch/image/{test_id}_1.png", + image2[0].cpu().permute(1, 2, 0).numpy(), + ) + + +@torch.no_grad() +def validate_kitti(model, sigma=0.05): + IMAGE_SIZE = [376, 1242] + TRAIN_SIZE = [376, 720] + + hws = compute_grid_indices(IMAGE_SIZE, TRAIN_SIZE) + weights = compute_weight(hws, IMAGE_SIZE, TRAIN_SIZE, sigma) + model.eval() + val_dataset = datasets.KITTI(split="training") + + out_list, epe_list = [], [] + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, valid_gt = val_dataset[val_id] + new_shape = image1.shape[1:] + if new_shape[1] != IMAGE_SIZE[1]: + print(f"replace {IMAGE_SIZE} with {new_shape}") + IMAGE_SIZE[0] = 376 + IMAGE_SIZE[1] = new_shape[1] + hws = compute_grid_indices(IMAGE_SIZE, TRAIN_SIZE) + weights = compute_weight(hws, IMAGE_SIZE, TRAIN_SIZE, sigma) + + padder = InputPadder(image1.shape, mode="kitti376") + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + + flows = 0 + flow_count = 0 + + for idx, (h, w) in enumerate(hws): + image1_tile = image1[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + image2_tile = image2[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + flow_pre, flow_low = model(image1_tile, image2_tile) + + padding = ( + w, + IMAGE_SIZE[1] - w - TRAIN_SIZE[1], + h, + IMAGE_SIZE[0] - h - TRAIN_SIZE[0], + 0, + 0, + ) + flows += F.pad(flow_pre * weights[idx], padding) + flow_count += F.pad(weights[idx], padding) + + flow_pre = flows / flow_count + flow = padder.unpad(flow_pre[0]).cpu() + epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt() + mag = torch.sum(flow_gt**2, dim=0).sqrt() + + epe = epe.view(-1) + mag = mag.view(-1) + val = valid_gt.view(-1) >= 0.5 + + out = ((epe > 3.0) & ((epe / mag) > 0.05)).float() + epe_list.append(epe[val].mean().item()) + out_list.append(out[val].cpu().numpy()) + + epe_list = np.array(epe_list) + out_list = np.concatenate(out_list) + + epe = np.mean(epe_list) + f1 = 100 * np.mean(out_list) + + print("Validation KITTI: %f, %f" % (epe, f1)) + return {"kitti-epe": epe, "kitti-f1": f1} + + +@torch.no_grad() +def validate_sintel(model, sigma=0.05): + """Peform validation using the Sintel (train) split""" + + IMAGE_SIZE = [436, 1024] + + hws = compute_grid_indices(IMAGE_SIZE) + weights = compute_weight(hws, IMAGE_SIZE, TRAIN_SIZE, sigma) + + model.eval() + results = {} + for dstype in ["final", "clean"]: + val_dataset = datasets.MpiSintel(split="training", dstype=dstype) + + epe_list = [] + + for val_id in range(len(val_dataset)): + if val_id % 50 == 0: + print(val_id) + + image1, image2, flow_gt, _ = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + flows = 0 + flow_count = 0 + + for idx, (h, w) in enumerate(hws): + image1_tile = image1[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + image2_tile = image2[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + + flow_pre, _ = model(image1_tile, image2_tile, flow_init=None) + + padding = ( + w, + IMAGE_SIZE[1] - w - TRAIN_SIZE[1], + h, + IMAGE_SIZE[0] - h - TRAIN_SIZE[0], + 0, + 0, + ) + flows += F.pad(flow_pre * weights[idx], padding) + flow_count += F.pad(weights[idx], padding) + + flow_pre = flows / flow_count + flow_pre = flow_pre[0].cpu() + + epe = torch.sum((flow_pre - flow_gt) ** 2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + epe_all = np.concatenate(epe_list) + epe = np.mean(epe_all) + px1 = np.mean(epe_all < 1) + px3 = np.mean(epe_all < 3) + px5 = np.mean(epe_all < 5) + + print( + "Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" + % (dstype, epe, px1, px3, px5) + ) + results[f"{dstype}_tile"] = np.mean(epe_list) + + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", help="load model") + parser.add_argument("--eval", help="eval benchmark") + parser.add_argument("--small", action="store_true", help="use small model") + args = parser.parse_args() + + exp_func = None + cfg = None + if args.eval == "sintel_submission": + exp_func = create_sintel_submission + cfg = get_submission_cfg() + elif args.eval == "kitti_submission": + exp_func = create_kitti_submission + cfg = get_submission_cfg() + cfg.latentcostformer.decoder_depth = 24 + elif args.eval == "sintel_validation": + exp_func = validate_sintel + if args.small: + cfg = get_small_things_cfg() + else: + cfg = get_things_cfg() + elif args.eval == "kitti_validation": + exp_func = validate_kitti + if args.small: + cfg = get_small_things_cfg() + else: + cfg = get_things_cfg() + cfg.latentcostformer.decoder_depth = 24 + else: + print(f"EROOR: {args.eval} is not valid") + cfg.update(vars(args)) + + print(cfg) + model = torch.nn.DataParallel(build_flowformer(cfg)) + model.load_state_dict(torch.load(cfg.model)) + + model.cuda() + model.eval() + + exp_func(model.module) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/train_FlowFormer.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/train_FlowFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..a3c0b850544ac90c16e5ceab1146032598f3f696 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/train_FlowFormer.py @@ -0,0 +1,182 @@ +from __future__ import print_function, division +import sys + +# sys.path.append('core') + +import argparse +import os +import cv2 +import time +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +from torch.utils.data import DataLoader +from core import optimizer +import evaluate_FlowFormer as evaluate +import evaluate_FlowFormer_tile as evaluate_tile +import core.datasets as datasets +from core.loss import sequence_loss +from core.optimizer import fetch_optimizer +from core.utils.misc import process_cfg +from loguru import logger as loguru_logger + +# from torch.utils.tensorboard import SummaryWriter +from core.utils.logger import Logger + +# from core.FlowFormer import FlowFormer +from core.FlowFormer import build_flowformer + +try: + from torch.cuda.amp import GradScaler +except: + # dummy GradScaler for PyTorch < 1.6 + class GradScaler: + def __init__(self): + pass + + def scale(self, loss): + return loss + + def unscale_(self, optimizer): + pass + + def step(self, optimizer): + optimizer.step() + + def update(self): + pass + + +# torch.autograd.set_detect_anomaly(True) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def train(cfg): + model = nn.DataParallel(build_flowformer(cfg)) + loguru_logger.info("Parameter Count: %d" % count_parameters(model)) + + if cfg.restore_ckpt is not None: + print("[Loading ckpt from {}]".format(cfg.restore_ckpt)) + model.load_state_dict(torch.load(cfg.restore_ckpt), strict=True) + + model.cuda() + model.train() + + train_loader = datasets.fetch_dataloader(cfg) + optimizer, scheduler = fetch_optimizer(model, cfg.trainer) + + total_steps = 0 + scaler = GradScaler(enabled=cfg.mixed_precision) + logger = Logger(model, scheduler, cfg) + + add_noise = False + + should_keep_training = True + while should_keep_training: + for i_batch, data_blob in enumerate(train_loader): + optimizer.zero_grad() + image1, image2, flow, valid = [x.cuda() for x in data_blob] + + if cfg.add_noise: + stdv = np.random.uniform(0.0, 5.0) + image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp( + 0.0, 255.0 + ) + image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp( + 0.0, 255.0 + ) + + output = {} + flow_predictions = model(image1, image2, output) + loss, metrics = sequence_loss(flow_predictions, flow, valid, cfg) + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.trainer.clip) + + scaler.step(optimizer) + scheduler.step() + scaler.update() + + metrics.update(output) + logger.push(metrics) + + ### change evaluate to functions + + if total_steps % cfg.val_freq == cfg.val_freq - 1: + PATH = "%s/%d_%s.pth" % (cfg.log_dir, total_steps + 1, cfg.name) + # torch.save(model.state_dict(), PATH) + + results = {} + for val_dataset in cfg.validation: + if val_dataset == "chairs": + results.update(evaluate.validate_chairs(model.module)) + elif val_dataset == "sintel": + results.update(evaluate.validate_sintel(model.module)) + elif val_dataset == "kitti": + results.update(evaluate.validate_kitti(model.module)) + + logger.write_dict(results) + + model.train() + + total_steps += 1 + + if total_steps > cfg.trainer.num_steps: + should_keep_training = False + break + + logger.close() + PATH = cfg.log_dir + "/final" + torch.save(model.state_dict(), PATH) + + PATH = f"checkpoints/{cfg.stage}.pth" + torch.save(model.state_dict(), PATH) + + return PATH + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--name", default="flowformer", help="name your experiment") + parser.add_argument("--stage", help="determines which dataset to use for training") + parser.add_argument("--validation", type=str, nargs="+") + + parser.add_argument( + "--mixed_precision", action="store_true", help="use mixed precision" + ) + + args = parser.parse_args() + + if args.stage == "chairs": + from configs.default import get_cfg + elif args.stage == "things": + from configs.things import get_cfg + elif args.stage == "sintel": + from configs.sintel import get_cfg + elif args.stage == "kitti": + from configs.kitti import get_cfg + elif args.stage == "autoflow": + from configs.autoflow import get_cfg + + cfg = get_cfg() + cfg.update(vars(args)) + process_cfg(cfg) + loguru_logger.add(str(Path(cfg.log_dir) / "log.txt"), encoding="utf8") + loguru_logger.info(cfg) + + torch.manual_seed(1234) + np.random.seed(1234) + + if not os.path.isdir("checkpoints"): + os.mkdir("checkpoints") + + train(cfg) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/flowformer/visualize_flow.py b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/visualize_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..b81923b4451d86432161355f072fd4bac01a4463 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/flowformer/visualize_flow.py @@ -0,0 +1,238 @@ +import sys + +sys.path.append("core") + +from PIL import Image +from glob import glob +import argparse +import os +import time +import numpy as np +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +from configs.submission import get_cfg +from core.utils.misc import process_cfg +import datasets +from utils import flow_viz +from utils import frame_utils +import cv2 +import math +import os.path as osp + +from core.FlowFormer import build_flowformer + +from utils.utils import InputPadder, forward_interpolate +import itertools + +TRAIN_SIZE = [432, 960] + + +def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20): + if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]: + raise ValueError( + f"Overlap should be less than size of patch (got {min_overlap}" + f"for patch size {patch_size})." + ) + if image_shape[0] == TRAIN_SIZE[0]: + hs = list(range(0, image_shape[0], TRAIN_SIZE[0])) + else: + hs = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap)) + if image_shape[1] == TRAIN_SIZE[1]: + ws = list(range(0, image_shape[1], TRAIN_SIZE[1])) + else: + ws = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap)) + + # Make sure the final patch is flush with the image boundary + hs[-1] = image_shape[0] - patch_size[0] + ws[-1] = image_shape[1] - patch_size[1] + return [(h, w) for h in hs for w in ws] + + +def compute_weight( + hws, image_shape, patch_size=TRAIN_SIZE, sigma=1.0, wtype="gaussian" +): + patch_num = len(hws) + h, w = torch.meshgrid(torch.arange(patch_size[0]), torch.arange(patch_size[1])) + h, w = h / float(patch_size[0]), w / float(patch_size[1]) + c_h, c_w = 0.5, 0.5 + h, w = h - c_h, w - c_w + weights_hw = (h**2 + w**2) ** 0.5 / sigma + denorm = 1 / (sigma * math.sqrt(2 * math.pi)) + weights_hw = denorm * torch.exp(-0.5 * (weights_hw) ** 2) + + weights = torch.zeros(1, patch_num, *image_shape) + for idx, (h, w) in enumerate(hws): + weights[:, idx, h : h + patch_size[0], w : w + patch_size[1]] = weights_hw + weights = weights.cuda() + patch_weights = [] + for idx, (h, w) in enumerate(hws): + patch_weights.append( + weights[:, idx : idx + 1, h : h + patch_size[0], w : w + patch_size[1]] + ) + + return patch_weights + + +def compute_flow(model, image1, image2, weights=None): + print(f"computing flow...") + + image_size = image1.shape[1:] + + image1, image2 = image1[None].cuda(), image2[None].cuda() + + hws = compute_grid_indices(image_size) + if weights is None: # no tile + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + flow_pre, _ = model(image1, image2) + + flow_pre = padder.unpad(flow_pre) + flow = flow_pre[0].permute(1, 2, 0).cpu().numpy() + else: # tile + flows = 0 + flow_count = 0 + + for idx, (h, w) in enumerate(hws): + image1_tile = image1[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + image2_tile = image2[:, :, h : h + TRAIN_SIZE[0], w : w + TRAIN_SIZE[1]] + flow_pre, _ = model(image1_tile, image2_tile) + padding = ( + w, + image_size[1] - w - TRAIN_SIZE[1], + h, + image_size[0] - h - TRAIN_SIZE[0], + 0, + 0, + ) + flows += F.pad(flow_pre * weights[idx], padding) + flow_count += F.pad(weights[idx], padding) + + flow_pre = flows / flow_count + flow = flow_pre[0].permute(1, 2, 0).cpu().numpy() + + return flow + + +def compute_adaptive_image_size(image_size): + target_size = TRAIN_SIZE + scale0 = target_size[0] / image_size[0] + scale1 = target_size[1] / image_size[1] + + if scale0 > scale1: + scale = scale0 + else: + scale = scale1 + + image_size = (int(image_size[1] * scale), int(image_size[0] * scale)) + + return image_size + + +def prepare_image(root_dir, viz_root_dir, fn1, fn2, keep_size): + print(f"preparing image...") + print(f"root dir = {root_dir}, fn = {fn1}") + + image1 = frame_utils.read_gen(osp.join(root_dir, fn1)) + image2 = frame_utils.read_gen(osp.join(root_dir, fn2)) + image1 = np.array(image1).astype(np.uint8)[..., :3] + image2 = np.array(image2).astype(np.uint8)[..., :3] + if not keep_size: + dsize = compute_adaptive_image_size(image1.shape[0:2]) + image1 = cv2.resize(image1, dsize=dsize, interpolation=cv2.INTER_CUBIC) + image2 = cv2.resize(image2, dsize=dsize, interpolation=cv2.INTER_CUBIC) + image1 = torch.from_numpy(image1).permute(2, 0, 1).float() + image2 = torch.from_numpy(image2).permute(2, 0, 1).float() + + dirname = osp.dirname(fn1) + filename = osp.splitext(osp.basename(fn1))[0] + + viz_dir = osp.join(viz_root_dir, dirname) + if not osp.exists(viz_dir): + os.makedirs(viz_dir) + + viz_fn = osp.join(viz_dir, filename + ".png") + + return image1, image2, viz_fn + + +def build_model(): + print(f"building model...") + cfg = get_cfg() + model = torch.nn.DataParallel(build_flowformer(cfg)) + model.load_state_dict(torch.load(cfg.model)) + + model.cuda() + model.eval() + + return model + + +def visualize_flow(root_dir, viz_root_dir, model, img_pairs, keep_size): + weights = None + for img_pair in img_pairs: + fn1, fn2 = img_pair + print(f"processing {fn1}, {fn2}...") + + image1, image2, viz_fn = prepare_image( + root_dir, viz_root_dir, fn1, fn2, keep_size + ) + flow = compute_flow(model, image1, image2, weights) + flow_img = flow_viz.flow_to_image(flow) + cv2.imwrite(viz_fn, flow_img[:, :, [2, 1, 0]]) + + +def process_sintel(sintel_dir): + img_pairs = [] + for scene in os.listdir(sintel_dir): + dirname = osp.join(sintel_dir, scene) + image_list = sorted(glob(osp.join(dirname, "*.png"))) + for i in range(len(image_list) - 1): + img_pairs.append((image_list[i], image_list[i + 1])) + + return img_pairs + + +def generate_pairs(dirname, start_idx, end_idx): + img_pairs = [] + for idx in range(start_idx, end_idx): + img1 = osp.join(dirname, f"{idx:06}.png") + img2 = osp.join(dirname, f"{idx+1:06}.png") + # img1 = f'{idx:06}.png' + # img2 = f'{idx+1:06}.png' + img_pairs.append((img1, img2)) + + return img_pairs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--eval_type", default="sintel") + parser.add_argument("--root_dir", default=".") + parser.add_argument("--sintel_dir", default="datasets/Sintel/test/clean") + parser.add_argument("--seq_dir", default="demo_data/mihoyo") + parser.add_argument( + "--start_idx", type=int, default=1 + ) # starting index of the image sequence + parser.add_argument( + "--end_idx", type=int, default=1200 + ) # ending index of the image sequence + parser.add_argument("--viz_root_dir", default="viz_results") + parser.add_argument( + "--keep_size", action="store_true" + ) # keep the image size, or the image will be adaptively resized. + + args = parser.parse_args() + + root_dir = args.root_dir + viz_root_dir = args.viz_root_dir + + model = build_model() + + if args.eval_type == "sintel": + img_pairs = process_sintel(args.sintel_dir) + elif args.eval_type == "seq": + img_pairs = generate_pairs(args.seq_dir, args.start_idx, args.end_idx) + with torch.no_grad(): + visualize_flow(root_dir, viz_root_dir, model, img_pairs, args.keep_size) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/gimm.py b/blissful_tuner/gimmvfi/generalizable_INR/gimm.py new file mode 100644 index 0000000000000000000000000000000000000000..4d18f8c8b822b78c767e992596f6ef7796f48786 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/gimm.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# motif: https://github.com/sichun233746/MoTIF +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .configs import GIMMConfig +from .modules.coord_sampler import CoordSampler3D +from .modules.fi_components import LateralBlock +from .modules.hyponet import HypoNet +from .modules.fi_utils import warp + +from .modules.softsplat import softsplat + + +class GIMM(nn.Module): + Config = GIMMConfig + + def __init__(self, config: GIMMConfig): + super().__init__() + self.config = config = config.copy() + self.hyponet_config = config.hyponet + self.coord_sampler = CoordSampler3D(config.coord_range) + self.fwarp_type = config.fwarp_type + + # Motion Encoder + channel = 32 + in_dim = 2 + self.cnn_encoder = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + LateralBlock(channel), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + + # Latent Refiner + channel = 64 + in_dim = 64 + self.res_conv = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + self.g_filter = torch.nn.Parameter( + torch.FloatTensor( + [ + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + [1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0], + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + ] + ).reshape(1, 1, 1, 3, 3), + requires_grad=False, + ) + self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + + self.hyponet = HypoNet(config.hyponet, add_coord_dim=32) + + def cal_splatting_weights(self, raft_flow01, raft_flow10): + batch_size = raft_flow01.shape[0] + raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0) + + ## flow variance metric + sqaure_mean, mean_square = torch.split( + F.conv3d( + F.pad( + torch.cat([raft_flows**2, raft_flows], 1), + (1, 1, 1, 1), + mode="reflect", + ).unsqueeze(1), + self.g_filter, + ).squeeze(1), + 2, + dim=1, + ) + var = ( + (sqaure_mean - mean_square**2) + .clamp(1e-9, None) + .sqrt() + .mean(1) + .unsqueeze(1) + ) + var01 = var[:batch_size] + var10 = var[batch_size:] + + ## flow warp metirc + f01_warp = -warp(raft_flow10, raft_flow01) + f10_warp = -warp(raft_flow01, raft_flow10) + err01 = ( + torch.nn.functional.l1_loss( + input=f01_warp, target=raft_flow01, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + err02 = ( + torch.nn.functional.l1_loss( + input=f10_warp, target=raft_flow10, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + + weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v) + weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v) + + return weights1, weights2 + + def forward( + self, xs, coord=None, keep_xs_shape=True, ori_flow=None, timesteps=None + ): + coord = self.sample_coord_input(xs) if coord is None else coord + raft_flow01 = ori_flow[:, :, 0] + raft_flow10 = ori_flow[:, :, 1] + + # calculate splatting metrics + weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10) + # b,c,h,w + pixel_latent_0 = self.cnn_encoder(xs[:, :, 0]) + pixel_latent_1 = self.cnn_encoder(xs[:, :, 1]) + pixel_latent = [] + + modulation_params_dict = None + strtype = self.fwarp_type + if isinstance(timesteps, list): + assert isinstance(coord, list) + assert len(timesteps) == len(coord) + for i, cur_t in enumerate(timesteps): + cur_t = cur_t.reshape(-1, 1, 1, 1) + tmp_pixel_latent_0 = softsplat( + tenIn=pixel_latent_0, + tenFlow=raft_flow01 * cur_t, + tenMetric=weights1, + strMode=strtype + "-zeroeps", + ) + tmp_pixel_latent_1 = softsplat( + tenIn=pixel_latent_1, + tenFlow=raft_flow10 * (1 - cur_t), + tenMetric=weights2, + strMode=strtype + "-zeroeps", + ) + tmp_pixel_latent = torch.cat( + [tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1 + ) + tmp_pixel_latent = tmp_pixel_latent + self.res_conv( + torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1) + ) + pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1)) + + all_outputs = [] + for idx, c in enumerate(coord): + outputs = self.hyponet( + c, + modulation_params_dict=modulation_params_dict, + pixel_latent=pixel_latent[idx], + ) + if keep_xs_shape: + permute_idx_range = [i for i in range(1, xs.ndim - 1)] + outputs = outputs.permute(0, -1, *permute_idx_range) + all_outputs.append(outputs) + return all_outputs + + else: + cur_t = timesteps.reshape(-1, 1, 1, 1) + tmp_pixel_latent_0 = softsplat( + tenIn=pixel_latent_0, + tenFlow=raft_flow01 * cur_t, + tenMetric=weights1, + strMode=strtype + "-zeroeps", + ) + tmp_pixel_latent_1 = softsplat( + tenIn=pixel_latent_1, + tenFlow=raft_flow10 * (1 - cur_t), + tenMetric=weights2, + strMode=strtype + "-zeroeps", + ) + tmp_pixel_latent = torch.cat( + [tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1 + ) + tmp_pixel_latent = tmp_pixel_latent + self.res_conv( + torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1) + ) + pixel_latent = tmp_pixel_latent.permute(0, 2, 3, 1) + + # predict all pixels of coord after applying the modulation_parms into hyponet + outputs = self.hyponet( + coord, + modulation_params_dict=modulation_params_dict, + pixel_latent=pixel_latent, + ) + if keep_xs_shape: + permute_idx_range = [i for i in range(1, xs.ndim - 1)] + outputs = outputs.permute(0, -1, *permute_idx_range) + return outputs + + def compute_loss(self, preds, targets, reduction="mean", single=False): + assert reduction in ["mean", "sum", "none"] + batch_size = preds.shape[0] + sample_mses = 0 + assert preds.shape[2] == 1 + assert targets.shape[2] == 1 + for i in range(preds.shape[2]): + sample_mses += torch.reshape( + (preds[:, :, i] - targets[:, :, i]) ** 2, (batch_size, -1) + ).mean(dim=-1) + sample_mses = sample_mses / preds.shape[2] + if reduction == "mean": + total_loss = sample_mses.mean() + psnr = (-10 * torch.log10(sample_mses)).mean() + elif reduction == "sum": + total_loss = sample_mses.sum() + psnr = (-10 * torch.log10(sample_mses)).sum() + else: + total_loss = sample_mses + psnr = -10 * torch.log10(sample_mses) + + return {"loss_total": total_loss, "mse": total_loss, "psnr": psnr} + + def sample_coord_input( + self, + batch_size, + s_shape, + t_ids, + coord_range=None, + upsample_ratio=1.0, + device=None, + ): + assert device is not None + assert coord_range is None + coord_inputs = self.coord_sampler( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + return coord_inputs diff --git a/blissful_tuner/gimmvfi/generalizable_INR/gimmvfi_f.py b/blissful_tuner/gimmvfi/generalizable_INR/gimmvfi_f.py new file mode 100644 index 0000000000000000000000000000000000000000..73be6342a55b20f0ec668aa985dceb7e8e6cfe77 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/gimmvfi_f.py @@ -0,0 +1,468 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# amt: https://github.com/MCG-NKU/AMT +# motif: https://github.com/sichun233746/MoTIF +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .configs import GIMMVFIConfig +from .modules.coord_sampler import CoordSampler3D +from .modules.hyponet import HypoNet +from .modules.fi_components import * +#from .flowformer import initialize_Flowformer +from .modules.fi_utils import normalize_flow, unnormalize_flow, warp, resize +from .raft.corr import BidirCorrBlock +from .modules.softsplat import softsplat + + +class GIMMVFI_F(nn.Module): + Config = GIMMVFIConfig + + def __init__(self, config: GIMMVFIConfig): + super().__init__() + self.config = config = config.copy() + self.hyponet_config = config.hyponet + self.raft_iter = config.raft_iter + + ######### Encoder and Decoder Settings ######### + #self.flow_estimator = initialize_Flowformer() + f_dims = [256, 128] + + skip_channels = f_dims[-1] // 2 + self.num_flows = 3 + self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels) + self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels) + + self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0) + self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None) + + self.amt_comb_block = nn.Sequential( + nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + nn.PReLU(6 * self.num_flows), + nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), + ) + + ################ GIMM settings ################# + self.coord_sampler = CoordSampler3D(config.coord_range) + + self.g_filter = torch.nn.Parameter( + torch.FloatTensor( + [ + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + [1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0], + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + ] + ).reshape(1, 1, 1, 3, 3), + requires_grad=False, + ) + self.fwarp_type = config.fwarp_type + + self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + + channel = 32 + in_dim = 2 + self.cnn_encoder = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + LateralBlock(channel), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + channel = 64 + in_dim = 64 + self.res_conv = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + + self.hyponet = HypoNet(config.hyponet, add_coord_dim=32) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock( + cdim=cdim, + hidden_dim=192, + flow_dim=64, + corr_dim=256, + corr_dim2=192, + fc_dim=188, + scale_factor=scale_factor, + corr_levels=4, + radius=4, + ) + + def cal_bidirection_flow(self, im0, im1): + f01, features0, fnet0 = self.flow_estimator( + im0, im1, return_feat=True, iters=None + ) + f10, features1, fnet1 = self.flow_estimator( + im1, im0, return_feat=True, iters=None + ) + f01 = f01[0] + f10 = f10[0] + corr_fn = BidirCorrBlock(fnet0, fnet1, radius=4) + flow01 = f01.unsqueeze(2) + flow10 = f10.unsqueeze(2) + noraml_flows = torch.cat([flow01, -flow10], dim=2) + noraml_flows, flow_scalers = normalize_flow(noraml_flows) + + ori_flows = torch.cat([flow01, flow10], dim=2) + return ( + noraml_flows, + ori_flows, + flow_scalers, + features0, + features1, + corr_fn, + torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2), + ) + + def predict_flow(self, f, coord, t, flows): + raft_flow01 = flows[:, :, 0].detach() + raft_flow10 = flows[:, :, 1].detach() + + # calculate splatting metrics + weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10) + strtype = self.fwarp_type + "-zeroeps" + + # b,c,h,w + pixel_latent_0 = self.cnn_encoder(f[:, :, 0]) + pixel_latent_1 = self.cnn_encoder(f[:, :, 1]) + pixel_latent = [] + + for i, cur_t in enumerate(t): + cur_t = cur_t.reshape(-1, 1, 1, 1) + + tmp_pixel_latent_0 = softsplat( + tenIn=pixel_latent_0, + tenFlow=raft_flow01 * cur_t, + tenMetric=weights1, + strMode=strtype, + ) + tmp_pixel_latent_1 = softsplat( + tenIn=pixel_latent_1, + tenFlow=raft_flow10 * (1 - cur_t), + tenMetric=weights2, + strMode=strtype, + ) + + tmp_pixel_latent = torch.cat( + [tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1 + ) + tmp_pixel_latent = tmp_pixel_latent + self.res_conv( + torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1) + ) + pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1)) + + all_outputs = [] + permute_idx_range = [i for i in range(1, f.ndim - 1)] + for idx, c in enumerate(coord): + assert c[0][0, 0, 0, 0, 0] == t[idx][0].squeeze() + assert isinstance(c, tuple) + + if c[1] is None: + outputs = self.hyponet( + c, modulation_params_dict=None, pixel_latent=pixel_latent[idx] + ).permute(0, -1, *permute_idx_range) + else: + outputs = self.hyponet( + c, modulation_params_dict=None, pixel_latent=pixel_latent[idx] + ) + all_outputs.append(outputs) + + return all_outputs + + def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1): + ft0 = scale * resize(ft0, scale_factor=scale) + ft1 = scale * resize(ft1, scale_factor=scale) + mask = resize(mask, scale_factor=scale).sigmoid() + img0_warp = warp(img0, ft0) + img1_warp = warp(img1, ft1) + img_warp = mask * img0_warp + (1 - mask) * img1_warp + return img_warp + + def frame_synthesize( + self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None + ): + """ + flow_t: b,2,h,w + cur_t: b,1,1,1 + """ + batch_size = img_xs.shape[0] + img0 = 2 * img_xs[:, :, 0] - 1.0 + img1 = 2 * img_xs[:, :, 1] - 1.0 + + ##################### update the predicted flow ##################### + ## initialize coordinates for looking up + lookup_coord = self.flow_estimator.build_coord(img_xs[:, :, 0]).to( + img_xs[:, :, 0].device + ) + + flow_t0_fullsize = flow_t * (-cur_t) + flow_t1_fullsize = flow_t * (1.0 - cur_t) + + inv = 1 / 4 + flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv) + flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv) + + ############################# scale 1/4 ############################# + # i. Initialize feature t at scale 1/4 + flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder( + features0[-1], + features1[-1], + flow_t0_inr4, + flow_t1_inr4, + img0=img0, + img1=img1, + ) + mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:] + img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4) + img_warp_4 = (img_warp_4 + 1.0) / 2 + img_warp_4 = torch.clamp(img_warp_4, 0, 1) + + corr_4, flow_4_lr = self._amt_corr_scale_lookup( + corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2 + ) + + delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + flowt0_4 = flowt0_4 + delta_flow0_4 + flowt1_4 = flowt1_4 + delta_flow1_4 + ft_4_ = ft_4_ + delta_ft_4_ + + # iii. residue update with lookup corr + corr_4 = resize(corr_4, scale_factor=2.0) + + flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1) + delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4) + flowt0_4 = flowt0_4 + delta_flow_4[:, :2] + flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4] + ft_4_ = ft_4_ + delta_ft_4_ + + ############################# scale 1/1 ############################# + flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder( + ft_4_, + features0[0], + features1[0], + flowt0_4, + flowt1_4, + mask=mask_4_, + img0=img0, + img1=img1, + ) + + if full_img is not None: + img0 = 2 * full_img[:, :, 0] - 1.0 + img1 = 2 * full_img[:, :, 1] - 1.0 + inv = img1.shape[2] / flowt0_1.shape[2] + flowt0_1 = inv * resize(flowt0_1, scale_factor=inv) + flowt1_1 = inv * resize(flowt1_1, scale_factor=inv) + flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv) + flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv) + mask = resize(mask, scale_factor=inv) + img_res = resize(img_res, scale_factor=inv) + + imgt_pred = multi_flow_combine( + self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None + ) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + ###################################################################### + + flowt0_1 = flowt0_1.reshape( + batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] + ) + flowt1_1 = flowt1_1.reshape( + batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] + ) + + flowt0_pred = [flowt0_1, flowt0_4] + flowt1_pred = [flowt1_1, flowt1_4] + other_pred = [img_warp_4] + return imgt_pred, flowt0_pred, flowt1_pred, other_pred + + def forward(self, img_xs, coord=None, t=None, ds_factor=None): + assert isinstance(t, list) + assert isinstance(coord, list) + assert len(t) == len(coord) + full_size_img = None + if ds_factor is not None: + full_size_img = img_xs.clone() + img_xs = torch.cat( + [ + resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2), + resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2), + ], + dim=2, + ) + + ( + normal_flows, + flows, + flow_scalers, + features0, + features1, + corr_fn, + preserved_raft_flows, + ) = self.cal_bidirection_flow(255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1]) + assert coord is not None + + # List of flows + normal_inr_flows = self.predict_flow(normal_flows, coord, t, flows) + + ############ Unnormalize the predicted/reconstructed flow ############ + start_idx = 0 + if coord[0][1] is not None: + # Subsmapled flows for reconstruction supervision in the GIMM module + # In such case, first two coords in the list are subsampled for supervision up-mentioned + # Normalized flow_t towards positive t-axis + assert len(coord) > 2 + flow_t = [ + unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze() + for i in range(2, len(coord)) + ] + start_idx = 2 + else: + flow_t = [ + unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze() + for i in range(len(coord)) + ] + + imgt_preds, flowt0_preds, flowt1_preds, all_others = [], [], [], [] + + for idx in range(start_idx, len(coord)): + cur_flow_t = flow_t[idx - start_idx] + cur_t = t[idx].reshape(-1, 1, 1, 1) + if cur_flow_t.ndim != 4: + cur_flow_t = cur_flow_t.unsqueeze(0) + assert cur_flow_t.ndim == 4 + + imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize( + img_xs, + cur_flow_t, + features0, + features1, + corr_fn, + cur_t, + full_img=full_size_img, + ) + + imgt_preds.append(imgt_pred) + flowt0_preds.append(flowt0_pred) + flowt1_preds.append(flowt1_pred) + all_others.append(others) + + return { + "imgt_pred": imgt_preds, + "other_pred": all_others, + "flowt0_pred": flowt0_preds, + "flowt1_pred": flowt1_preds, + "raft_flow": preserved_raft_flows, + "ninrflow": normal_inr_flows, + "nflow": normal_flows, + "flowt": flow_t, + } + + def warp_frame(self, frame, flow): + return warp(frame, flow) + + def sample_coord_input( + self, + batch_size, + s_shape, + t_ids, + coord_range=None, + upsample_ratio=1.0, + device=None, + ): + assert device is not None + assert coord_range is None + coord_inputs = self.coord_sampler( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + return coord_inputs + + def cal_splatting_weights(self, raft_flow01, raft_flow10): + batch_size = raft_flow01.shape[0] + raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0) + + ## flow variance metric + sqaure_mean, mean_square = torch.split( + F.conv3d( + F.pad( + torch.cat([raft_flows**2, raft_flows], 1), + (1, 1, 1, 1), + mode="reflect", + ).unsqueeze(1), + self.g_filter, + ).squeeze(1), + 2, + dim=1, + ) + var = ( + (sqaure_mean - mean_square**2) + .clamp(1e-9, None) + .sqrt() + .mean(1) + .unsqueeze(1) + ) + var01 = var[:batch_size] + var10 = var[batch_size:] + + ## flow warp metirc + f01_warp = -warp(raft_flow10, raft_flow01) + f10_warp = -warp(raft_flow01, raft_flow10) + err01 = ( + torch.nn.functional.l1_loss( + input=f01_warp, target=raft_flow01, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + err02 = ( + torch.nn.functional.l1_loss( + input=f10_warp, target=raft_flow10, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + + weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v) + weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v) + + return weights1, weights2 + + def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t0_scale = 1.0 / embt + t1_scale = 1.0 / (1.0 - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow diff --git a/blissful_tuner/gimmvfi/generalizable_INR/gimmvfi_r.py b/blissful_tuner/gimmvfi/generalizable_INR/gimmvfi_r.py new file mode 100644 index 0000000000000000000000000000000000000000..62e4c38a117db815967c82b2556886b5501f6608 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/gimmvfi_r.py @@ -0,0 +1,508 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# amt: https://github.com/MCG-NKU/AMT +# motif: https://github.com/sichun233746/MoTIF +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import torch +import torch.nn as nn + +from .configs import GIMMVFIConfig +from .modules.coord_sampler import CoordSampler3D +from .modules.hyponet import HypoNet +from .modules.fi_components import * +from .raft import initialize_RAFT +from .modules.fi_utils import ( + normalize_flow, + unnormalize_flow, + warp, + resize, + build_coord, +) +import torch.nn.functional as F + +from .raft.corr import BidirCorrBlock +from .modules.softsplat import softsplat + + +class GIMMVFI_R(nn.Module): + Config = GIMMVFIConfig + + def __init__(self, config: GIMMVFIConfig): + super().__init__() + self.config = config = config.copy() + self.hyponet_config = config.hyponet + self.raft_iter = 20 + + ######### Encoder and Decoder Settings ######### + #self.flow_estimator = initialize_RAFT() + cur_f_dims = [128, 96] + f_dims = [256, 128] + self.dtype = torch.float32 + + skip_channels = f_dims[-1] // 2 + self.num_flows = 3 + + self.amt_last_cproj = nn.Conv2d(cur_f_dims[0], f_dims[0], 1) + self.amt_second_last_cproj = nn.Conv2d(cur_f_dims[1], f_dims[1], 1) + self.amt_fproj = nn.Conv2d(f_dims[0], f_dims[0], 1) + self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels) + self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels) + + self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0) + self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None) + + self.amt_comb_block = nn.Sequential( + nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + nn.PReLU(6 * self.num_flows), + nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), + ) + + ################ GIMM settings ################# + self.coord_sampler = CoordSampler3D(config.coord_range) + + self.g_filter = torch.nn.Parameter( + torch.FloatTensor( + [ + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + [1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0], + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + ] + ).reshape(1, 1, 1, 3, 3), + requires_grad=False, + ) + self.fwarp_type = config.fwarp_type + + self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + + channel = 32 + in_dim = 2 + self.cnn_encoder = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + LateralBlock(channel), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + channel = 64 + in_dim = 64 + self.res_conv = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + + self.hyponet = HypoNet(config.hyponet, add_coord_dim=32) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock( + cdim=cdim, + hidden_dim=192, + flow_dim=64, + corr_dim=256, + corr_dim2=192, + fc_dim=188, + scale_factor=scale_factor, + corr_levels=4, + radius=4, + ) + + def cal_bidirection_flow(self, im0, im1, iters=20): + f01, features0, fnet0 = self.flow_estimator( + im0, im1, return_feat=True, iters=20 + ) + f10, features1, fnet1 = self.flow_estimator( + im1, im0, return_feat=True, iters=20 + ) + corr_fn = BidirCorrBlock(self.amt_fproj(fnet0), self.amt_fproj(fnet1), radius=4) + features0 = [ + self.amt_second_last_cproj(features0[0]), + self.amt_last_cproj(features0[1]), + ] + features1 = [ + self.amt_second_last_cproj(features1[0]), + self.amt_last_cproj(features1[1]), + ] + flow01 = f01.unsqueeze(2) + flow10 = f10.unsqueeze(2) + noraml_flows = torch.cat([flow01, -flow10], dim=2) + noraml_flows, flow_scalers = normalize_flow(noraml_flows) + + ori_flows = torch.cat([flow01, flow10], dim=2) + return ( + noraml_flows, + ori_flows, + flow_scalers, + features0, + features1, + corr_fn, + torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2), + ) + + def predict_flow(self, f, coord, t, flows): + raft_flow01 = flows[:, :, 0].detach() + raft_flow10 = flows[:, :, 1].detach() + + # calculate splatting metrics + weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10) + strtype = self.fwarp_type + "-zeroeps" + + # b,c,h,w + pixel_latent_0 = self.cnn_encoder(f[:, :, 0]) + pixel_latent_1 = self.cnn_encoder(f[:, :, 1]) + pixel_latent = [] + + for i, cur_t in enumerate(t): + cur_t = cur_t.reshape(-1, 1, 1, 1) + + tmp_pixel_latent_0 = softsplat( + tenIn=pixel_latent_0, + tenFlow=raft_flow01 * cur_t, + tenMetric=weights1, + strMode=strtype, + ) + tmp_pixel_latent_1 = softsplat( + tenIn=pixel_latent_1, + tenFlow=raft_flow10 * (1 - cur_t), + tenMetric=weights2, + strMode=strtype, + ) + + tmp_pixel_latent = torch.cat( + [tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1 + ) + tmp_pixel_latent = tmp_pixel_latent + self.res_conv( + torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1) + ) + pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1)) + + all_outputs = [] + permute_idx_range = [i for i in range(1, f.ndim - 1)] + for idx, c in enumerate(coord): + assert c[0][0, 0, 0, 0, 0] == t[idx][0].squeeze() + assert isinstance(c, tuple) + + if c[1] is None: + outputs = self.hyponet( + c, modulation_params_dict=None, pixel_latent=pixel_latent[idx] + ).permute(0, -1, *permute_idx_range) + else: + outputs = self.hyponet( + c, modulation_params_dict=None, pixel_latent=pixel_latent[idx] + ) + all_outputs.append(outputs) + + return all_outputs + + def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1): + ft0 = scale * resize(ft0, scale_factor=scale) + ft1 = scale * resize(ft1, scale_factor=scale) + mask = resize(mask, scale_factor=scale).sigmoid() + img0_warp = warp(img0, ft0) + img1_warp = warp(img1, ft1) + img_warp = mask * img0_warp + (1 - mask) * img1_warp + return img_warp + + def frame_synthesize( + self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None + ): + """ + flow_t: b,2,h,w + cur_t: b,1,1,1 + """ + batch_size = img_xs.shape[0] # b,c,t,h,w + img0 = 2 * img_xs[:, :, 0] - 1.0 + img1 = 2 * img_xs[:, :, 1] - 1.0 + + ##################### update the predicted flow ##################### + ##initialize coordinates for looking up + lookup_coord = build_coord(img_xs[:, :, 0]).to( + img_xs[:, :, 0].device + ) # H//8,W//8 + + flow_t0_fullsize = flow_t * (-cur_t) + flow_t1_fullsize = flow_t * (1.0 - cur_t) + + inv = 1 / 4 + flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv) + flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv) + + ############################# scale 1/4 ############################# + # i. Initialize feature t at scale 1/4 + flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder( + features0[-1], + features1[-1], + flow_t0_inr4, + flow_t1_inr4, + img0=img0, + img1=img1, + ) + features0, features1 = features0[:-1], features1[:-1] + + mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:] + img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4) + img_warp_4 = (img_warp_4 + 1.0) / 2 + img_warp_4 = torch.clamp(img_warp_4, 0, 1) + + corr_4, flow_4_lr = self._amt_corr_scale_lookup( + corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2 + ) + + delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + flowt0_4 = flowt0_4 + delta_flow0_4 + flowt1_4 = flowt1_4 + delta_flow1_4 + ft_4_ = ft_4_ + delta_ft_4_ + + # iii. residue update with lookup corr + corr_4 = resize(corr_4, scale_factor=2.0) + + flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1) + delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4) + flowt0_4 = flowt0_4 + delta_flow_4[:, :2] + flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4] + ft_4_ = ft_4_ + delta_ft_4_ + + ############################# scale 1/1 ############################# + flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder( + ft_4_, + features0[0], + features1[0], + flowt0_4, + flowt1_4, + mask=mask_4_, + img0=img0, + img1=img1, + ) + + if full_img is not None: + img0 = 2 * full_img[:, :, 0] - 1.0 + img1 = 2 * full_img[:, :, 1] - 1.0 + inv = img1.shape[2] / flowt0_1.shape[2] + flowt0_1 = inv * resize(flowt0_1, scale_factor=inv) + flowt1_1 = inv * resize(flowt1_1, scale_factor=inv) + flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv) + flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv) + mask = resize(mask, scale_factor=inv) + img_res = resize(img_res, scale_factor=inv) + + imgt_pred = multi_flow_combine( + self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None + ) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + ###################################################################### + + flowt0_1 = flowt0_1.reshape( + batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] + ) + flowt1_1 = flowt1_1.reshape( + batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] + ) + + flowt0_pred = [flowt0_1, flowt0_4] + flowt1_pred = [flowt1_1, flowt1_4] + other_pred = [img_warp_4] + return imgt_pred, flowt0_pred, flowt1_pred, other_pred + + def forward(self, img_xs, coord=None, t=None, iters=None, ds_factor=None): + assert isinstance(t, list) + assert isinstance(coord, list) + assert len(t) == len(coord) + full_size_img = None + if ds_factor is not None: + full_size_img = img_xs.clone() + img_xs = torch.cat( + [ + resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2), + resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2), + ], + dim=2, + ) + + iters = self.raft_iter if iters is None else iters + ( + normal_flows, + flows, + flow_scalers, + features0, + features1, + corr_fn, + preserved_raft_flows, + ) = self.cal_bidirection_flow( + 255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1], iters=iters + ) + assert coord is not None + + # List of flows + normal_inr_flows = self.predict_flow(normal_flows, coord, t, flows) + + ############ Unnormalize the predicted/reconstructed flow ############ + start_idx = 0 + if coord[0][1] is not None: + # Subsmapled flows for reconstruction supervision in the GIMM module + # In such case, by default, first two coords are subsampled for supervision up-mentioned + # normalized flow_t versus positive t-axis + assert len(coord) > 2 + flow_t = [ + unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze() + for i in range(2, len(coord)) + ] + start_idx = 2 + else: + flow_t = [ + unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze() + for i in range(len(coord)) + ] + + imgt_preds, flowt0_preds, flowt1_preds, all_others = [], [], [], [] + + for idx in range(start_idx, len(coord)): + cur_flow_t = flow_t[idx - start_idx] + cur_t = t[idx].reshape(-1, 1, 1, 1) + if cur_flow_t.ndim != 4: + cur_flow_t = cur_flow_t.unsqueeze(0) + assert cur_flow_t.ndim == 4 + + imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize( + img_xs, + cur_flow_t, + features0, + features1, + corr_fn, + cur_t, + full_img=full_size_img, + ) + + imgt_preds.append(imgt_pred) + flowt0_preds.append(flowt0_pred) + flowt1_preds.append(flowt1_pred) + all_others.append(others) + + return { + "imgt_pred": imgt_preds, + "other_pred": all_others, + "flowt0_pred": flowt0_preds, + "flowt1_pred": flowt1_preds, + "raft_flow": preserved_raft_flows, + "ninrflow": normal_inr_flows, + "nflow": normal_flows, + "flowt": flow_t, + } + + def warp_frame(self, frame, flow): + return warp(frame, flow) + + def compute_psnr(self, preds, targets, reduction="mean"): + assert reduction in ["mean", "sum", "none"] + batch_size = preds.shape[0] + sample_mses = torch.reshape((preds - targets) ** 2, (batch_size, -1)).mean( + dim=-1 + ) + + if reduction == "mean": + psnr = (-10 * torch.log10(sample_mses)).mean() + elif reduction == "sum": + psnr = (-10 * torch.log10(sample_mses)).sum() + else: + psnr = -10 * torch.log10(sample_mses) + + return psnr + + def sample_coord_input( + self, + batch_size, + s_shape, + t_ids, + coord_range=None, + upsample_ratio=1.0, + device=None, + ): + assert device is not None + assert coord_range is None + coord_inputs = self.coord_sampler( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + return coord_inputs + + def cal_splatting_weights(self, raft_flow01, raft_flow10): + batch_size = raft_flow01.shape[0] + raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0) + + ## flow variance metric + sqaure_mean, mean_square = torch.split( + F.conv3d( + F.pad( + torch.cat([raft_flows**2, raft_flows], 1), + (1, 1, 1, 1), + mode="reflect", + ).unsqueeze(1), + self.g_filter, + ).squeeze(1), + 2, + dim=1, + ) + var = ( + (sqaure_mean - mean_square**2) + .clamp(1e-9, None) + .sqrt() + .mean(1) + .unsqueeze(1) + ) + var01 = var[:batch_size] + var10 = var[batch_size:] + + ## flow warp metirc + f01_warp = -warp(raft_flow10, raft_flow01) + f10_warp = -warp(raft_flow01, raft_flow10) + err01 = ( + torch.nn.functional.l1_loss( + input=f01_warp, target=raft_flow01, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + err02 = ( + torch.nn.functional.l1_loss( + input=f10_warp, target=raft_flow10, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + + weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v) + weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v) + + return weights1, weights2 + + def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t0_scale = 1.0 / embt + t1_scale = 1.0 / (1.0 - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow diff --git a/blissful_tuner/gimmvfi/generalizable_INR/modules/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/blissful_tuner/gimmvfi/generalizable_INR/modules/coord_sampler.py b/blissful_tuner/gimmvfi/generalizable_INR/modules/coord_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5225c2a26ef25205257429f48d5b0046bca83a --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/modules/coord_sampler.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import torch +import torch.nn as nn + + +class CoordSampler3D(nn.Module): + def __init__(self, coord_range, t_coord_only=False): + super().__init__() + self.coord_range = coord_range + self.t_coord_only = t_coord_only + + def shape2coordinate( + self, + batch_size, + spatial_shape, + t_ids, + coord_range=(-1.0, 1.0), + upsample_ratio=1, + device=None, + ): + coords = [] + assert isinstance(t_ids, list) + _coords = torch.tensor(t_ids, device=device) / 1.0 + coords.append(_coords.to(torch.float32)) + for num_s in spatial_shape: + num_s = int(num_s * upsample_ratio) + _coords = (0.5 + torch.arange(num_s, device=device)) / num_s + _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords + coords.append(_coords) + coords = torch.meshgrid(*coords, indexing="ij") + coords = torch.stack(coords, dim=-1) + ones_like_shape = (1,) * coords.ndim + coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape) + return coords # (B,T,H,W,3) + + def batchshape2coordinate( + self, + batch_size, + spatial_shape, + t_ids, + coord_range=(-1.0, 1.0), + upsample_ratio=1, + device=None, + ): + coords = [] + _coords = torch.tensor(1, device=device) + coords.append(_coords.to(torch.float32)) + for num_s in spatial_shape: + num_s = int(num_s * upsample_ratio) + _coords = (0.5 + torch.arange(num_s, device=device)) / num_s + _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords + coords.append(_coords) + coords = torch.meshgrid(*coords, indexing="ij") + coords = torch.stack(coords, dim=-1) + ones_like_shape = (1,) * coords.ndim + # Now coords b,1,h,w,3, coords[...,0]=1. + coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape) + # assign per-sample timestep within the batch + coords[..., :1] = coords[..., :1] * t_ids.reshape(-1, 1, 1, 1, 1) + return coords + + def forward( + self, + batch_size, + s_shape, + t_ids, + coord_range=None, + upsample_ratio=1.0, + device=None, + ): + coord_range = self.coord_range if coord_range is None else coord_range + if isinstance(t_ids, list): + coords = self.shape2coordinate( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + elif isinstance(t_ids, torch.Tensor): + coords = self.batchshape2coordinate( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + if self.t_coord_only: + coords = coords[..., :1] + return coords diff --git a/blissful_tuner/gimmvfi/generalizable_INR/modules/fi_components.py b/blissful_tuner/gimmvfi/generalizable_INR/modules/fi_components.py new file mode 100644 index 0000000000000000000000000000000000000000..f3bfcdd8c57352b1211c107580564d7088096c1b --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/modules/fi_components.py @@ -0,0 +1,340 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# amt: https://github.com/MCG-NKU/AMT +# motif: https://github.com/sichun233746/MoTIF +# -------------------------------------------------------- + +import torch +import torch.nn as nn +from .fi_utils import warp, resize + + +class LateralBlock(nn.Module): + def __init__(self, dim): + super(LateralBlock, self).__init__() + self.layers = nn.Sequential( + nn.Conv2d(dim, dim, 3, 1, 1, bias=True), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(dim, dim, 3, 1, 1, bias=True), + ) + + def forward(self, x): + res = x + x = self.layers(x) + return x + res + + +def convrelu( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=True, +): + return nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias, + ), + nn.PReLU(out_channels), + ) + + +def multi_flow_combine( + comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None +): + assert mean is None + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = ( + mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w) + if mask is not None + else None + ) + img_res = ( + img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w) + if img_res is not None + else 0 + ) + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = ( + torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1) + if mean is not None + else 0 + ) + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + + res = comb_block(img_warps.view(b, -1, h, w)) + imgt_pred = img_warps.mean(1) + res + + imgt_pred = (imgt_pred + 1.0) / 2 + + return imgt_pred + + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ), + nn.PReLU(in_channels), + ) + self.conv2 = nn.Sequential( + nn.Conv2d( + side_channels, + side_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ), + nn.PReLU(side_channels), + ) + self.conv3 = nn.Sequential( + nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ), + nn.PReLU(in_channels), + ) + self.conv4 = nn.Sequential( + nn.Conv2d( + side_channels, + side_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ), + nn.PReLU(side_channels), + ) + self.conv5 = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, : -self.side_channels, ...] + side_feat = out[:, -self.side_channels :, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, : -self.side_channels, ...] + side_feat = out[:, -self.side_channels :, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + + +class BasicUpdateBlock(nn.Module): + def __init__( + self, + cdim, + hidden_dim, + flow_dim, + corr_dim, + corr_dim2, + fc_dim, + corr_levels=4, + radius=3, + scale_factor=None, + out_num=1, + ): + super(BasicUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) ** 2 + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = ( + resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net + ) + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize( + delta_flow, scale_factor=self.scale_factor + ) + return delta_net, delta_flow + + +def get_bn(): + return nn.BatchNorm2d + + +class NewInitDecoder(nn.Module): + def __init__(self, in_ch, skip_ch): + super().__init__() + norm_layer = get_bn() + + self.upsample = nn.Sequential( + nn.PixelShuffle(2), + convrelu(in_ch // 4, in_ch // 4, 5, 1, 2), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 2), + nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1), + norm_layer(in_ch // 2), + nn.ReLU(inplace=True), + ) + + in_ch = in_ch // 2 + self.convblock = nn.Sequential( + convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0), + ResBlock(in_ch, skip_ch), + ResBlock(in_ch, skip_ch), + ResBlock(in_ch, skip_ch), + nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True), + ) + + def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None): + f0 = self.upsample(f0) + f1 = self.upsample(f1) + f0_warp_ks = warp(f0, flow0_in) + f1_warp_ks = warp(f1, flow1_in) + + f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1) + + assert img0 is not None + assert img1 is not None + scale_factor = f_in.shape[2] / img0.shape[2] + img0 = resize(img0, scale_factor=scale_factor) + img1 = resize(img1, scale_factor=scale_factor) + warped_img0 = warp(img0, flow0_in) + warped_img1 = warp(img1, flow1_in) + f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1) + + out = self.convblock(f_in) + ft_ = out[:, 4:, ...] + flow0 = flow0_in + out[:, :2, ...] + flow1 = flow1_in + out[:, 2:4, ...] + return flow0, flow1, ft_ + + +class NewMultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(NewMultiFlowDecoder, self).__init__() + norm_layer = get_bn() + + self.upsample = nn.Sequential( + nn.PixelShuffle(2), + nn.PixelShuffle(2), + convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 2), + nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1), + norm_layer(in_ch // 2), + nn.ReLU(inplace=True), + ) + + self.num_flows = num_flows + ch_factor = 2 + self.convblock = nn.Sequential( + convrelu(in_ch * ch_factor + 17, in_ch * ch_factor), + ResBlock(in_ch * ch_factor, skip_ch), + ResBlock(in_ch * ch_factor, skip_ch), + ResBlock(in_ch * ch_factor, skip_ch), + nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1), + ) + + def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None): + f0 = self.upsample(f0) + # print([f1.shape,f0.shape]) + f1 = self.upsample(f1) + n = self.num_flows + flow0 = 4.0 * resize(flow0, scale_factor=4.0) + flow1 = 4.0 * resize(flow1, scale_factor=4.0) + + ft_ = resize(ft_, scale_factor=4.0) + mask = resize(mask, scale_factor=4.0) + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + + f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1) + + assert mask is not None + f_in = torch.cat([f_in, mask], 1) + + assert img0 is not None + assert img1 is not None + warped_img0 = warp(img0, flow0) + warped_img1 = warp(img1, flow1) + f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1) + + out = self.convblock(f_in) + delta_flow0, delta_flow1, delta_mask, img_res = torch.split( + out, [2 * n, 2 * n, n, 3 * n], 1 + ) + mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1) + mask = torch.sigmoid(mask) + flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1) + flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1) + + return flow0, flow1, mask, img_res diff --git a/blissful_tuner/gimmvfi/generalizable_INR/modules/fi_utils.py b/blissful_tuner/gimmvfi/generalizable_INR/modules/fi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bd12d5c68bc3d9b29726b02432349707a90d92 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/modules/fi_utils.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# raft: https://github.com/princeton-vl/RAFT +# ema-vfi: https://github.com/MCG-NJU/EMA-VFI +# -------------------------------------------------------- + +import torch +import torch.nn.functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample( + input=tenInput, + grid=g, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + + +def normalize_flow(flows): + # FIXME: MULTI-DIMENSION + flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape( + -1, 1, 1, 1, 1 + ) + flows = flows / flow_scaler # [-1,1] + # # Adapt to [0,1] + flows = (flows + 1.0) / 2.0 + return flows, flow_scaler + + +def unnormalize_flow(flows, flow_scaler): + return (flows * 2.0 - 1.0) * flow_scaler + + +def resize(x, scale_factor): + return F.interpolate( + x, scale_factor=scale_factor, mode="bilinear", align_corners=False + ) + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def build_coord(img): + N, C, H, W = img.shape + coords = coords_grid(N, H // 8, W // 8) + return coords diff --git a/blissful_tuner/gimmvfi/generalizable_INR/modules/hyponet.py b/blissful_tuner/gimmvfi/generalizable_INR/modules/hyponet.py new file mode 100644 index 0000000000000000000000000000000000000000..75310277775ce877452f0f5e98d1ddce147abe63 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/modules/hyponet.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import einops + +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf + +from ..configs import HypoNetConfig +from .utils import create_params_with_init, create_activation + + +class HypoNet(nn.Module): + r""" + The Hyponetwork with a coordinate-based MLP to be modulated. + """ + + def __init__(self, config: HypoNetConfig, add_coord_dim=32): + super().__init__() + self.config = config + self.use_bias = config.use_bias + self.init_config = config.initialization + self.num_layer = config.n_layer + self.hidden_dims = config.hidden_dim + self.add_coord_dim = add_coord_dim + + if len(self.hidden_dims) == 1: + self.hidden_dims = OmegaConf.to_object(self.hidden_dims) * ( + self.num_layer - 1 + ) # exclude output layer + else: + assert len(self.hidden_dims) == self.num_layer - 1 + + if self.config.activation.type == "siren": + assert self.init_config.weight_init_type == "siren" + assert self.init_config.bias_init_type == "siren" + + # after computes the shape of trainable parameters, initialize them + self.params_dict = None + self.params_shape_dict = self.compute_params_shape() + self.activation = create_activation(self.config.activation) + self.build_base_params_dict(self.config.initialization) + self.output_bias = config.output_bias + + self.normalize_weight = config.normalize_weight + + self.ignore_base_param_dict = {name: False for name in self.params_dict} + + @staticmethod + def subsample_coords(coords, subcoord_idx=None): + if subcoord_idx is None: + return coords + + batch_size = coords.shape[0] + sub_coords = [] + coords = coords.view(batch_size, -1, coords.shape[-1]) + for idx in range(batch_size): + sub_coords.append(coords[idx : idx + 1, subcoord_idx[idx]]) + sub_coords = torch.cat(sub_coords, dim=0) + return sub_coords + + def forward(self, coord, modulation_params_dict=None, pixel_latent=None): + sub_idx = None + if isinstance(coord, tuple): + coord, sub_idx = coord[0], coord[1] + + if modulation_params_dict is not None: + self.check_valid_param_keys(modulation_params_dict) + + batch_size, coord_shape, input_dim = ( + coord.shape[0], + coord.shape[1:-1], + coord.shape[-1], + ) + coord = coord.view(batch_size, -1, input_dim) # flatten the coordinates + assert pixel_latent is not None + pixel_latent = F.interpolate( + pixel_latent.permute(0, 3, 1, 2), + size=(coord_shape[1], coord_shape[2]), + mode="bilinear", + ).permute(0, 2, 3, 1) + pixel_latent_dim = pixel_latent.shape[-1] + pixel_latent = pixel_latent.view(batch_size, -1, pixel_latent_dim) + hidden = coord + + hidden = torch.cat([pixel_latent, hidden], dim=-1) + + hidden = self.subsample_coords(hidden, sub_idx) + + for idx in range(self.config.n_layer): + param_key = f"linear_wb{idx}" + base_param = einops.repeat( + self.params_dict[param_key], "n m -> b n m", b=batch_size + ) + + if (modulation_params_dict is not None) and ( + param_key in modulation_params_dict.keys() + ): + modulation_param = modulation_params_dict[param_key] + else: + if self.config.use_bias: + modulation_param = torch.ones_like(base_param[:, :-1]) + else: + modulation_param = torch.ones_like(base_param) + + if self.config.use_bias: + ones = torch.ones(*hidden.shape[:-1], 1, device=hidden.device) + hidden = torch.cat([hidden, ones], dim=-1) + + base_param_w, base_param_b = ( + base_param[:, :-1, :], + base_param[:, -1:, :], + ) + + if self.ignore_base_param_dict[param_key]: + base_param_w = 1.0 + param_w = base_param_w * modulation_param + if self.normalize_weight: + param_w = F.normalize(param_w, dim=1) + modulated_param = torch.cat([param_w, base_param_b], dim=1) + else: + if self.ignore_base_param_dict[param_key]: + base_param = 1.0 + if self.normalize_weight: + modulated_param = F.normalize(base_param * modulation_param, dim=1) + else: + modulated_param = base_param * modulation_param + # print([param_key,hidden.shape,modulated_param.shape]) + hidden = torch.bmm(hidden, modulated_param) + + if idx < (self.config.n_layer - 1): + hidden = self.activation(hidden) + + outputs = hidden + self.output_bias + if sub_idx is None: + outputs = outputs.view(batch_size, *coord_shape, -1) + return outputs + + def compute_params_shape(self): + """ + Computes the shape of MLP parameters. + The computed shapes are used to build the initial weights by `build_base_params_dict`. + """ + config = self.config + use_bias = self.use_bias + + param_shape_dict = dict() + + fan_in = config.input_dim + add_dim = self.add_coord_dim + fan_in = fan_in + add_dim + fan_in = fan_in + 1 if use_bias else fan_in + + for i in range(config.n_layer - 1): + fan_out = self.hidden_dims[i] + param_shape_dict[f"linear_wb{i}"] = (fan_in, fan_out) + fan_in = fan_out + 1 if use_bias else fan_out + + param_shape_dict[f"linear_wb{config.n_layer-1}"] = (fan_in, config.output_dim) + return param_shape_dict + + def build_base_params_dict(self, init_config): + assert self.params_shape_dict + params_dict = nn.ParameterDict() + for idx, (name, shape) in enumerate(self.params_shape_dict.items()): + is_first = idx == 0 + params = create_params_with_init( + shape, + init_type=init_config.weight_init_type, + include_bias=self.use_bias, + bias_init_type=init_config.bias_init_type, + is_first=is_first, + siren_w0=self.config.activation.siren_w0, # valid only for siren + ) + params = nn.Parameter(params) + params_dict[name] = params + self.set_params_dict(params_dict) + + def check_valid_param_keys(self, params_dict): + predefined_params_keys = self.params_shape_dict.keys() + for param_key in params_dict.keys(): + if param_key in predefined_params_keys: + continue + else: + raise KeyError + + def set_params_dict(self, params_dict): + self.check_valid_param_keys(params_dict) + self.params_dict = params_dict diff --git a/blissful_tuner/gimmvfi/generalizable_INR/modules/layers.py b/blissful_tuner/gimmvfi/generalizable_INR/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..515de3819060e2d06953f42d3e27f23afcb40655 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/modules/layers.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- + +from torch import nn +import torch + + +# define siren layer & Siren model +class Sine(nn.Module): + """Sine activation with scaling. + + Args: + w0 (float): Omega_0 parameter from SIREN paper. + """ + + def __init__(self, w0=1.0): + super().__init__() + self.w0 = w0 + + def forward(self, x): + return torch.sin(self.w0 * x) + + +# Damping activation from http://arxiv.org/abs/2306.15242 +class Damping(nn.Module): + """Sine activation with sublinear factor + + Args: + w0 (float): Omega_0 parameter from SIREN paper. + """ + + def __init__(self, w0=1.0): + super().__init__() + self.w0 = w0 + + def forward(self, x): + x = torch.clamp(x, min=1e-30) + return torch.sin(self.w0 * x) * torch.sqrt(x.abs()) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/modules/module_config.py b/blissful_tuner/gimmvfi/generalizable_INR/modules/module_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a117db2160f30ac165d9d5223fbad70d66a3fb98 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/modules/module_config.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +from typing import List, Optional +from dataclasses import dataclass, field +from omegaconf import MISSING + + +@dataclass +class HypoNetActivationConfig: + type: str = "relu" + siren_w0: Optional[float] = 30.0 + + +@dataclass +class HypoNetInitConfig: + weight_init_type: Optional[str] = "kaiming_uniform" + bias_init_type: Optional[str] = "zero" + + +@dataclass +class HypoNetConfig: + type: str = "mlp" + n_layer: int = 5 + hidden_dim: List[int] = MISSING + use_bias: bool = True + input_dim: int = 2 + output_dim: int = 3 + output_bias: float = 0.5 + activation: HypoNetActivationConfig = field(default_factory=HypoNetActivationConfig) + initialization: HypoNetInitConfig = field(default_factory=HypoNetInitConfig) + + normalize_weight: bool = True + linear_interpo: bool = False + + +@dataclass +class CoordSamplerConfig: + data_type: str = "image" + t_coord_only: bool = False + coord_range: List[float] = MISSING + time_range: List[float] = MISSING + train_strategy: Optional[str] = MISSING + val_strategy: Optional[str] = MISSING + patch_size: Optional[int] = MISSING diff --git a/blissful_tuner/gimmvfi/generalizable_INR/modules/softsplat.py b/blissful_tuner/gimmvfi/generalizable_INR/modules/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..369b99b0a3363ca16b4a9fe36b5db4344c50326a --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/modules/softsplat.py @@ -0,0 +1,669 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# softmax-splatting: https://github.com/sniklaus/softmax-splatting +# -------------------------------------------------------- + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn: int): + return cupy.int32(intIn) + + +# end + + +def cuda_float32(fltIn: float): + return cupy.float32(fltIn) + + +# end + + +def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): + if "device" not in objCudacache: + objCudacache["device"] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert False + + # end + # end + + strKey += objCudacache["device"] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace("{{" + strVariable + "}}", objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace("{{type}}", "unsigned char") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace("{{type}}", "half") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace("{{type}}", "float") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace("{{type}}", "double") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace("{{type}}", "int") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace("{{type}}", "long") + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert False + + elif True: + print(strVariable, type(objValue)) + assert False + + # end + # end + + while True: + objMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace( + objMatch.group(), + str( + intSizes[intArg] + if torch.is_tensor(intSizes[intArg]) == False + else intSizes[intArg].item() + ), + ) + # end + + while True: + objMatch = re.search(r"(OFFSET_)([0-4])(\()", strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == "(" else 0 + intParentheses -= 1 if strKernel[intStop] == ")" else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(",") + + assert intArgs == len(strArgs) - 1 + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append( + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str( + intStrides[intArg] + if torch.is_tensor(intStrides[intArg]) == False + else intStrides[intArg].item() + ) + + ")" + ) + # end + + strKernel = strKernel.replace( + "OFFSET_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", + "(" + str.join("+", strIndex) + ")", + ) + # end + + while True: + objMatch = re.search(r"(VALUE_)([0-4])(\()", strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == "(" else 0 + intParentheses -= 1 if strKernel[intStop] == ")" else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(",") + + assert intArgs == len(strArgs) - 1 + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append( + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str( + intStrides[intArg] + if torch.is_tensor(intStrides[intArg]) == False + else intStrides[intArg].item() + ) + + ")" + ) + # end + + strKernel = strKernel.replace( + "VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", + strTensor + "[" + str.join("+", strIndex) + "]", + ) + # end + + objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel} + # end + + return strKey + + +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey: str): + try: + os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path()) + except Exception: + if "CUDA_HOME" not in os.environ: + raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.") + + strKernel = objCudacache[strKey]["strKernel"] + strFunction = objCudacache[strKey]["strFunction"] + + return cupy.RawModule( + code=strKernel, + options=( + "-I " + os.environ["CUDA_HOME"], + "-I " + os.environ["CUDA_HOME"] + "/include", + ), + ).get_function(strFunction) + + + +########################################################## + + +def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False): + assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"] + + if strMode == "sum": + assert tenMetric is None + if strMode == "avg": + assert tenMetric is None + if strMode.split("-")[0] == "linear": + assert tenMetric is not None + if strMode.split("-")[0] == "softmax": + assert tenMetric is not None + + if strMode == "avg": + tenIn = torch.cat( + [ + tenIn, + tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]), + ], + 1, + ) + + elif strMode.split("-")[0] == "linear": + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split("-")[0] == "softmax": + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + if torch.isnan(tenIn).any(): + print("NaN values detected during training in tenIn. Exiting.") + assert False + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if torch.isnan(tenOut).any(): + print("NaN values detected during training in tenOut_1. Exiting.") + assert False + + if strMode.split("-")[0] in ["avg", "linear", "softmax"]: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split("-")) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "addeps": + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "zeroeps": + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split("-")[1] == "clipeps": + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + if return_norm: + return tenOut[:, :-1, :, :], tenNormalize + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + + if torch.isnan(tenOut).any(): + print("NaN values detected during training in tenOut_2. Exiting.") + assert False + + # end + + return tenOut + + +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros( + [tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]] + ) + + if tenIn.is_cuda == True: + cuda_launch( + cuda_kernel( + "softsplat_out", + """ + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + """, + {"tenIn": tenIn, "tenFlow": tenFlow, "tenOut": tenOut}, + ) + )( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenOut.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOut.data_ptr(), + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + + elif tenIn.is_cuda != True: + assert False + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + + # end + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous() + assert tenOutgrad.is_cuda == True + + tenIngrad = ( + tenIn.new_zeros( + [tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]] + ) + if self.needs_input_grad[0] == True + else None + ) + tenFlowgrad = ( + tenFlow.new_zeros( + [tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]] + ) + if self.needs_input_grad[1] == True + else None + ) + + if tenIngrad is not None: + cuda_launch( + cuda_kernel( + "softsplat_ingrad", + """ + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + """, + { + "tenIn": tenIn, + "tenFlow": tenFlow, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenFlowgrad": tenFlowgrad, + }, + ) + )( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenIngrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), + tenIngrad.data_ptr(), + None, + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + # end + + if tenFlowgrad is not None: + cuda_launch( + cuda_kernel( + "softsplat_flowgrad", + """ + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + """, + { + "tenIn": tenIn, + "tenFlow": tenFlow, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenFlowgrad": tenFlowgrad, + }, + ) + )( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenFlowgrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), + None, + tenFlowgrad.data_ptr(), + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + # end + + return tenIngrad, tenFlowgrad + + # end + + +# end diff --git a/blissful_tuner/gimmvfi/generalizable_INR/modules/utils.py b/blissful_tuner/gimmvfi/generalizable_INR/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7f6548069b2b748c8863fda6fd39725cb6fe7fdf --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/modules/utils.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import math +import torch +import torch.nn as nn + +from .layers import Sine, Damping + + +def convert_int_to_list(size, len_list=2): + if isinstance(size, int): + return [size] * len_list + else: + assert len(size) == len_list + return size + + +def initialize_params(params, init_type, **kwargs): + fan_in, fan_out = params.shape[0], params.shape[1] + if init_type is None or init_type == "normal": + nn.init.normal_(params) + elif init_type == "kaiming_uniform": + nn.init.kaiming_uniform_(params, a=math.sqrt(5)) + elif init_type == "uniform_fan_in": + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(params, -bound, bound) + elif init_type == "zero": + nn.init.zeros_(params) + elif "siren" == init_type: + assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys() + w0 = kwargs["siren_w0"] + if kwargs["is_first"]: + w_std = 1 / fan_in + else: + w_std = math.sqrt(6.0 / fan_in) / w0 + nn.init.uniform_(params, -w_std, w_std) + else: + raise NotImplementedError + + +def create_params_with_init( + shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs +): + if not include_bias: + params = torch.empty([shape[0], shape[1]]) + initialize_params(params, init_type, **kwargs) + return params + else: + params = torch.empty([shape[0] - 1, shape[1]]) + bias = torch.empty([1, shape[1]]) + + initialize_params(params, init_type, **kwargs) + initialize_params(bias, bias_init_type, **kwargs) + return torch.cat([params, bias], dim=0) + + +def create_activation(config): + if config.type == "relu": + activation = nn.ReLU() + elif config.type == "siren": + activation = Sine(config.siren_w0) + elif config.type == "silu": + activation = nn.SiLU() + elif config.type == "damp": + activation = Damping(config.siren_w0) + else: + raise NotImplementedError + return activation diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a733186cdc9005050f1083c46b94b81b22cf53e7 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/__init__.py @@ -0,0 +1,24 @@ +from .raft import RAFT +import argparse +import torch +from .extractor import BasicEncoder + + +def initialize_RAFT(model_path="pretrained_ckpt/raft-things.pth", device="cuda"): + """Initializes the RAFT model.""" + args = argparse.ArgumentParser() + args.raft_model = model_path + args.small = False + args.mixed_precision = False + args.alternate_corr = False + model = RAFT(args) + ckpt = torch.load(args.raft_model, map_location="cpu") + + def convert(param): + return {k.replace("module.", ""): v for k, v in param.items() if "module" in k} + + ckpt = convert(ckpt) + model.load_state_dict(ckpt, strict=True) + print("load raft from " + model_path) + + return model diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/corr.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/corr.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d678b41ec3ea46957e1602559db87ebcb816f8 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/corr.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# amt: https://github.com/MCG-NKU/AMT +# raft: https://github.com/princeton-vl/RAFT +# -------------------------------------------------------- + +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.corr_pyramid_T = [] + + corr = BidirCorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) + + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + corr_T = corr_T.reshape(batch * h2 * w2, dim, h1, w1) + + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + for _ in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_T = F.avg_pool2d(corr_T, 2, stride=2) + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + def __call__(self, coords0, coords1): + r = self.radius + coords0 = coords0.permute(0, 2, 3, 1) + coords1 = coords1.permute(0, 2, 3, 1) + assert ( + coords0.shape == coords1.shape + ), f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" + batch, h1, w1, _ = coords0.shape + + out_pyramid = [] + out_pyramid_T = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + corr_T = self.corr_pyramid_T[i] + + dx = torch.linspace(-r, r, 2 * r + 1, device=coords0.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords0.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + + centroid_lvl_0 = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + centroid_lvl_1 = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + coords_lvl_0 = centroid_lvl_0 + delta_lvl + coords_lvl_1 = centroid_lvl_1 + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl_0) + corr_T = bilinear_sampler(corr_T, coords_lvl_1) + corr = corr.view(batch, h1, w1, -1) + corr_T = corr_T.view(batch, h1, w1, -1) + out_pyramid.append(corr) + out_pyramid_T.append(corr_T) + + out = torch.cat(out_pyramid, dim=-1) + out_T = torch.cat(out_pyramid_T, dim=-1) + return ( + out.permute(0, 3, 1, 2).contiguous().float(), + out_T.permute(0, 3, 1, 2).contiguous().float(), + ) + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/extractor.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac21b55b207cb66232a8c29b9974e04bf72973a --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/extractor.py @@ -0,0 +1,293 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride + ) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0, only_feat=False): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + self.only_feat = only_feat + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + if not self.only_feat: + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x, return_feature=False, mif=False): + features = [] + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x_2 = F.interpolate(x, scale_factor=1 / 2, mode="bilinear", align_corners=False) + x_4 = F.interpolate(x, scale_factor=1 / 4, mode="bilinear", align_corners=False) + + def f1(feat): + feat = self.conv1(feat) + feat = self.norm1(feat) + feat = self.relu1(feat) + feat = self.layer1(feat) + return feat + + x = f1(x) + features.append(x) + x = self.layer2(x) + if mif: + x_2_2 = f1(x_2) + features.append(torch.cat([x, x_2_2], dim=1)) + else: + features.append(x) + x = self.layer3(x) + if mif: + x_2_4 = self.layer2(x_2_2) + x_4_4 = f1(x_4) + features.append(torch.cat([x, x_2_4, x_4_4], dim=1)) + else: + features.append(x) + + if not self.only_feat: + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + features = [torch.split(f, [batch_dim, batch_dim], dim=0) for f in features] + if return_feature: + return x, features + else: + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/other_raft.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/other_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..77a02fe940f54b48feb520c37466d704f0b72ab1 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/other_raft.py @@ -0,0 +1,238 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import BidirCorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + +# BiRAFT +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + self.corr_levels = 4 + self.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + self.corr_levels = 4 + self.corr_radius = 4 + + if "dropout" not in args._get_kwargs(): + self.args.dropout = 0 + + if "alternate_corr" not in args._get_kwargs(): + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder( + output_dim=128, norm_fn="instance", dropout=args.dropout + ) + self.cnet = SmallEncoder( + output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout + ) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder( + output_dim=256, norm_fn="instance", dropout=args.dropout + ) + self.cnet = BasicEncoder( + output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout + ) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def build_coord(self, img): + N, C, H, W = img.shape + coords = coords_grid(N, H // 8, W // 8, device=img.device) + return coords + + def initialize_flow(self, img, img2): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + assert img.shape == img2.shape + N, C, H, W = img.shape + coords01 = coords_grid(N, H // 8, W // 8, device=img.device) + coords02 = coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = coords_grid(N, H // 8, W // 8, device=img.device) + coords2 = coords_grid(N, H // 8, W // 8, device=img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords01, coords02, coords1, coords2 + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def get_corr_fn(self, image1, image2, projector=None): + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmaps, feats = self.fnet([image1, image2], return_feature=True) + fmap1, fmap2 = fmaps + fmap1 = fmap1.float() + fmap2 = fmap2.float() + corr_fn1 = None + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + if projector is not None: + corr_fn1 = AlternateCorrBlock( + projector(feats[-1][0]), + projector(feats[-1][1]), + radius=self.args.corr_radius, + ) + else: + corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + if projector is not None: + corr_fn1 = BidirCorrBlock( + projector(feats[-1][0]), + projector(feats[-1][1]), + radius=self.args.corr_radius, + ) + if corr_fn1 is None: + return corr_fn, corr_fn + else: + return corr_fn, corr_fn1 + + def get_corr_fn_from_feat(self, fmap1, fmap2): + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + return corr_fn + + def forward( + self, + image1, + image2, + iters=12, + flow_init=None, + upsample=True, + test_mode=False, + corr_fn=None, + mif=False, + ): + """Estimate optical flow between pair of frames""" + assert flow_init is None + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + if corr_fn is None: + corr_fn, _ = self.get_corr_fn(image1, image2) + + # # run the feature network + # with autocast(enabled=self.args.mixed_precision): + # fmap1, fmap2 = self.fnet([image1, image2]) + + # fmap1 = fmap1.float() + # fmap2 = fmap2.float() + # if self.args.alternate_corr: + # corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + # else: + # corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + # for image1 + cnet1, features1 = self.cnet(image1, return_feature=True, mif=mif) + net1, inp1 = torch.split(cnet1, [hdim, cdim], dim=1) + net1 = torch.tanh(net1) + inp1 = torch.relu(inp1) + # for image2 + cnet2, features2 = self.cnet(image2, return_feature=True, mif=mif) + net2, inp2 = torch.split(cnet2, [hdim, cdim], dim=1) + net2 = torch.tanh(net2) + inp2 = torch.relu(inp2) + + coords01, coords02, coords1, coords2 = self.initialize_flow(image1, image2) + + # if flow_init is not None: + # coords1 = coords1 + flow_init + + # flow_predictions1 = [] + # flow_predictions2 = [] + for itr in range(iters): + coords1 = coords1.detach() + coords2 = coords2.detach() + corr1, corr2 = corr_fn(coords1, coords2) # index correlation volume + + flow1 = coords1 - coords01 + flow2 = coords2 - coords02 + + with autocast(enabled=self.args.mixed_precision): + net1, up_mask1, delta_flow1 = self.update_block( + net1, inp1, corr1, flow1 + ) + net2, up_mask2, delta_flow2 = self.update_block( + net2, inp2, corr2, flow2 + ) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow1 + coords2 = coords2 + delta_flow2 + flow_low1 = coords1 - coords01 + flow_low2 = coords2 - coords02 + # upsample predictions + if up_mask1 is None: + flow_up1 = upflow8(coords1 - coords01) + flow_up2 = upflow8(coords2 - coords02) + else: + flow_up1 = self.upsample_flow(coords1 - coords01, up_mask1) + flow_up2 = self.upsample_flow(coords2 - coords02, up_mask2) + + # flow_predictions.append(flow_up) + return flow_up1, flow_up2, flow_low1, flow_low2, features1, features2 + # if test_mode: + # return coords1 - coords0, flow_up + + # return flow_predictions diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/raft.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..75d9d0292f47fddbafae67f37612b4c473ad6ad5 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/raft.py @@ -0,0 +1,169 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + self.corr_levels = 4 + self.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + self.corr_levels = 4 + self.corr_radius = 4 + + if "dropout" not in args._get_kwargs(): + self.args.dropout = 0 + + if "alternate_corr" not in args._get_kwargs(): + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder( + output_dim=128, norm_fn="instance", dropout=args.dropout + ) + self.cnet = SmallEncoder( + output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout + ) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder( + output_dim=256, norm_fn="instance", dropout=args.dropout + ) + self.cnet = BasicEncoder( + output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout + ) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = coords_grid(N, H // 8, W // 8, device=img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def forward( + self, + image1, + image2, + iters=12, + flow_init=None, + upsample=True, + test_mode=False, + return_feat=True, + ): + """Estimate optical flow between pair of frames""" + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet, feats = self.cnet(image1, return_feature=True) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + if return_feat: + return flow_up, feats[1:], fmap1 + + return flow_predictions diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/update.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/update.py new file mode 100644 index 0000000000000000000000000000000000000000..ced6df0658da475bf660b98475acab7e100cabaa --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/update.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/__init__.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/augmentor.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..565ae8e3ae8a8f8d4e499a12a12b6f5b051a0a67 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/augmentor.py @@ -0,0 +1,266 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14 + ) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """Photometric augmentation""" + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array( + self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 + ) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """Occlusion augmentation""" + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd) + ) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize( + img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + img2 = cv2.resize( + img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + flow = cv2.resize( + flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter( + brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14 + ) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array( + self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 + ) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid >= 1] + flow0 = flow[valid >= 1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:, 0]).astype(np.int32) + yy = np.round(coords1[:, 1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd) + ) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize( + img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + img2 = cv2.resize( + img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + flow, valid = self.resize_sparse_flow_map( + flow, valid, fx=scale_x, fy=scale_y + ) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + return img1, img2, flow, valid + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/flow_viz.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/flow_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..fec08363580082d412f103dc3ea0154895d8169b --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/flow_viz.py @@ -0,0 +1,133 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u) / np.pi + fk = (a + 1) / 2 * (ncols - 1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:, i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, "input flow must have three dimensions" + assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/frame_utils.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bee554e4bb1596bfe15945198f4cb1b5f1d786bf --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/frame_utils.py @@ -0,0 +1,142 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + + +def readFlow(fn): + """Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, "rb") as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print("Magic number incorrect. Invalid .flo file") + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + +def readPFM(file): + file = open(file, "rb") + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b"PF": + color = True + elif header == b"Pf": + color = False + else: + raise Exception("Not a PFM file.") + + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + + +def writeFlow(filename, uv, v=None): + """Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert uv.ndim == 3 + assert uv.shape[2] == 2 + u = uv[:, :, 0] + v = uv[:, :, 1] + else: + u = uv + + assert u.shape == v.shape + height, width = u.shape + f = open(filename, "wb") + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width * nBands)) + tmp[:, np.arange(width) * 2] = u + tmp[:, np.arange(width) * 2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg": + return Image.open(file_name) + elif ext == ".bin" or ext == ".raw": + return np.load(file_name) + elif ext == ".flo": + return readFlow(file_name).astype(np.float32) + elif ext == ".pfm": + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] diff --git a/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/utils.py b/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..781dc55d6e9b62f5eb9fcd275720ae57a5bafad4 --- /dev/null +++ b/blissful_tuner/gimmvfi/generalizable_INR/raft/utils/utils.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """Pads images such that dimensions are divisible by 8""" + + def __init__(self, dims, mode="sintel"): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == "sintel": + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 + ) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 + ) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode="bilinear", mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid( + torch.arange(ht, device=device), torch.arange(wd, device=device) + ) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode="bilinear"): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/blissful_tuner/gimmvfi/utils/flow_viz.py b/blissful_tuner/gimmvfi/utils/flow_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..b8d93345c88238b40c07ddd9b914cb9a66ad8224 --- /dev/null +++ b/blissful_tuner/gimmvfi/utils/flow_viz.py @@ -0,0 +1,136 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u) / np.pi + fk = (a + 1) / 2 * (ncols - 1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:, i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, max_flow=None): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, "input flow must have three dimensions" + assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + if max_flow is None: + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + else: + rad_max = max_flow + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) diff --git a/blissful_tuner/gimmvfi/utils/utils.py b/blissful_tuner/gimmvfi/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..981bff61f21f6e017bc97effcc9b64d9822ab83c --- /dev/null +++ b/blissful_tuner/gimmvfi/utils/utils.py @@ -0,0 +1,52 @@ +from easydict import EasyDict as edict +import torch.nn.functional as F + +class InputPadder: + """Pads images such that dimensions are divisible by divisor""" + + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + + def pad(self, *inputs): + if len(inputs) == 1: + return F.pad(inputs[0], self._pad, mode="replicate") + else: + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, *inputs): + if len(inputs) == 1: + return self._unpad(inputs[0]) + else: + return [self._unpad(x) for x in inputs] + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] +def easydict_to_dict(obj): + if not isinstance(obj, edict): + return obj + else: + return {k: easydict_to_dict(v) for k, v in obj.items()} + + +class RaftArgs: + def __init__(self, small, mixed_precision, alternate_corr): + self.small = small + self.mixed_precision = mixed_precision + self.alternate_corr = alternate_corr + + def _get_kwargs(self): + return { + "small": self.small, + "mixed_precision": self.mixed_precision, + "alternate_corr": self.alternate_corr + } \ No newline at end of file diff --git a/blissful_tuner/hvw_posemb_layers.py b/blissful_tuner/hvw_posemb_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..cf54a3981116516f401a653a951223b0921942d9 --- /dev/null +++ b/blissful_tuner/hvw_posemb_layers.py @@ -0,0 +1,305 @@ +# From HunyuanVideoWrapper +import torch +from typing import Union, Tuple, List + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 + + +def apply_rotary(x, cos, sin): + x_reshaped = x.view(*x.shape[:-1], -1, 2) + x1, x2 = x_reshaped.unbind(-1) + x_rotated = torch.stack([-x2, x1], dim=-1).flatten(3) + return (x * cos) + (x_rotated * sin) + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + upcast: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + + shape = [d if i == 1 or i == xq.ndim - 1 else 1 for i, d in enumerate(xq.shape)] + cos, sin = freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + if upcast: + xq_out = apply_rotary(xq.float(), cos, sin).to(xq.dtype) + xk_out = apply_rotary(xk.float(), cos, sin).to(xk.dtype) + else: + xq_out = apply_rotary(xq, cos, sin) + xk_out = apply_rotary(xk, cos, sin) + + return xq_out, xk_out + + +def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + theta_rescale_factor: Union[float, List[float]] = 1.0, + interpolation_factor: Union[float, List[float]] = 1.0, + num_frames: int = 129, + k: int = 0, +): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd( + start, *args, dim=len(rope_dim_list) + ) # [3, W, H, D] / [2, W, H] + + if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): + theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) + assert len(theta_rescale_factor) == len( + rope_dim_list + ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): + interpolation_factor = [interpolation_factor] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) + assert len(interpolation_factor) == len( + rope_dim_list + ), "len(interpolation_factor) should equal to len(rope_dim_list)" + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + if i == 0: + emb = get_1d_rotary_pos_embed_riflex( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + L_test=num_frames, + k=k, + ) # 2 x [WHD, rope_dim_list[i]] + else: + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + L_test: int = 100, + k: int = 0, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) # [D/2] + # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" + + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # complex64 # [S, D/2] + return freqs_cis + +def get_1d_rotary_pos_embed_riflex( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + L_test: int = 66, + k: int = 0, + N_k: int=50 +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) # [D/2] + # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" + + #RIFLEx https://github.com/thu-ml/RIFLEx + if k > 0 and L_test > N_k: + freqs[k-1] = 0.9 * 2 * torch.pi / L_test + + + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # complex64 # [S, D/2] + return freqs_cis diff --git a/blissful_tuner/latent_preview.py b/blissful_tuner/latent_preview.py new file mode 100644 index 0000000000000000000000000000000000000000..9a10c835038e864f8e2cd4c4231180dfa7a363b1 --- /dev/null +++ b/blissful_tuner/latent_preview.py @@ -0,0 +1,314 @@ +# latent_preview.py + +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Latent preview for Blissful Tuner extension +License: Apache 2.0 +Created on Mon Mar 10 16:47:29 2025 + +@author: blyss +""" +import os +import torch +import av +from PIL import Image +from .taehv import TAEHV +from .utils import load_torch_file +from blissful_tuner.utils import BlissfulLogger + +logger = BlissfulLogger(__name__, "#8e00ed") + + +class LatentPreviewer(): + @torch.inference_mode() + def __init__(self, args, original_latents, timesteps, device, dtype, model_type="hunyuan"): + self.mode = "latent2rgb" if not hasattr(args, 'preview_vae') or args.preview_vae is None else "taehv" + ##logger.info(f"Initializing latent previewer with mode {self.mode}...") + # Correctly handle framepack - it should subtract noise like others unless specifically told otherwise + self.subtract_noise = True # Default to True for all models now + # If you specifically need framepack NOT to subtract noise, you'd add a condition here + # Example: self.subtract_noise = False if model_type == "framepack" else True + self.args = args + self.model_type = model_type + self.device = device + self.dtype = dtype if dtype != torch.float8_e4m3fn else torch.float16 + if model_type != "framepack" and original_latents is not None and timesteps is not None: + self.original_latents = original_latents.to(self.device) + self.timesteps_percent = timesteps / 1000 + # Add Framepack check here too if needed for original_latents/timesteps later + # elif model_type == "framepack" and ... + + if self.model_type not in ["hunyuan", "wan", "framepack"]: + raise ValueError(f"Unsupported model type: {self.model_type}") + + if self.mode == "taehv": + #logger.info(f"Loading TAEHV: {args.preview_vae}...") + if os.path.exists(args.preview_vae): + tae_sd = load_torch_file(args.preview_vae, safe_load=True, device=args.device) + else: + raise FileNotFoundError(f"{args.preview_vae} was not found!") + self.taehv = TAEHV(tae_sd).to("cpu", self.dtype) # Offload for VRAM and match datatype + self.decoder = self.decode_taehv + self.scale_factor = None + self.fps = args.fps + elif self.mode == "latent2rgb": + self.decoder = self.decode_latent2rgb + self.scale_factor = 8 + # Adjust FPS for latent2rgb preview if necessary + # Original code had / 4, but maybe match output FPS is better? + # Let's keep the / 4 logic for now as it was there before. + self.fps = int(args.fps / 4) if args.fps > 4 else 1 # Ensure fps is at least 1 + + + @torch.inference_mode() + def preview(self, noisy_latents, current_step=None, preview_suffix=None): + if self.device == "cuda" or self.device == torch.device("cuda"): + torch.cuda.empty_cache() + if self.model_type == "wan": + noisy_latents = noisy_latents.unsqueeze(0) # F, C, H, W -> B, F, C, H, W + elif self.model_type == "hunyuan" or self.model_type == "framepack": # Handle framepack like hunyuan + pass # already B, F, C, H, W or expected format B, C, T, H, W + + # Check dimensions for framepack - it might be B,C,T,H,W not B,F,C,H,W + if self.model_type == "framepack" and noisy_latents.ndim == 5: # B,C,T,H,W + # Ensure latent shape is B, F, C, H, W for consistent processing below if needed + # If decoder expects B,C,T,H,W, this permute might be wrong. Check decoder. + # Assuming decoder handles B,C,T,H,W for framepack's latent2rgb + pass # Keep as B, C, T, H, W if latent2rgb handles it + + # Apply subtraction only if enabled AND necessary inputs are available + if self.subtract_noise and hasattr(self, 'original_latents') and hasattr(self, 'timesteps_percent') and current_step is not None: + denoisy_latents = self.subtract_original_and_normalize(noisy_latents, current_step) + else: + # If not subtracting, maybe still normalize? Depends on desired preview quality. + # For now, just pass through if subtraction isn't happening. + denoisy_latents = noisy_latents + + + decoded = self.decoder(denoisy_latents) # Expects F, C, H, W output from decoder + + # Upscale if we used latent2rgb so output is same size as expected + if self.scale_factor is not None: + upscaled = torch.nn.functional.interpolate( + decoded, + scale_factor=self.scale_factor, + mode="bicubic", + align_corners=False + ) + else: + upscaled = decoded + + _, _, h, w = upscaled.shape + self.write_preview(upscaled, w, h, preview_suffix=preview_suffix) + + @torch.inference_mode() + def subtract_original_and_normalize(self, noisy_latents, current_step): + # Ensure original_latents and timesteps_percent were initialized + if not hasattr(self, 'original_latents') or not hasattr(self, 'timesteps_percent'): + logger.warning("Cannot subtract noise: original_latents or timesteps_percent not initialized.") + return noisy_latents # Return original if we can't process + + # Compute what percent of original noise is remaining + noise_remaining = self.timesteps_percent[current_step].to(device=noisy_latents.device) + # Subtract the portion of original latents + denoisy_latents = noisy_latents - (self.original_latents.to(device=noisy_latents.device) * noise_remaining) + + # Normalize + normalized_denoisy_latents = (denoisy_latents - denoisy_latents.mean()) / (denoisy_latents.std() + 1e-8) + return normalized_denoisy_latents + + @torch.inference_mode() + def write_preview(self, frames, width, height, preview_suffix=None): + suffix_str = f"_{preview_suffix}" if preview_suffix else "" + base_name = f"latent_preview{suffix_str}" + target = os.path.join(self.args.save_path, f"{base_name}.mp4") + target_img = os.path.join(self.args.save_path, f"{base_name}.png") + # Check if we only have a single frame. + if frames.shape[0] == 1: + # Clamp, scale, convert to byte and move to CPU + frame = frames[0].clamp(0, 1).mul(255).byte().cpu() + # Permute from (3, H, W) to (H, W, 3) for PIL. + frame_np = frame.permute(1, 2, 0).numpy() + Image.fromarray(frame_np).save(target_img) + #logger.info(f"Saved single frame preview to {target_img}") # Add log + return + + # Otherwise, write out as a video. + # Make sure fps is at least 1 + output_fps = max(1, self.fps) + #logger.info(f"Writing preview video to {target} at {output_fps} FPS") # Add log + try: + container = av.open(target, mode="w") + stream = container.add_stream("libx264", rate=output_fps) # Use output_fps + stream.pix_fmt = "yuv420p" + stream.width = width + stream.height = height + # Add option for higher quality preview encoding if needed + # stream.options = {'crf': '18'} # Example: Lower CRF = higher quality + + # Loop through each frame. + for frame_idx, frame in enumerate(frames): + # Clamp to [0,1], scale, convert to byte and move to CPU. + frame = frame.clamp(0, 1).mul(255).byte().cpu() + # Permute from (3, H, W) -> (H, W, 3) for AV. + frame_np = frame.permute(1, 2, 0).numpy() + try: + video_frame = av.VideoFrame.from_ndarray(frame_np, format="rgb24") + for packet in stream.encode(video_frame): + container.mux(packet) + except Exception as e: + logger.error(f"Error encoding frame {frame_idx}: {e}") + # Optionally break or continue if one frame fails + break + + + # Flush out any remaining packets and close. + try: + for packet in stream.encode(): + container.mux(packet) + container.close() + #logger.info(f"Finished writing preview video: {target}") # Add log + except Exception as e: + logger.error(f"Error finalizing preview video: {e}") + # Clean up container if possible + try: container.close() + except: pass + except Exception as e: + logger.error(f"Error opening or writing to preview container {target}: {e}") + + @torch.inference_mode() + def decode_taehv(self, latents): + """ + Decodes latents with the TAEHV model, returns shape (F, C, H, W). + """ + self.taehv.to(self.device) # Onload + # --- Adjust permute based on expected input dimension order --- + # Assuming TAEHV expects B, C, F, H, W (check TAEHV implementation) + # If input `latents` is B, F, C, H, W (like hunyuan/wan), permute is needed + # If input `latents` is B, C, F, H, W (like framepack), permute might not be needed or different + if self.model_type == "framepack": # Assuming framepack latents are B,C,T,H,W + latents_permuted = latents # No permute needed if TAEHV handles B,C,T,H,W + else: # Assuming hunyuan/wan are B,F,C,H,W -> need B,C,F,H,W for TAEHV? + # Original permute was (0, 2, 1, 3, 4) - Check if this matches TAEHV's expectation + # This permutes B, F, C, H, W -> B, C, F, H, W + latents_permuted = latents.permute(0, 2, 1, 3, 4) + + latents_permuted = latents_permuted.to(device=self.device, dtype=self.dtype) + decoded = self.taehv.decode_video(latents_permuted, parallel=False, show_progress_bar=False) + self.taehv.to("cpu") # Offload + return decoded.squeeze(0) # squeeze off batch dimension -> F, C, H, W + + @torch.inference_mode() + def decode_latent2rgb(self, latents): + """ + Decodes latents to RGB using linear transform, returns shape (F, 3, H, W). + Handles different latent dimension orders (B,F,C,H,W or B,C,T,H,W). + """ + model_params = { + "hunyuan": { + "rgb_factors": [ + [-0.0395, -0.0331, 0.0445], [ 0.0696, 0.0795, 0.0518], + [ 0.0135, -0.0945, -0.0282], [ 0.0108, -0.0250, -0.0765], + [-0.0209, 0.0032, 0.0224], [-0.0804, -0.0254, -0.0639], + [-0.0991, 0.0271, -0.0669], [-0.0646, -0.0422, -0.0400], + [-0.0696, -0.0595, -0.0894], [-0.0799, -0.0208, -0.0375], + [ 0.1166, 0.1627, 0.0962], [ 0.1165, 0.0432, 0.0407], + [-0.2315, -0.1920, -0.1355], [-0.0270, 0.0401, -0.0821], + [-0.0616, -0.0997, -0.0727], [ 0.0249, -0.0469, -0.1703] + ], + "bias": [0.0259, -0.0192, -0.0761], + }, + "wan": { + "rgb_factors": [ + [-0.1299, -0.1692, 0.2932], [ 0.0671, 0.0406, 0.0442], + [ 0.3568, 0.2548, 0.1747], [ 0.0372, 0.2344, 0.1420], + [ 0.0313, 0.0189, -0.0328], [ 0.0296, -0.0956, -0.0665], + [-0.3477, -0.4059, -0.2925], [ 0.0166, 0.1902, 0.1975], + [-0.0412, 0.0267, -0.1364], [-0.1293, 0.0740, 0.1636], + [ 0.0680, 0.3019, 0.1128], [ 0.0032, 0.0581, 0.0639], + [-0.1251, 0.0927, 0.1699], [ 0.0060, -0.0633, 0.0005], + [ 0.3477, 0.2275, 0.2950], [ 0.1984, 0.0913, 0.1861] + ], + "bias": [-0.1835, -0.0868, -0.3360], + }, + # No 'framepack' key needed, will map to 'hunyuan' below + } + + # --- FIX: Determine the correct parameter key --- + # Use 'hunyuan' parameters if the model type is 'framepack' + params_key = "hunyuan" if self.model_type == "framepack" else self.model_type + if params_key not in model_params: + logger.error(f"Unsupported model type '{self.model_type}' (key '{params_key}') for latent2rgb.") + # Optionally return a black image or raise error + # Returning black image of expected shape might prevent further crashes + b, c_or_f, t_or_c, h, w = latents.shape # Get shape + num_frames = t_or_c if self.model_type == "framepack" else c_or_f # Estimate frame dim + return torch.zeros((num_frames, 3, h * self.scale_factor, w * self.scale_factor), device='cpu') + # raise KeyError(f"Unsupported model type '{self.model_type}' (key '{params_key}') for latent2rgb decoding.") + + latent_rgb_factors_data = model_params[params_key]["rgb_factors"] + latent_rgb_factors_bias_data = model_params[params_key]["bias"] + # --- END FIX --- + + # Prepare linear transform + latent_rgb_factors = torch.tensor( + latent_rgb_factors_data, # Use data fetched with correct key + device=latents.device, + dtype=latents.dtype + ).transpose(0, 1) + latent_rgb_factors_bias = torch.tensor( + latent_rgb_factors_bias_data, # Use data fetched with correct key + device=latents.device, + dtype=latents.dtype + ) + + # Handle different dimension orders + # B, F, C, H, W (Hunyuan, Wan) vs B, C, T, H, W (Framepack) + if self.model_type == "framepack": + # Input: B, C, T, H, W + # We need to iterate through T (time/frames) dimension + num_frames = latents.shape[2] + frame_dim_idx = 2 + channel_dim_idx = 1 + else: # Wan (and potentially Hunyuan if prepared similarly) + # Input is expected as B, C, F, H, W after preview() method + num_frames = latents.shape[2] # F (frame dimension) + channel_dim_idx = 1 # C + frame_dim_idx = 2 # F + + latent_images = [] + for t in range(num_frames): + # Extract frame t, permute C to the end for linear layer + if self.model_type == "framepack": + # Extract B, C, H, W for frame t -> squeeze B -> C, H, W -> permute -> H, W, C + extracted = latents[:, :, t, :, :].squeeze(0).permute(1, 2, 0) + else: + # Extract B, C, H, W for frame t -> squeeze B -> C, H, W -> permute -> H, W, C + extracted = latents[:, :, t, :, :].squeeze(0).permute(1, 2, 0) + + # extracted should now be (H, W, C) + rgb = torch.nn.functional.linear(extracted, latent_rgb_factors, bias=latent_rgb_factors_bias) # shape = (H, W, 3) + latent_images.append(rgb) + + # Stack frames into (F, H, W, 3) + if not latent_images: # Handle case where loop might not run + logger.warning("No latent images generated in decode_latent2rgb.") + b, c_or_f, t_or_c, h, w = latents.shape + num_frames = t_or_c if self.model_type == "framepack" else c_or_f + return torch.zeros((num_frames, 3, h * self.scale_factor, w * self.scale_factor), device='cpu') + + latent_images_stacked = torch.stack(latent_images, dim=0) + + # Normalize to [0..1] + latent_images_min = latent_images_stacked.min() + latent_images_max = latent_images_stacked.max() + if latent_images_max > latent_images_min: + normalized_images = (latent_images_stacked - latent_images_min) / (latent_images_max - latent_images_min) + else: + # Handle case where max == min (e.g., all black image) + normalized_images = torch.zeros_like(latent_images_stacked) + + # Permute to (F, 3, H, W) before returning + final_images = normalized_images.permute(0, 3, 1, 2) + return final_images \ No newline at end of file diff --git a/blissful_tuner/model_utility.py b/blissful_tuner/model_utility.py new file mode 100644 index 0000000000000000000000000000000000000000..1660ac0456b6d65e20dd30fc9629347834e10afe --- /dev/null +++ b/blissful_tuner/model_utility.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Model inspection and conversion utility for Blissful Tuner Extension + +License: Apache 2.0 +Created on Wed Apr 23 10:19:19 2025 +@author: blyss +""" +import os +import argparse +import torch +import safetensors +from safetensors.torch import save_file +from tqdm import tqdm + +parser = argparse.ArgumentParser( + description="Convert any model checkpoint (single file or shard directory) to safetensors with dtype cast." +) +parser.add_argument( + "--input", + required=True, + help="Checkpoint file or directory of shards to convert/inspect" +) +parser.add_argument("--convert", type=str, default=None) +parser.add_argument("--inspect", action="store_true") +parser.add_argument("--key_target", type=str) +parser.add_argument("--weights_only", type=str, default="true") +parser.add_argument("--dtype", type=str) +args = parser.parse_args() + + +def load_torch_file(ckpt, weights_only=True, device=None, return_metadata=False): + """ + Load a single checkpoint file or all shards in a directory. + - If `ckpt` is a dir, iterates over supported files, loads each, and merges. + - Returns state_dict (and metadata if return_metadata=True and single file). + """ + if device is None: + device = torch.device("cpu") + + # --- shard support --- + if os.path.isdir(ckpt): + all_sd = {} + for fname in sorted(os.listdir(ckpt)): + path = os.path.join(ckpt, fname) + # only load supported extensions + if not os.path.isfile(path): + continue + if not path.lower().endswith((".safetensors", ".sft", ".pt", ".pth")): + continue + # load each shard (we ignore metadata for shards) + shard_sd = load_torch_file(path, weights_only, device, return_metadata=False) + all_sd.update(shard_sd) + return (all_sd, None) if return_metadata else all_sd + + # --- single file --- + metadata = None + if ckpt.lower().endswith((".safetensors", ".sft")): + try: + with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: + sd = {k: f.get_tensor(k) for k in f.keys()} + metadata = f.metadata() if return_metadata else None + except Exception as e: + raise ValueError(f"Safetensors load failed: {e}\nFile: {ckpt}") + else: + pl_sd = torch.load(ckpt, map_location=device, weights_only=weights_only) + sd = pl_sd.get("state_dict", pl_sd) + + return (sd, metadata) if return_metadata else sd + + +print("Loading checkpoint...") +weights_only = args.weights_only.lower() == "true" +checkpoint = load_torch_file(args.input, weights_only) + +dtype_mapping = { + "fp16": torch.float16, + "float16": torch.float16, + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, + "fp32": torch.float32, + "float32": torch.float32, +} + +if args.convert is not None and os.path.exists(args.convert): + confirm = input(f"{args.convert} exists. Overwrite? [y/N]: ").strip().lower() + if confirm != "y": + print("Aborting.") + exit() + +converted_state_dict = {} +keys_to_process = ( + [k for k in checkpoint if args.key_target in k] if args.key_target else checkpoint.keys() +) +dtypes_in_model = {} +for key in tqdm(keys_to_process, desc="Processing tensors"): + value = checkpoint[key] + if args.inspect: + print(f"{key}: {value.shape} ({value.dtype})") + dtype_to_use = ( + dtype_mapping.get(args.dtype.lower(), value.dtype) + if args.dtype + else value.dtype + ) + if dtype_to_use not in dtypes_in_model: + dtypes_in_model[dtype_to_use] = 1 + else: + dtypes_in_model[dtype_to_use] += 1 + if args.convert: + converted_state_dict[key] = value.to(dtype_to_use) + + +print(f"Dtypes in model: {dtypes_in_model}") +if args.convert: + output_file = ( + args.convert.replace(".pth", ".safetensors") + .replace(".pt", ".safetensors") + ) + print(f"Saving converted tensors to '{output_file}'...") + save_file(converted_state_dict, output_file) diff --git a/blissful_tuner/prompt_weighting.py b/blissful_tuner/prompt_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..c99c8b0324acc1a3c2b5179ec73b828bf5ae9a7a --- /dev/null +++ b/blissful_tuner/prompt_weighting.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sun Apr 20 12:51:05 2025 +Prompt weighting for WanVideo +Adapted and heavily modified from https://github.com/xhinker/sd_embed +License: Apache 2.0 +@author: blyss +""" +from transformers import T5Model +import torch +import re +from typing import Tuple, List, Union +from blissful_tuner.utils import BlissfulLogger +logger = BlissfulLogger(__name__, "#8e00ed") + + +class MiniT5Wrapper(): + """A mini wrapper for the T5 to make managing prompt weighting in Musubi easier""" + + def __init__(self, device: torch.device, dtype: torch.dtype, t5: T5Model): + self.device = device + self.dtype = dtype + self.t5 = t5 + self.model = t5.model + self.times_called = 0 + + def __call__( + self, + prompt: Union[str, List[str]], + device: torch.device, + max_len: int = None + ) -> List[torch.Tensor]: + if isinstance(prompt, list): + if len(prompt) != 1: + raise ValueError("MiniT5Wrapper expects a single prompt at a time (wrapped as a list). Got multiple prompts.") + prompt = prompt[0] + if self.times_called == 0: # Only print this notice once even if called multiple times + logger.info("Weighting prompts...") + # Split positive prompts and process each with weights + prompts_raw = [p.strip() for p in prompt.split('|')] + prompts = [] + all_weights = [] + + for p in prompts_raw: + cleaned_prompt, weights = self.parse_prompt_weights(p) + prompts.append(cleaned_prompt) + all_weights.append(weights) + context = self.t5(prompts, device) + + # Apply weights to embeddings if any were extracted + for i, weights in enumerate(all_weights): + for text, weight in weights.items(): + logger.info(f"Applying weight ({weight}) to promptchunk: '{text}'") + if len(weights) > 0: + context[i] = context[i] * weight + self.times_called += 1 + return context + + def parse_prompt_weights(self, prompt: str) -> Tuple[str, dict]: + """Extract text and weights from prompts with (text:weight) format""" + # Parse all instances of (text:weight) in the prompt + pattern = r'\((.*?):([\d\.]+)\)' + matches = re.findall(pattern, prompt) + + # Replace each match with just the text part + cleaned_prompt = prompt + weights = {} + + for match in matches: + text, weight = match + orig_text = f"({text}:{weight})" + cleaned_prompt = cleaned_prompt.replace(orig_text, text) + weights[text] = float(weight) + + return cleaned_prompt, weights diff --git a/blissful_tuner/swinir/network_swinir.py b/blissful_tuner/swinir/network_swinir.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ef00baf2f38e4049b5adb38b02eb7d37a0bdea --- /dev/null +++ b/blissful_tuner/swinir/network_swinir.py @@ -0,0 +1,882 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + @torch.no_grad() + #@torch.compile(backend="inductor", mode="max-autotune") + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # vectorized version of the 3×3 slicing logic + H, W = x_size + device = next(self.parameters()).device + + # boundaries for your 3 regions along each axis + b0_h = H - self.window_size + b1_h = H - self.shift_size + b0_w = W - self.window_size + b1_w = W - self.shift_size + + # 1D coords + h = torch.arange(H, device=device) + w = torch.arange(W, device=device) + + # region index 0,1,2 for each row/col + # row < b0_h → 0 ; b0_h ≤ row < b1_h → 1 ; row ≥ b1_h → 2 + h_regions = (h >= b0_h).long() + (h >= b1_h).long() # shape (H,) + w_regions = (w >= b0_w).long() + (w >= b1_w).long() # shape (W,) + + # combine to get 0–8 labels + # region_id[y,x] = 3*h_regions[y] + w_regions[x] + region_id = (h_regions[:, None] * 3 + w_regions[None, :]) # (H, W) + + # reshape to 1×H×W×1 so window_partition still works + img_mask = region_id.view(1, H, W, 1) + + # now identical to before + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + @torch.no_grad() + #@torch.compile(backend="inductor", mode="max-autotune") + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) diff --git a/blissful_tuner/taehv.py b/blissful_tuner/taehv.py new file mode 100644 index 0000000000000000000000000000000000000000..476a97784864b2c1120463a4730ef9ef7c9bb41e --- /dev/null +++ b/blissful_tuner/taehv.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +""" +Tiny AutoEncoder for Hunyuan Video +(DNN for encoding / decoding videos to Hunyuan Video's latent space) + From: https://github.com/madebyollin/taehv + MIT License +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm +from collections import namedtuple + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = nn.ReLU(inplace=True) + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar): + """ + Apply a sequential model with memblocks to the given input. + Args: + - model: nn.Sequential of blocks to apply + - x: input data, of dimensions NTCHW + - parallel: if True, parallelize over timesteps (fast but uses O(T) memory) + if False, each timestep will be processed sequentially (slow but uses O(1) memory) + - show_progress_bar: if True, enables tqdm progressbar display + + Returns NTCHW tensor of output data. + """ + assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor" + N, T, C, H, W = x.shape + if parallel: + x = x.reshape(N*T, C, H, W) + # parallel over input timesteps, iterate over blocks + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + NT, C, H, W = x.shape + T = NT // N + _x = x.reshape(N, T, C, H, W) + mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + NT, C, H, W = x.shape + T = NT // N + x = x.view(N, T, C, H, W) + else: + # TODO(oboerbohan): at least on macos this still gradually uses more memory during decode... + # need to fix :( + out = [] + # iterate over input timesteps and also iterate over blocks. + # because of the cursed TPool/TGrow blocks, this is not a nested loop, + # it's actually a ***graph traversal*** problem! so let's make a queue + work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))] + # in addition to manually managing our queue, we also need to manually manage our progressbar. + # we'll update it for every source node that we consume. + progress_bar = tqdm(range(T), disable=not show_progress_bar) + # we'll also need a separate addressable memory per node as well + mem = [None] * len(model) + while work_queue: + xt, i = work_queue.pop(0) + if i == 0: + # new source node consumed + progress_bar.update(1) + if i == len(model): + # reached end of the graph, append result to output list + out.append(xt) + else: + # fetch the block to process + b = model[i] + if isinstance(b, MemBlock): + # mem blocks are simple since we're visiting the graph in causal order + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt + else: + xt_new = b(xt, mem[i]) + mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though + # add successor to work queue + work_queue.insert(0, TWorkItem(xt_new, i+1)) + elif isinstance(b, TPool): + # pool blocks are miserable + if mem[i] is None: + mem[i] = [] # pool memory is itself a queue of inputs to pool + mem[i].append(xt) + if len(mem[i]) > b.stride: + # pool mem is in invalid state, we should have pooled before this + raise ValueError("???") + elif len(mem[i]) < b.stride: + # pool mem is not yet full, go back to processing the work queue + pass + else: + # pool mem is ready, run the pool block + N, C, H, W = xt.shape + xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W)) + # reset the pool mem + mem[i] = [] + # add successor to work queue + work_queue.insert(0, TWorkItem(xt, i+1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C, H, W = xt.shape + # each tgrow has multiple successor nodes + for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)): + # add successor to work queue + work_queue.insert(0, TWorkItem(xt_next, i+1)) + else: + # normal block with no funny business + xt = b(xt) + # add successor to work queue + work_queue.insert(0, TWorkItem(xt, i+1)) + progress_bar.close() + x = torch.stack(out, 1) + return x + +class TAEHV(nn.Module): + latent_channels = 16 + image_channels = 3 + def __init__(self, state_dict, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)): + """Initialize pretrained TAEHV from the given checkpoint. + + Arg: + checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1. + decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview. + decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview. + """ + super().__init__() + self.encoder = nn.Sequential( + conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True), + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), + TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), + conv(64, TAEHV.latent_channels), + ) + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( + Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True), + MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels), + ) + if state_dict is not None: + self.load_state_dict(self.patch_tgrow_layers(state_dict)) + + def patch_tgrow_layers(self, sd): + """Patch TGrow layers to use a smaller kernel if needed. + + Args: + sd: state dict to patch + """ + new_sd = self.state_dict() + for i, layer in enumerate(self.decoder): + if isinstance(layer, TGrow): + key = f"decoder.{i}.conv.weight" + if sd[key].shape[0] > new_sd[key].shape[0]: + # take the last-timestep output channels + sd[key] = sd[key][-new_sd[key].shape[0]:] + return sd + + def encode_video(self, x, parallel=True, show_progress_bar=True): + """Encode a sequence of frames. + + Args: + x: input NTCHW RGB (C=3) tensor with values in [0, 1]. + parallel: if True, all frames will be processed at once. + (this is faster but may require more memory). + if False, frames will be processed sequentially. + Returns NTCHW latent tensor with ~Gaussian values. + """ + return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar) + + def decode_video(self, x, parallel=True, show_progress_bar=True): + """Decode a sequence of frames. + + Args: + x: input NTCHW latent (C=12) tensor with ~Gaussian values. + parallel: if True, all frames will be processed at once. + (this is faster but may require more memory). + if False, frames will be processed sequentially. + Returns NTCHW RGB tensor with ~[0, 1] values. + """ + x = apply_model_with_memblocks(self.decoder, x, False, show_progress_bar) + return x[:, self.frames_to_trim:] + + def forward(self, x): + return self.c(x) diff --git a/blissful_tuner/upscaler.py b/blissful_tuner/upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..41aa33e58e0ef263e3790388c57fea6e416a9688 --- /dev/null +++ b/blissful_tuner/upscaler.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Video Upscaler for Blissful Tuner Extension + +License: Apache 2.0 +Created on Wed Apr 23 10:19:19 2025 +@author: blyss +""" + +from typing import List +import torch +import numpy as np +from tqdm import tqdm +from rich.traceback import install as install_rich_tracebacks +from swinir.network_swinir import SwinIR +from spandrel import ImageModelDescriptor, ModelLoader +from video_processing_common import BlissfulVideoProcessor, set_seed, setup_parser_video_common +from utils import setup_compute_context, load_torch_file, BlissfulLogger +logger = BlissfulLogger(__name__, "#8e00ed") +install_rich_tracebacks() + + +def upscale_frames_swin( + model: torch.nn.Module, + frames: List[np.ndarray], + VideoProcessor: BlissfulVideoProcessor +) -> List[np.ndarray]: + """ + Upscale a list of RGB frames using a compiled SwinIR model. + + Args: + model: Loaded SwinIR upsampler. + frames: List of H×W×3 float32 RGB arrays in [0,1]. + device: torch device (cpu or cuda). + dtype: torch.dtype to use for computation. + + Returns: + List of upscaled H'×W'×3 uint8 BGR frames. + """ + window_size = 8 + for img in tqdm(frames, desc="Upscaling SwinIR"): + # Mark step for CUDA graph capture if enabled + torch.compiler.cudagraph_mark_step_begin() + + # Convert HWC RGB → CHW tensor + tensor = VideoProcessor.np_image_to_tensor(img) + + # Pad to window multiple + _, _, h, w = tensor.shape + h_pad = ((h + window_size - 1) // window_size) * window_size - h + w_pad = ((w + window_size - 1) // window_size) * window_size - w + tensor = torch.cat([tensor, torch.flip(tensor, [2])], 2)[:, :, : h + h_pad, :] + tensor = torch.cat([tensor, torch.flip(tensor, [3])], 3)[:, :, :, : w + w_pad] + + # Inference + with torch.no_grad(): + out = model(tensor) + + # Post-process: NCHW → HWC BGR uint8 + VideoProcessor.write_np_or_tensor_to_png(out) + + +def load_swin_model( + model_path: str, + device: torch.device, + dtype: torch.dtype, +) -> torch.nn.Module: + """ + Instantiate and load weights into a SwinIR model. + + Args: + model_path: Path to checkpoint (.pth or safetensors). + device: torch device. + dtype: torch dtype. + Returns: + SwinIR model in eval() on device and dtype. + """ + logger.info(f"Loading SwinIR model ({dtype})…") + model = SwinIR( + upscale=4, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6] * 9, + embed_dim=240, + num_heads=[8] * 9, + mlp_ratio=2, + upsampler='nearest+conv', + resi_connection='3conv', + ) + ckpt = load_torch_file(model_path) + key = 'params_ema' if 'params_ema' in ckpt else None + model.load_state_dict(ckpt[key] if key else ckpt, strict=True) + model.to(device, dtype).eval() + return model + + +def load_esrgan_model( + model_path: str, + device: torch.device, + dtype: torch.dtype, +) -> torch.nn.Module: + """ + Load an ESRGAN (or RRDBNet) style model via Spandrel loader. + + Args: + model_path: Path to ESRGAN checkpoint. + device: torch device. + dtype: torch dtype. + Returns: + Model ready for inference. + """ + logger.info(f"Loading ESRGAN model ({dtype})…") + descriptor = ModelLoader().load_from_file(model_path) + assert isinstance(descriptor, ImageModelDescriptor) + model = descriptor.model.eval().to(device, dtype) + return model + + +def main() -> None: + """ + Parse CLI args, load input, model, and run upscaling pipeline. + """ + parser = setup_parser_video_common(description="Video upscaling using SwinIR or ESRGAN models") + parser.add_argument( + "--scale", type=float, default=2, + help="Final scale multiplier for output resolution" + ) + parser.add_argument( + "--mode", choices=["swinir", "esrgan"], default="swinir", + help="Model architecture to use" + ) + args = parser.parse_args() + args.mode = args.mode.lower() + # Map string → torch.dtype + device, dtype = setup_compute_context(None, args.dtype) + VideoProcessor = BlissfulVideoProcessor(device, dtype) + VideoProcessor.prepare_files_and_path(args.input, args.output, args.mode.upper()) + + frames, fps, w, h = VideoProcessor.load_frames(make_rgb=True) + set_seed(args.seed) + # Load and run model + if args.mode == "swinir": + model = load_swin_model(args.model, device, dtype) + upscale_frames_swin(model, frames, VideoProcessor) + else: + model = load_esrgan_model(args.model, device, dtype) + logger.info("Processing with ESRGAN...") + for frame in tqdm(frames, desc="Upscaling ESRGAN"): + inp = VideoProcessor.np_image_to_tensor(frame) + with torch.no_grad(): + sr = model(inp) + VideoProcessor.write_np_or_tensor_to_png(sr) + + # Write video + logger.info("Encoding output video...") + out_w, out_h = int(w * args.scale), int(h * args.scale) + VideoProcessor.write_buffered_frames_to_output(fps, args.keep_pngs, (out_w, out_h)) + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/blissful_tuner/utils.py b/blissful_tuner/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4501eb082dad83b03b4c4f48f502bc79f0cd5e --- /dev/null +++ b/blissful_tuner/utils.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Utility functions for Blissful Tuner extension +License: Apache 2.0 +Created on Sat Apr 12 14:09:37 2025 + +@author: blyss +""" +import argparse +import hashlib +import torch +import safetensors +from typing import List, Union, Dict, Tuple, Optional +import logging +from rich.logging import RichHandler + + +# Adapted from ComfyUI +def load_torch_file( + ckpt: str, + safe_load: Optional[bool] = True, + device: Optional[Union[str, torch.device]] = None, + return_metadata: Optional[bool] = False +) -> Union[ + Dict[str, torch.Tensor], + Tuple[Dict[str, torch.Tensor], Optional[Dict[str, str]]] +]: + if device is None: + device = torch.device("cpu") + metadata = None + if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): + try: + with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: + sd = {} + for k in f.keys(): + sd[k] = f.get_tensor(k) + if return_metadata: + metadata = f.metadata() + except Exception as e: + if len(e.args) > 0: + message = e.args[0] + if "HeaderTooLarge" in message: + raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt)) + if "MetadataIncompleteBuffer" in message: + raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt)) + raise e + else: + + pl_sd = torch.load(ckpt, map_location=device, weights_only=safe_load) + + if "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + if len(pl_sd) == 1: + key = list(pl_sd.keys())[0] + sd = pl_sd[key] + if not isinstance(sd, dict): + sd = pl_sd + else: + sd = pl_sd + return (sd, metadata) if return_metadata else sd + + +def add_noise_to_reference_video( + image: torch.Tensor, + ratio: Optional[float] = None +) -> torch.Tensor: + """ + Add Gaussian noise (scaled by `ratio`) to an image or batch of images. + Supports: + • Single image: (C, H, W) + • Batch of images: (B, C, H, W) + Any pixel exactly == –1 will have zero noise (mask value). + """ + if ratio is None or ratio == 0.0: + return image + + dims = image.ndim + if dims == 3: + # Single image -> make it a batch of 1 + image = image.unsqueeze(0) # -> (1, C, H, W) + squeeze_back = True + elif dims == 4: + squeeze_back = False + else: + raise ValueError( + f"add_noise_to_reference_video() expected 3D or 4D tensor, got {dims}D" + ) + + # image is now (B, C, H, W) + B, C, H, W = image.shape + # make a (B,) sigma array, all = ratio + sigma = image.new_ones((B,)) * ratio + # sample noise and scale by sigma + noise = torch.randn_like(image) * sigma.view(B, 1, 1, 1) + # zero out noise wherever the original was -1 + noise = torch.where(image == -1, torch.zeros_like(image), noise) + + out = image + noise + return out.squeeze(0) if squeeze_back else out + + +# Below here, Blyss wrote it! +class BlissfulLogger: + def __init__(self, logging_source: str, log_color: str, do_announce: Optional[bool] = False): + logging_source = f"{logging_source}" + self.logging_source = "{:<8}".format(logging_source) + self.log_color = log_color + self.logger = logging.getLogger(self.logging_source) + self.logger.setLevel(logging.DEBUG) + + self.handler = RichHandler( + show_time=False, + show_level=True, + show_path=True, + rich_tracebacks=True, + markup=True + ) + + formatter = logging.Formatter( + f"[{self.log_color} bold]%(name)s[/] | %(message)s [dim](%(funcName)s)[/]" + ) + + self.handler.setFormatter(formatter) + self.logger.addHandler(self.handler) + if do_announce: + self.logger.info("Set up logging!") + + def set_color(self, new_color): + self.log_color = new_color + formatter = logging.Formatter( + f"[{self.log_color} bold]%(name)s[/] | %(message)s [dim](%(funcName)s)[/]" + ) + self.handler.setFormatter(formatter) + + def set_name(self, new_name): + self.logging_source = "{:<8}".format(new_name) + self.logger = logging.getLogger(self.logging_source) + self.logger.setLevel(logging.DEBUG) + + # Remove any existing handlers (just in case) + if not self.logger.hasHandlers(): + self.logger.addHandler(self.handler) + else: + self.logger.handlers.clear() + self.logger.addHandler(self.handler) + + def info(self, msg): + self.logger.info(msg, stacklevel=2) + + def debug(self, msg): + self.logger.debug(msg, stacklevel=2) + + def warning(self, msg, levelmod=0): + self.logger.warning(msg, stacklevel=2 + levelmod) + + def warn(self, msg): + self.logger.warning(msg, stacklevel=2) + + def error(self, msg): + self.logger.error(msg, stacklevel=2) + + def critical(self, msg): + self.logger.critical(msg, stacklevel=2) + + def setLevel(self, level): + self.logger.set_level(level) + + +def parse_scheduled_cfg(schedule: str, infer_steps: int, guidance_scale: int) -> List[int]: + """ + Parse a schedule string like "1-10,20,!5,e~3" into a sorted list of steps. + + - "start-end" includes all steps in [start, end] + - "e~n" includes every nth step (n, 2n, ...) up to infer_steps + - "x" includes the single step x + - Prefix "!" on any token to exclude those steps instead of including them. + - Postfix ":float" e.g. ":6.0" to any step or range to specify a guidance_scale override for that step + + Raises argparse.ArgumentTypeError on malformed tokens or out-of-range steps. + """ + excluded = set() + guidance_scale_dict = {} + + for raw in schedule.split(","): + token = raw.strip() + if not token: + continue # skip empty tokens + + # exclusion if it starts with "!" + if token.startswith("!"): + target = "exclude" + token = token[1:] + else: + target = "include" + + weight = guidance_scale + if ":" in token: + token, float_part = token.rsplit(":", 1) + weight = float(float_part) + + # modulus syntax: e.g. "e~3" + if token.startswith("e~"): + num_str = token[2:] + try: + n = int(num_str) + except ValueError: + raise argparse.ArgumentTypeError(f"Invalid modulus in '{raw}'") + if n < 1: + raise argparse.ArgumentTypeError(f"Modulus must be ≥ 1 in '{raw}'") + + steps = range(n, infer_steps + 1, n) + + # range syntax: e.g. "5-10" + elif "-" in token: + parts = token.split("-") + if len(parts) != 2: + raise argparse.ArgumentTypeError(f"Malformed range '{raw}'") + start_str, end_str = parts + try: + start = int(start_str) + end = int(end_str) + except ValueError: + raise argparse.ArgumentTypeError(f"Non‑integer in range '{raw}'") + if start < 1 or end < 1: + raise argparse.ArgumentTypeError(f"Steps must be ≥ 1 in '{raw}'") + if start > end: + raise argparse.ArgumentTypeError(f"Start > end in '{raw}'") + if end > infer_steps: + raise argparse.ArgumentTypeError(f"End > infer_steps ({infer_steps}) in '{raw}'") + + steps = range(start, end + 1) + + # single‑step syntax: e.g. "7" + else: + try: + step = int(token) + except ValueError: + raise argparse.ArgumentTypeError(f"Invalid token '{raw}'") + if step < 1 or step > infer_steps: + raise argparse.ArgumentTypeError(f"Step {step} out of range 1–{infer_steps} in '{raw}'") + + steps = [step] + + # apply include/exclude + if target == "include": + for step in steps: + guidance_scale_dict[step] = weight + else: + excluded.update(steps) + + for step in excluded: + guidance_scale_dict.pop(step, None) + return guidance_scale_dict + + +def setup_compute_context(device: Optional[Union[torch.device, str]] = None, dtype: Optional[Union[torch.dtype, str]] = None) -> Tuple[torch.device, torch.dtype]: + dtype_mapping = { + "fp16": torch.float16, + "float16": torch.float16, + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, + "fp32": torch.float32, + "float32": torch.float32, + "fp8": torch.float8_e4m3fn, + "float8": torch.float8_e4m3fn + } + if device is None: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.mps.is_available(): + device = torch.device("mps") + elif isinstance(device, str): + device = torch.device(device) + + if dtype is None: + dtype = torch.float32 + elif isinstance(dtype, str): + if dtype not in dtype_mapping: + raise ValueError(f"Unknown dtype string '{dtype}'") + dtype = dtype_mapping[dtype] + + torch.set_float32_matmul_precision('high') + if dtype == torch.float16 or dtype == torch.bfloat16: + if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): + torch.backends.cuda.matmul.allow_fp16_accumulation = True + print("FP16 accumulation enabled.") + return device, dtype + + +def string_to_seed(s: str, bits: int = 63) -> int: + """ + Turn any string into a reproducible integer in [0, 2**bits) with a hash and some other logic. + + Args: + s: Input string + bits: Number of bits for the final seed (PyTorch accepts up to 63 safely, numpy likes 32) + Returns: + A non-negative int < 2**bits + """ + digest = hashlib.sha256(s.encode("utf-8")).digest() + crypto = int.from_bytes(digest, byteorder="big") + mask = (1 << bits) - 1 + algo = 0 + for i, char in enumerate(s): + char_val = ord(char) + if i % 2 == 0: + algo *= char_val + elif i % 3 == 0: + algo -= char_val + elif i % 5 == 0: + algo /= char_val + else: + algo += char_val + seed = (abs(crypto - int(algo))) & mask + return seed + + +def error_out(error, message): + logger = BlissfulLogger(__name__, "#8e00ed") + logger.warning(message, levelmod=1) + raise error(message) diff --git a/blissful_tuner/video_processing_common.py b/blissful_tuner/video_processing_common.py new file mode 100644 index 0000000000000000000000000000000000000000..7d7bd7b97e5766b2d2cf4f7630d69cb170cbe45e --- /dev/null +++ b/blissful_tuner/video_processing_common.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Common video processing utilities for Blissful Tuner extension. + +License: Apache-2.0 +Created on Thu Apr 24 11:29:37 2025 +Author: Blyss +""" +import argparse +import glob +import os +import random +import shutil +import subprocess +from pathlib import Path +from typing import List, Tuple, Union, Optional +from einops import rearrange +import torchvision +from rich_argparse import RichHelpFormatter +from PIL import Image, UnidentifiedImageError +import cv2 +import numpy as np +import torch +try: + from blissful_tuner.utils import BlissfulLogger, string_to_seed +except ImportError: # This is needed so we can import either within blissful_tuner directory or base musubi directory + from utils import BlissfulLogger, string_to_seed + + +logger = BlissfulLogger(__name__, "#8e00ed") + + +def set_seed(seed: Union[int, str] = None) -> int: + """ + Sets the random seed for reproducibility. + """ + if seed is None: + seed = random.getrandbits(32) + else: + try: + seed = int(seed) + except ValueError: + seed = string_to_seed(seed, bits=32) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + return seed + + +def setup_parser_video_common(description: Optional[str] = None) -> argparse.ArgumentParser: + "Common function for setting up the parser for GIMM-VFI, upscaler, and face fix" + parser = argparse.ArgumentParser(description=description, formatter_class=RichHelpFormatter) + parser.add_argument("--model", required=True, help="Path to the model(directory for GIMM-VFI, .safetensors otherwise)") + parser.add_argument("--input", required=True, help="Input video/image to process") + parser.add_argument("--dtype", type=str, default="fp32", help="Datatype to use") + parser.add_argument( + "--output", type=str, default=None, + help="Output file path, default is same path as input. Extension may be changed to match chosen settings!" + ) + parser.add_argument("--seed", type=str, default=None, help="Seed for reproducibility") + parser.add_argument("--keep_pngs", action="store_true", help="Also keep individual frames as PNGs") + parser.add_argument( + "--codec", choices=["prores", "h264", "h265"], default="prores", + help="Codec to use, choose from 'prores', 'h264', or 'h265'. Ignored for images." + ) + parser.add_argument( + "--container", choices=["mkv", "mp4"], default="mkv", + help="Container format to use, choose from 'mkv' or 'mp4'. Note prores can only go in MKV! Ignored for images." + ) + return parser + + +class BlissfulVideoProcessor: + """ + Manager for working with images and video in generative AI workloads + """ + + def __init__(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: + """ + Initialize with a target device and dtype for tensor operations. + + Args: + device: torch.device (e.g. cuda or cpu). + dtype: torch.dtype (e.g. torch.float32, torch.float16). + """ + self.device = device if device is not None else torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + self.dtype = dtype if dtype is not None else torch.float32 + self.png_idx = 0 + self.frame_dir = "" + self.input_file_path = "" + self.output_file_path = "" + self.output_directory = "" + self.new_ext = ".mkv" + self.codec = "prores" + + def prepare_files_and_path( + self, + input_file_path: Optional[str] = None, + output_file_path: Optional[str] = None, + modifier: Optional[str] = "", + codec: Optional[str] = "prores", + container: Optional[str] = "mkv", + ) -> Tuple[str, str]: + """ + Determine and confirm input/output paths, generating a default output + name if none provided, and set up the frames directory path and codec/container. + + Args: + input_file_path: Path to the source video. + output_file_path: Desired output path or None to auto-generate. + modifier: Suffix to append to the basename when auto-generating. + codec: The video codec to use(ignored for images) + container: The container format to use(ignored for images) + + Returns: + A tuple of (input_file_path, output_file_path). + """ + def _is_image_file(path: Path) -> bool: + try: + with Image.open(path) as img: + img.verify() + return True + except (UnidentifiedImageError, OSError): + return False + if codec is not None: + if codec.lower() in ["prores", "h264", "h265"]: + self.codec = codec.lower() + else: + raise ValueError("Invalid codec requested {codec}! Expected 'prores', 'h264', or 'h265'!") + if container is not None: + if container.lower() == "mkv": + self.new_ext = ".mkv" + elif container.lower() == "mp4": + if self.codec != "prores": + self.new_ext = ".mp4" + else: + logger.warning("Prores can only be written into an mkv but mp4 was passed! Selecting mkv and continuing...") + else: + raise ValueError("Invalid container format {container}! Expected 'mkv' or 'mp4'!") + if input_file_path is not None: + basename = os.path.basename(input_file_path) + name, _ = os.path.splitext(basename) + output_dir = os.path.dirname(input_file_path) + is_image = _is_image_file(input_file_path) + if is_image: + self.new_ext = ".png" + self.codec = "png" + elif output_file_path is not None: + output_dir = os.path.dirname(output_file_path) + else: + raise ValueError("At least one of input_file_path or output_file_path must be provided!") + + if not output_file_path: + output_file_path = os.path.join(output_dir, f"{name}_{modifier}{self.new_ext}") + o_basename = os.path.basename(output_file_path) + o_name, o_ext = os.path.splitext(o_basename) + o_output_dir = os.path.dirname(output_file_path) + if o_ext != self.new_ext: + logger.warning(f"Extension '{o_ext[-3:]}' not valid for output! Updating to '{self.new_ext[-3:]}'...") + output_file_path = os.path.join(o_output_dir, f"{o_name}{self.new_ext}") + + if os.path.exists(output_file_path): + choice = input(f"{output_file_path} exists. F for 'fix' by appending _! Overwrite?[y/N/f]: ").strip().lower() + if choice == 'f': + base = o_name + while os.path.exists(output_file_path): + base += '_' + output_file_path = os.path.join(o_output_dir, f"{base}{self.new_ext}") + elif choice != 'y': + logger.info("Aborted.") + exit() + + self.input_file_path = input_file_path + self.output_file_path = output_file_path + self.output_directory = output_dir + self.frame_dir = os.path.join(self.output_directory, 'frames') + if os.path.exists(self.frame_dir): + while os.path.exists(self.frame_dir): + self.frame_dir += "_" + + logger.info(f"Output will be saved to: {self.output_file_path} using {self.codec}!") + return self.input_file_path, self.output_file_path + + def np_image_to_tensor( + self, + image: Union[np.ndarray, List[np.ndarray]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Convert a single H×W×3 numpy image or list of images (RGB uint8 or float32) + into torch tensors of shape 1×3×H×W in [0,1], on the configured device and dtype. + + Args: + image: An RGB image array or list of arrays. + + Returns: + A torch.Tensor or list of torch.Tensors. + """ + def _convert(img: np.ndarray) -> torch.Tensor: + arr = img.astype(np.float32) / 255.0 + tensor = torch.from_numpy(arr.transpose(2, 0, 1)) + return tensor.unsqueeze(0).to(self.device, self.dtype) + + if isinstance(image, np.ndarray): + return _convert(image) + return [_convert(img) for img in image] + + def tensor_to_np_image( + self, + tensor: Union[torch.Tensor, List[torch.Tensor]], + rescale: bool = False + ) -> Union[np.ndarray, List[np.ndarray]]: + """ + Convert a 1×3×H×W or 3×H×W torch tensor (RGB float in [0,1] or [-1,1]) + into H×W×3 uint8 BGR images suitable for OpenCV (and do rescale if needed). + + Args: + tensor: A torch.Tensor or list of torch.Tensors. + rescale: If True, assumes the tensor is in [-1,1] and remaps to [0,1]. + Returns: + A numpy BGR image or list of images. + """ + def _convert(t: torch.Tensor) -> np.ndarray: + # 1) Bring to CPU, float, clamp + t = t.detach().cpu().float() + # 2) Optional range shift from [-1,1] to [0,1] + if rescale: + t = (t + 1.0) / 2.0 + t = t.clamp(0.0, 1.0) + + # 3) Normalize shape to [1,3,H,W] + if t.ndim == 3: # [3,H,W] + t = t.unsqueeze(0) # -> [1,3,H,W] + elif t.ndim != 4 or t.shape[1] != 3: + raise ValueError(f"Unexpected tensor shape: {tuple(t.shape)}") + + # 4) Squeeze batch, permute to H×W×C, scale to 0–255 + t = t.squeeze(0) # [3,H,W] + img = (t.permute(1, 2, 0).numpy() * 255.0).round().astype(np.uint8) # [H,W,3] + + # 5) Flip RGB→BGR for OpenCV + return img[..., ::-1] + + if isinstance(tensor, torch.Tensor): + return _convert(tensor) + return [_convert(t) for t in tensor] + + def load_frames( + self, + make_rgb: Optional[bool] = False + ) -> Tuple[List[np.ndarray], float, int, int]: + """ + Load all frames from the input video/image as uint8 BGR or RGB numpy arrays. + + Args: + make_rgb: If True, convert frames to RGB. + + Returns: + frames: List of H×W×3 image arrays. + fps: Frame rate of the video. + width: Original width. + height: Original height. + """ + cap = cv2.VideoCapture(self.input_file_path) + fps = cap.get(cv2.CAP_PROP_FPS) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames: List[np.ndarray] = [] + + while True: + ret, frame = cap.read() + if not ret: + break + if make_rgb: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + + cap.release() + return frames, fps, width, height + + def write_np_or_tensor_to_png( + self, + img: Union[np.ndarray, torch.Tensor] + ) -> None: + """ + Write a single frame (numpy BGR or tensor) to the frames directory as PNG. + + Args: + img: A BGR uint8 image array or a tensor to convert. + """ + if isinstance(img, torch.Tensor): + img = self.tensor_to_np_image(img) + if self.png_idx == 0: + os.makedirs(self.frame_dir, exist_ok=False) + path = os.path.join(self.frame_dir, f"{self.png_idx:06d}.png") + cv2.imwrite(path, img) + self.png_idx += 1 + + def write_np_images_to_output( + self, + imgs: List[np.ndarray], + fps: Optional[float] = 1, + keep_frames: Optional[bool] = False, + rescale: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Dump a list of BGR frames as PNGs + + Args: + imgs: List of H×W×3 uint8 BGR frames. + fps: Output frame rate. + rescale: To resize the output + keep_frames: If True, do not delete PNGs afterward. + """ + os.makedirs(self.frame_dir, exist_ok=False) + for idx, img in enumerate(imgs): + path = os.path.join(self.frame_dir, f"{idx:06d}.png") + cv2.imwrite(path, img) + self.write_buffered_frames_to_output(fps, keep_frames, rescale) + + def write_buffered_frames_to_output( + self, + fps: Optional[float] = 1, + keep_frames: Optional[bool] = False, + rescale: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Encode the PNG sequence in the frames directory to a video via ffmpeg, + or—if there's only one frame—just write out an (optionally-rescaled) PNG. + """ + # 1) get all the PNGs + pattern = os.path.join(self.frame_dir, "*.png") + png_paths = sorted(glob.glob(pattern)) + + # 2) single-image case + if len(png_paths) == 1: + src = png_paths[0] + + if rescale is None: + # just copy the original + shutil.copy(src, self.output_file_path) + else: + # PIL approach: open, resize, save + width, height = rescale + with Image.open(src) as img: + # LANCZOS gives a high-quality down/upscale + img = img.resize((width, height), Image.LANCZOS) + img.save(self.output_file_path) + else: + # 3) multi‐frame → video + codec_args = self._get_ffmpeg_codec_args() + cmd = [ + "ffmpeg", "-framerate", str(fps), + "-i", os.path.join(self.frame_dir, "%06d.png"), + ] + codec_args + + if rescale is not None: + w, h = rescale + cmd += ["-vf", f"scale={w}:{h}"] + + # overwrite without prompt + cmd += ["-y", self.output_file_path] + + subprocess.run(cmd, check=True) + if not keep_frames: + shutil.rmtree(self.frame_dir, ignore_errors=True) + + def _get_ffmpeg_codec_args(self) -> List[str]: + """ + Return the ffmpeg args for codec/quality based on self.codec. + """ + if self.codec == "prores": + # prores_ks profile 3 + broadcast-safe colors + return [ + "-c:v", "prores_ks", + "-profile:v", "3", + "-pix_fmt", "yuv422p10le", + "-colorspace", "1", + "-color_primaries", "1", + "-color_trc", "1", + ] + if self.codec == "h264": + # libx264 + return [ + "-c:v", "libx264", + "-preset", "slow", + "-crf", "16", + "-pix_fmt", "yuv420p", + ] + if self.codec == "h265": + # libx265 + return [ + "-c:v", "libx265", + "-preset", "slow", + "-crf", "16", + "-pix_fmt", "yuv420p", + ] + raise ValueError(f"Unsupported codec: {self.codec}") + + +def save_videos_grid_advanced( + videos: torch.Tensor, + output_video: str, + codec: str, + container: str, + rescale: bool = False, + fps: int = 24, + n_rows: int = 1, + keep_frames: bool = False +): + "Function for saving Musubi Tuner outputs with more codec and container types" + + # 1) rearrange so we iterate over time + videos = rearrange(videos, "b c t h w -> t b c h w") + + VideoProcessor = BlissfulVideoProcessor() + VideoProcessor.prepare_files_and_path( + input_file_path=None, + output_file_path=output_video, + codec=codec, + container=container + ) + + outputs = [] + for video in videos: + # 2) tile frames into one grid [C, H, W] + grid = torchvision.utils.make_grid(video, nrow=n_rows) + # 3) convert to an OpenCV-ready numpy array + np_img = VideoProcessor.tensor_to_np_image(grid, rescale=rescale) + outputs.append(np_img) + + # 4) write them out + VideoProcessor.write_np_images_to_output(outputs, fps, keep_frames) diff --git a/blissful_tuner/widgets.py b/blissful_tuner/widgets.py new file mode 100644 index 0000000000000000000000000000000000000000..ba59f9a9b8a3c69182ee6f90a350b6f37dfb0086 --- /dev/null +++ b/blissful_tuner/widgets.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Mar 12 13:59:54 2025 + +@author: blyss +""" +from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QDial, QLabel, QLineEdit, QDialog, QPushButton, QFileDialog, QFormLayout, QCheckBox, QTextEdit +from PySide6.QtCore import Qt, Signal, QEvent, QLineF +from PySide6.QtGui import QPainter, QMouseEvent, QIntValidator +from blissful_tuner.blissful_settings import BlissfulSettings + + +class PromptWidget(QWidget): + def __init__(self, global_settings): + super().__init__() + self.global_settings = global_settings + + main_layout = QVBoxLayout(self) + + prompt_label = QLabel("Prompt:") + self.prompt_edit = QTextEdit() + self.prompt_edit.setText(self.global_settings.prompt) + self.prompt_edit.textChanged.connect(self.on_prompt_text_changed) + + main_layout.addWidget(prompt_label) + main_layout.addWidget(self.prompt_edit) + + def on_prompt_text_changed(self): + text = self.prompt_edit.toPlainText() + self.global_settings.update("prompt", text) + + +class SettingsDialog(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + self.global_settings = BlissfulSettings() + self.setWindowTitle("Settings") + self.resize(500, 300) + main_layout = QVBoxLayout(self) + + # QFormLayout with right-aligned labels. + form_layout = QFormLayout() + form_layout.setLabelAlignment(Qt.AlignRight) + self.path_edits = {} + + settings = ["Transformer", "Text Encoder 1", "Text Encoder 2", "VAE", "LoRAs"] + + self.attr_mapping = { + "Transformer": "transformer_path", + "Text Encoder 1": "text_encoder_1_path", + "Text Encoder 2": "text_encoder_2_path", + "VAE": "vae_path", + "LoRAs": "lora_path" + } + + for setting in settings: + container = QWidget() + container_layout = QHBoxLayout(container) + container_layout.setContentsMargins(0, 0, 0, 0) + + line_edit = QLineEdit() + line_edit.setMinimumWidth(300) + current_value = getattr(self.global_settings, self.attr_mapping[setting], "") + line_edit.setText(current_value) + + browse_button = QPushButton("Browse") + browse_button.clicked.connect(lambda checked, le=line_edit, s=setting: self.open_file_or_folder(le, s)) + + container_layout.addWidget(line_edit) + container_layout.addWidget(browse_button) + + self.path_edits[setting] = line_edit + form_layout.addRow(QLabel(f"{setting}:"), container) + + main_layout.addLayout(form_layout) + + button_layout = QHBoxLayout() + button_layout.addStretch() + ok_button = QPushButton("OK") + cancel_button = QPushButton("Cancel") + ok_button.clicked.connect(self.accept) + cancel_button.clicked.connect(self.reject) + button_layout.addWidget(ok_button) + button_layout.addWidget(cancel_button) + main_layout.addLayout(button_layout) + + def open_file_or_folder(self, line_edit, setting): + """ + Opens a file dialog for settings other than LoRAs, and a folder dialog for LoRAs. + """ + if setting == "LoRAs": + folder = QFileDialog.getExistingDirectory(self, "Select Folder") + if folder: + line_edit.setText(folder) + else: + file, _ = QFileDialog.getOpenFileName(self, "Select File") + if file: + line_edit.setText(file) + + def accept(self): + for key, value in self.path_edits.items(): + setattr(self.global_settings, self.attr_mapping[key], value.text()) + super().accept() + + +class SeedWidget(QWidget): + def __init__(self, global_settings): + super().__init__() + self.global_settings = global_settings + + main_layout = QVBoxLayout(self) + self.line_edit = QLineEdit(str(self.global_settings.seed)) + self.line_edit.setFixedWidth(80) + self.line_edit.setValidator(QIntValidator(-999999999, 999999999, self)) + self.line_edit.textChanged.connect(lambda value: self.global_settings.update("seed", value)) + main_layout.addWidget(QLabel("Seed"), alignment=Qt.AlignCenter) + main_layout.addWidget(self.line_edit, alignment=Qt.AlignCenter) + + checkbox_layout = QHBoxLayout() + self.checkbox = QCheckBox("Random") + self.checkbox.toggled.connect(self.on_checkbox_toggled) + checkbox_layout.addWidget(self.checkbox, alignment=Qt.AlignCenter) + + main_layout.addLayout(checkbox_layout) + + def on_checkbox_toggled(self, checked): + self.line_edit.setText("-1") + self.line_edit.setDisabled(checked) + self.global_settings.seed = -1 + + +class BlissfulDial(QDial): + """Wraps the QDial so we can write the value to it's center like a boss(or like someone who cares about UI design)""" + + def paintEvent(self, event): + super().paintEvent(event) + qp = QPainter(self) + qp.setRenderHint(QPainter.RenderHint.Antialiasing) + rect = self.rect() + value = self.value() + qp.drawText(rect, Qt.AlignmentFlag.AlignCenter, str(value)) + + +class ValueDial(QWidget): + """A QDial that displays values on it's notches, ported from Qt5 version here https://stackoverflow.com/questions/63698714/how-to-show-markings-on-qdial-in-pyqt5-python""" + _dialProperties = ('minimum', 'maximum', 'value', 'singleStep', 'pageStep', + 'notchesVisible', 'tracking', 'wrapping', + 'invertedAppearance', 'invertedControls', 'orientation') + _inPadding = 3 + _outPadding = 2 + valueChanged = Signal(int) + + def __init__(self, *args, **kwargs): + # Remove properties used as keyword arguments for the dial. + dialArgs = {k: v for k, v in kwargs.items() if k in self._dialProperties} + for k in dialArgs.keys(): + kwargs.pop(k) + super().__init__(*args, **kwargs) + layout = QVBoxLayout(self) + self.dial = BlissfulDial(self, **dialArgs) + layout.addWidget(self.dial) + self.dial.valueChanged.connect(self.valueChanged) + # Make the dial the focus proxy (so that it captures focus *and* key events) + self.setFocusProxy(self.dial) + + # Simple "monkey patching" to access dial functions. + self.value = self.dial.value + self.setValue = self.dial.setValue + self.minimum = self.dial.minimum + self.maximum = self.dial.maximum + self.wrapping = self.dial.wrapping + self.notchesVisible = self.dial.notchesVisible + self.setNotchesVisible = self.dial.setNotchesVisible + self.setNotchTarget = self.dial.setNotchTarget + self.notchSize = self.dial.notchSize + self.invertedAppearance = self.dial.invertedAppearance + self.setInvertedAppearance = self.dial.setInvertedAppearance + + self.updateSize() + + def inPadding(self): + return self._inPadding + + def setInPadding(self, padding): + self._inPadding = max(0, padding) + self.updateSize() + + def outPadding(self): + return self._outPadding + + def setOutPadding(self, padding): + self._outPadding = max(0, padding) + self.updateSize() + + def setMinimum(self, minimum): + self.dial.setMinimum(minimum) + self.updateSize() + + def setMaximum(self, maximum): + self.dial.setMaximum(maximum) + self.updateSize() + + def setWrapping(self, wrapping): + self.dial.setWrapping(wrapping) + self.updateSize() + + def updateSize(self): + # Update margins so that the value strings always have enough space. + fm = self.fontMetrics() + minWidth = max(fm.horizontalAdvance(str(v)) for v in range(self.minimum(), self.maximum() + 1)) + self.offset = max(minWidth, fm.height()) / 2 + margin = self.offset + self._inPadding + self._outPadding + self.layout().setContentsMargins(margin, margin, margin, margin) + + def translateMouseEvent(self, event): + # Translate mouse events to the dial. + return QMouseEvent( + event.type(), + self.dial.mapFrom(self, event.pos()), + event.button(), + event.buttons(), + event.modifiers() + ) + + def changeEvent(self, event): + if event.type() == QEvent.Type.FontChange: + self.updateSize() + + def mousePressEvent(self, event): + self.dial.mousePressEvent(self.translateMouseEvent(event)) + + def mouseMoveEvent(self, event): + self.dial.mouseMoveEvent(self.translateMouseEvent(event)) + + def mouseReleaseEvent(self, event): + self.dial.mouseReleaseEvent(self.translateMouseEvent(event)) + + def paintEvent(self, event): + radius = min(self.width(), self.height()) / 2 + radius -= (self.offset / 2 + self._outPadding) + invert = -1 if self.invertedAppearance() else 1 + if self.wrapping(): + angleRange = 360 + startAngle = 270 + rangeOffset = 0 + else: + angleRange = 300 + startAngle = 240 if invert > 0 else 300 + rangeOffset = 1 + fm = self.fontMetrics() + + # Reference line for positioning the text. + reference = QLineF.fromPolar(radius, 0).translated(self.rect().center()) + fullRange = self.maximum() - self.minimum() + textRect = self.rect() + + qp = QPainter(self) + qp.setRenderHints(QPainter.RenderHint.Antialiasing) + label_interval = 4 # Print every Nth numberal label + for p in range(0, fullRange + rangeOffset, self.notchSize() * label_interval): + value = self.minimum() + p + if invert < 0: + value -= 1 + if value < self.minimum(): + continue + angle = p / fullRange * angleRange * invert + reference.setAngle(startAngle - angle) + textRect.setSize(fm.size(Qt.TextFlag.TextSingleLine, str(value))) + textRect.moveCenter(reference.p2().toPoint()) + qp.drawText(textRect, Qt.AlignmentFlag.AlignCenter, str(value)) + + +class ResolutionWidget(QWidget): + """Custom widget for specifying resolution and framerate with validation for the former""" + + def __init__(self, global_settings): + super().__init__() + self.global_settings = global_settings + self.initUI() + + def initUI(self): + layout = QHBoxLayout(self) + + self.width_input = QLineEdit() + self.width_input.setText(str(self.global_settings.resolution_x)) + self.width_input.setFixedWidth(60) + self.width_input.setValidator(QIntValidator(1, 9999, self)) + self.check_divisible(self.width_input) + self.width_input.textChanged.connect(lambda: self.check_divisible(self.width_input)) + self.width_input.textChanged.connect(lambda value: self.global_settings.update("resolution_x", value)) + + x_label = QLabel("x") + + self.height_input = QLineEdit() + self.height_input.setText(str(self.global_settings.resolution_y)) + self.height_input.setFixedWidth(60) + self.height_input.setValidator(QIntValidator(1, 9999, self)) + self.check_divisible(self.height_input) + self.height_input.textChanged.connect(lambda: self.check_divisible(self.height_input)) + self.height_input.textChanged.connect(lambda value: self.global_settings.update("resolution_y", value)) + + at_label = QLabel("@") + + self.fps_input = QLineEdit() + self.fps_input.setText(str(self.global_settings.fps)) + self.fps_input.setFixedWidth(25) + self.fps_input.setValidator(QIntValidator(1, 200, self)) + self.fps_input.textChanged.connect(lambda value: self.global_settings.update("fps", value)) + + fps_label = QLabel("fps") + + layout.addWidget(self.width_input, alignment=Qt.AlignLeft) + layout.addWidget(x_label, alignment=Qt.AlignLeft) + layout.addSpacing(-30) + layout.addWidget(self.height_input, alignment=Qt.AlignLeft) + layout.addWidget(at_label, alignment=Qt.AlignLeft) + layout.addSpacing(-30) + layout.addWidget(self.fps_input, alignment=Qt.AlignLeft) + layout.addSpacing(-20) + layout.addWidget(fps_label, alignment=Qt.AlignLeft) + + def check_divisible(self, line_edit: QLineEdit): + """ + Check if the number in the QLineEdit is divisible by 8. + If yes, set the background color to light green. + If not, set it to light red (light coral). + """ + text = line_edit.text().strip() + if text: + try: + value = int(text) + if value % 8 == 0: + # Divisible by 8 -> light green background + line_edit.setStyleSheet("background-color: darkgreen;") + else: + # Not divisible by 8 -> light red background + line_edit.setStyleSheet("background-color: lightcoral;") + except ValueError: + # If conversion fails, mark as invalid (red background) + line_edit.setStyleSheet("background-color: lightcoral;") + else: + # Empty input, remove background color + line_edit.setStyleSheet("") diff --git a/cache_latents.py b/cache_latents.py new file mode 100644 index 0000000000000000000000000000000000000000..6f94853e5a7a865e23af7cd9239d5afab1bbad2d --- /dev/null +++ b/cache_latents.py @@ -0,0 +1,269 @@ +import argparse +import os +import glob +from typing import Optional, Union + +import numpy as np +import torch +from tqdm import tqdm + +from dataset import config_utils +from dataset.config_utils import BlueprintGenerator, ConfigSanitizer +from PIL import Image + +import logging + +from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache +from hunyuan_model.vae import load_vae +from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D +from utils.model_utils import str_to_dtype + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int: + import cv2 + + imgs = ( + [image] + if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image) + else [image[0], image[-1]] + ) + if len(imgs) > 1: + print(f"Number of images: {len(image)}") + for i, img in enumerate(imgs): + if len(imgs) > 1: + print(f"{'First' if i == 0 else 'Last'} image: {img.shape}") + else: + print(f"Image: {img.shape}") + cv2_img = np.array(img) if isinstance(img, Image.Image) else img + cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR) + cv2.imshow("image", cv2_img) + k = cv2.waitKey(0) + cv2.destroyAllWindows() + if k == ord("q") or k == ord("d"): + return k + return k + + +def show_console( + image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]], + width: int, + back: str, + interactive: bool = False, +) -> int: + from ascii_magic import from_pillow_image, Back + + back = None + if back is not None: + back = getattr(Back, back.upper()) + + k = None + imgs = ( + [image] + if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image) + else [image[0], image[-1]] + ) + if len(imgs) > 1: + print(f"Number of images: {len(image)}") + for i, img in enumerate(imgs): + if len(imgs) > 1: + print(f"{'First' if i == 0 else 'Last'} image: {img.shape}") + else: + print(f"Image: {img.shape}") + pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img) + ascii_img = from_pillow_image(pil_img) + ascii_img.to_terminal(columns=width, back=back) + + if interactive: + k = input("Press q to quit, d to next dataset, other key to next: ") + if k == "q" or k == "d": + return ord(k) + + if not interactive: + return ord(" ") + return ord(k) if k else ord(" ") + + +def show_datasets( + datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int] +): + print(f"d: next dataset, q: quit") + + num_workers = max(1, os.cpu_count() - 1) + for i, dataset in enumerate(datasets): + print(f"Dataset [{i}]") + batch_index = 0 + num_images_to_show = console_num_images + k = None + for key, batch in dataset.retrieve_latent_cache_batches(num_workers): + print(f"bucket resolution: {key}, count: {len(batch)}") + for j, item_info in enumerate(batch): + item_info: ItemInfo + print(f"{batch_index}-{j}: {item_info}") + if debug_mode == "image": + k = show_image(item_info.content) + elif debug_mode == "console": + k = show_console(item_info.content, console_width, console_back, console_num_images is None) + if num_images_to_show is not None: + num_images_to_show -= 1 + if num_images_to_show == 0: + k = ord("d") # next dataset + + if k == ord("q"): + return + elif k == ord("d"): + break + if k == ord("d"): + break + batch_index += 1 + + +def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]): + contents = torch.stack([torch.from_numpy(item.content) for item in batch]) + if len(contents.shape) == 4: + contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C + + contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W + contents = contents.to(vae.device, dtype=vae.dtype) + contents = contents / 127.5 - 1.0 # normalize to [-1, 1] + + h, w = contents.shape[3], contents.shape[4] + if h < 8 or w < 8: + item = batch[0] # other items should have the same size + raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}") + + # print(f"encode batch: {contents.shape}") + with torch.no_grad(): + latent = vae.encode(contents).latent_dist.sample() + # latent = latent * vae.config.scaling_factor + + # # debug: decode and save + # with torch.no_grad(): + # latent_to_decode = latent / vae.config.scaling_factor + # images = vae.decode(latent_to_decode, return_dict=False)[0] + # images = (images / 2 + 0.5).clamp(0, 1) + # images = images.cpu().float().numpy() + # images = (images * 255).astype(np.uint8) + # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C + # for b in range(images.shape[0]): + # for f in range(images.shape[1]): + # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0] + # img = Image.fromarray(images[b, f]) + # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg") + + for item, l in zip(batch, latent): + # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}") + save_latent_cache(item, l) + + +def main(args): + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # Load dataset config + blueprint_generator = BlueprintGenerator(ConfigSanitizer()) + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_utils.load_user_config(args.dataset_config) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + datasets = train_dataset_group.datasets + + if args.debug_mode is not None: + show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images) + return + + assert args.vae is not None, "vae checkpoint is required" + + # Load VAE model: HunyuanVideo VAE model is float16 + vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype) + vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae) + vae.eval() + logger.info(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}") + + if args.vae_chunk_size is not None: + vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size) + logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE") + if args.vae_spatial_tile_sample_min_size is not None: + vae.enable_spatial_tiling(True) + vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size + vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 + elif args.vae_tiling: + vae.enable_spatial_tiling(True) + + # Encode images + num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1) + for i, dataset in enumerate(datasets): + logger.info(f"Encoding dataset [{i}]") + all_latent_cache_paths = [] + for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)): + all_latent_cache_paths.extend([item.latent_cache_path for item in batch]) + + if args.skip_existing: + filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)] + if len(filtered_batch) == 0: + continue + batch = filtered_batch + + bs = args.batch_size if args.batch_size is not None else len(batch) + for i in range(0, len(batch), bs): + encode_and_save_batch(vae, batch[i : i + bs]) + + # normalize paths + all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths] + all_latent_cache_paths = set(all_latent_cache_paths) + + # remove old cache files not in the dataset + all_cache_files = dataset.get_all_latent_cache_files() + for cache_file in all_cache_files: + if os.path.normpath(cache_file) not in all_latent_cache_paths: + if args.keep_cache: + logger.info(f"Keep cache file not in the dataset: {cache_file}") + else: + os.remove(cache_file) + logger.info(f"Removed old cache file: {cache_file}") + + +def setup_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file") + parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") + parser.add_argument( + "--vae_tiling", + action="store_true", + help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled", + ) + parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") + parser.add_argument( + "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" + ) + parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available") + parser.add_argument( + "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this" + ) + parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1") + parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files") + parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset") + parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode") + parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width") + parser.add_argument( + "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back" + ) + parser.add_argument( + "--console_num_images", + type=int, + default=None, + help="debug mode: not interactive, number of images to show for each dataset", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/cache_text_encoder_outputs.py b/cache_text_encoder_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..0496e12468afe9079bd4e6d72934802b5d180039 --- /dev/null +++ b/cache_text_encoder_outputs.py @@ -0,0 +1,166 @@ +import argparse +import os +from typing import Optional, Union + +import numpy as np +import torch +from tqdm import tqdm + +from dataset import config_utils +from dataset.config_utils import BlueprintGenerator, ConfigSanitizer +import accelerate + +from dataset.image_video_dataset import ItemInfo, save_text_encoder_output_cache +from hunyuan_model import text_encoder as text_encoder_module +from hunyuan_model.text_encoder import TextEncoder + +import logging + +from utils.model_utils import str_to_dtype + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]): + data_type = "video" # video only, image is not supported + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) + + with torch.no_grad(): + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type) + + return prompt_outputs.hidden_state, prompt_outputs.attention_mask + + +def encode_and_save_batch( + text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator] +): + prompts = [item.caption for item in batch] + # print(prompts) + + # encode prompt + if accelerator is not None: + with accelerator.autocast(): + prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts) + else: + prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts) + + # # convert to fp16 if needed + # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32: + # prompt_embeds = prompt_embeds.to(text_encoder.dtype) + + # save prompt cache + for item, embed, mask in zip(batch, prompt_embeds, prompt_mask): + save_text_encoder_output_cache(item, embed, mask, is_llm) + + +def main(args): + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # Load dataset config + blueprint_generator = BlueprintGenerator(ConfigSanitizer()) + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_utils.load_user_config(args.dataset_config) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + datasets = train_dataset_group.datasets + + # define accelerator for fp8 inference + accelerator = None + if args.fp8_llm: + accelerator = accelerate.Accelerator(mixed_precision="fp16") + + # define encode function + num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1) + + all_cache_files_for_dataset = [] # exisiting cache files + all_cache_paths_for_dataset = [] # all cache paths in the dataset + for dataset in datasets: + all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()] + all_cache_files = set(all_cache_files) + all_cache_files_for_dataset.append(all_cache_files) + + all_cache_paths_for_dataset.append(set()) + + def encode_for_text_encoder(text_encoder: TextEncoder, is_llm: bool): + for i, dataset in enumerate(datasets): + logger.info(f"Encoding dataset [{i}]") + all_cache_files = all_cache_files_for_dataset[i] + all_cache_paths = all_cache_paths_for_dataset[i] + for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)): + # update cache files (it's ok if we update it multiple times) + all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch]) + + # skip existing cache files + if args.skip_existing: + filtered_batch = [ + item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files + ] + # print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files") + if len(filtered_batch) == 0: + continue + batch = filtered_batch + + bs = args.batch_size if args.batch_size is not None else len(batch) + for i in range(0, len(batch), bs): + encode_and_save_batch(text_encoder, batch[i : i + bs], is_llm, accelerator) + + # Load Text Encoder 1 + text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype) + logger.info(f"loading text encoder 1: {args.text_encoder1}") + text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype) + text_encoder_1.to(device=device) + + # Encode with Text Encoder 1 + logger.info("Encoding with Text Encoder 1") + encode_for_text_encoder(text_encoder_1, is_llm=True) + del text_encoder_1 + + # Load Text Encoder 2 + logger.info(f"loading text encoder 2: {args.text_encoder2}") + text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype) + text_encoder_2.to(device=device) + + # Encode with Text Encoder 2 + logger.info("Encoding with Text Encoder 2") + encode_for_text_encoder(text_encoder_2, is_llm=False) + del text_encoder_2 + + # remove cache files not in dataset + for i, dataset in enumerate(datasets): + all_cache_files = all_cache_files_for_dataset[i] + all_cache_paths = all_cache_paths_for_dataset[i] + for cache_file in all_cache_files: + if cache_file not in all_cache_paths: + if args.keep_cache: + logger.info(f"Keep cache file not in the dataset: {cache_file}") + else: + os.remove(cache_file) + logger.info(f"Removed old cache file: {cache_file}") + + +def setup_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file") + parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory") + parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory") + parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available") + parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16") + parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") + parser.add_argument( + "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this" + ) + parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1") + parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files") + parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset") + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/change.txt b/change.txt new file mode 100644 index 0000000000000000000000000000000000000000..52b8dd9def6388514dd294107048a98efbd01086 --- /dev/null +++ b/change.txt @@ -0,0 +1,270 @@ +# Update the sample_solver choices to include 'vanilla' +@@ -1035,7 +1035,7 @@ with gr.Blocks( + + with gr.Row(): + wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") +- wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") ++ wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) + wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) + +# Add exclude_single_blocks checkbox for WanX-i2v +@@ -1035,6 +1035,7 @@ with gr.Blocks( + + with gr.Row(): + wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") ++ wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) + wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) + +# Add LoRA support to WanX-i2v tab +@@ -979,7 +979,27 @@ with gr.Blocks( + ) + wanx_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + ++ # Add LoRA section for WanX-i2v similar to other tabs ++ wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") ++ wanx_lora_weights = [] ++ wanx_lora_multipliers = [] ++ for i in range(4): ++ with gr.Column(): ++ wanx_lora_weights.append(gr.Dropdown( ++ label=f"LoRA {i+1}", ++ choices=get_lora_options(), ++ value="None", ++ allow_custom_value=True, ++ interactive=True ++ )) ++ wanx_lora_multipliers.append(gr.Slider( ++ label=f"Multiplier", ++ minimum=0.0, ++ maximum=2.0, ++ step=0.05, ++ value=1.0 ++ )) ++ + with gr.Row(): + wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_task = gr.Dropdown( +@@ -992,6 +1012,7 @@ with gr.Blocks( + wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") ++ wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_save_path = gr.Textbox(label="Save Path", value="outputs") + +# Update WanX-t2v sample solver choices +@@ -1099,7 +1099,7 @@ with gr.Blocks( + + with gr.Row(): + wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") +- wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") ++ wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, + info="Max 39 for 14B model, 29 for 1.3B model") + +# Add exclude_single_blocks checkbox for WanX-t2v +@@ -1099,6 +1099,7 @@ with gr.Blocks( + + with gr.Row(): + wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") ++ wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, + info="Max 39 for 14B model, 29 for 1.3B model") + +# Add LoRA support to WanX-t2v tab +@@ -1063,7 +1064,27 @@ with gr.Blocks( + ) + wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + ++ # Add LoRA section for WanX-t2v ++ wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") ++ wanx_t2v_lora_weights = [] ++ wanx_t2v_lora_multipliers = [] ++ for i in range(4): ++ with gr.Column(): ++ wanx_t2v_lora_weights.append(gr.Dropdown( ++ label=f"LoRA {i+1}", ++ choices=get_lora_options(), ++ value="None", ++ allow_custom_value=True, ++ interactive=True ++ )) ++ wanx_t2v_lora_multipliers.append(gr.Slider( ++ label=f"Multiplier", ++ minimum=0.0, ++ maximum=2.0, ++ step=0.05, ++ value=1.0 ++ )) ++ + with gr.Row(): + wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_t2v_task = gr.Dropdown( +@@ -1077,6 +1098,7 @@ with gr.Blocks( + wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") ++ wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") + +# Update wanx_generate_video function to include LoRA and exclude_single_blocks +@@ -2051,6 +2073,15 @@ def wanx_generate_video( + save_path, + output_type, + sample_solver, ++ exclude_single_blocks, + attn_mode, + block_swap, + fp8, +- fp8_t5 ++ fp8_t5, ++ lora_folder, ++ lora1="None", ++ lora2="None", ++ lora3="None", ++ lora4="None", ++ lora1_multiplier=1.0, ++ lora2_multiplier=1.0, ++ lora3_multiplier=1.0, ++ lora4_multiplier=1.0 + ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate video with WanX model (supports both i2v and t2v)""" + global stop_event +@@ -2107,6 +2138,20 @@ def wanx_generate_video( + + if fp8_t5: + command.append("--fp8_t5") ++ ++ if exclude_single_blocks: ++ command.append("--exclude_single_blocks") ++ ++ # Add LoRA weights and multipliers if provided ++ valid_loras = [] ++ for weight, mult in zip([lora1, lora2, lora3, lora4], ++ [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): ++ if weight and weight != "None": ++ valid_loras.append((os.path.join(lora_folder, weight), mult)) ++ if valid_loras: ++ weights = [weight for weight, _ in valid_loras] ++ multipliers = [str(mult) for _, mult in valid_loras] ++ command.extend(["--lora_weight"] + weights) ++ command.extend(["--lora_multiplier"] + multipliers) + + print(f"Running: {' '.join(command)}") + +# Update wanx_generate_video_batch function +@@ -2176,9 +2221,19 @@ def wanx_generate_video_batch( + save_path, + output_type, + sample_solver, ++ exclude_single_blocks, + attn_mode, + block_swap, + fp8, +- fp8_t5, ++ fp8_t5, ++ lora_folder, ++ lora1="None", ++ lora2="None", ++ lora3="None", ++ lora4="None", ++ lora1_multiplier=1.0, ++ lora2_multiplier=1.0, ++ lora3_multiplier=1.0, ++ lora4_multiplier=1.0, + batch_size=1, + input_image=None # Make input_image optional and place it at the end + ) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: +@@ -2201,9 +2256,19 @@ def wanx_generate_video_batch( + save_path, + output_type, + sample_solver, ++ exclude_single_blocks, + attn_mode, + block_swap, + fp8, +- fp8_t5 ++ fp8_t5, ++ lora_folder, ++ lora1, ++ lora2, ++ lora3, ++ lora4, ++ lora1_multiplier, ++ lora2_multiplier, ++ lora3_multiplier, ++ lora4_multiplier + ), + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], + queue=True + +# Update WanX-i2v generate button click handler +@@ -2423,6 +2488,15 @@ def wanx_generate_btn.click( + wanx_save_path, + wanx_output_type, + wanx_sample_solver, ++ wanx_exclude_single_blocks, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, +- wanx_fp8_t5, ++ wanx_fp8_t5, ++ wanx_lora_folder, ++ *wanx_lora_weights, ++ *wanx_lora_multipliers, + wanx_batch_size, + wanx_input # Include the image input for this tab + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], + queue=True + ) ++ ++ # Add refresh button handler for WanX-i2v tab ++ wanx_refresh_outputs = [] ++ for i in range(4): ++ wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) ++ ++ wanx_refresh_btn.click( ++ fn=update_lora_dropdowns, ++ inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, ++ outputs=wanx_refresh_outputs ++ ) + +# Update WanX-t2v generate button click handler +@@ -2470,9 +2544,19 @@ def wanx_t2v_generate_btn.click( + wanx_t2v_save_path, + wanx_t2v_output_type, + wanx_t2v_sample_solver, ++ wanx_t2v_exclude_single_blocks, + wanx_t2v_attn_mode, + wanx_t2v_block_swap, + wanx_t2v_fp8, +- wanx_t2v_fp8_t5, ++ wanx_t2v_fp8_t5, ++ wanx_t2v_lora_folder, ++ *wanx_t2v_lora_weights, ++ *wanx_t2v_lora_multipliers, + wanx_t2v_batch_size, + # input_image is now optional and not included here + ], + outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], + queue=True + ) ++ ++ # Add refresh button handler for WanX-t2v tab ++ wanx_t2v_refresh_outputs = [] ++ for i in range(4): ++ wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) ++ ++ wanx_t2v_refresh_btn.click( ++ fn=update_lora_dropdowns, ++ inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, ++ outputs=wanx_t2v_refresh_outputs ++ ) \ No newline at end of file diff --git a/compare_safetensors_weights.py b/compare_safetensors_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb59101e8acf64d2bea9e7c3f4311b0fb918513 --- /dev/null +++ b/compare_safetensors_weights.py @@ -0,0 +1,408 @@ +import argparse +import logging +import gc +import math +from pathlib import Path +from typing import Dict, Set, Tuple, List, Any + +import torch +from safetensors import safe_open +from tqdm import tqdm +import numpy as np + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# Use MPS if available on Mac, otherwise CUDA or CPU +if torch.backends.mps.is_available(): + DEFAULT_DEVICE = "mps" +elif torch.cuda.is_available(): + DEFAULT_DEVICE = "cuda" +else: + DEFAULT_DEVICE = "cpu" + +def get_tensor_keys(filepath: Path) -> Set[str]: + """Gets all tensor keys from a safetensors file without loading tensors.""" + keys = set() + try: + with safe_open(filepath, framework="pt", device="cpu") as f: + keys = set(f.keys()) + logging.debug(f"Found {len(keys)} keys in {filepath.name}") + return keys + except Exception as e: + logging.error(f"Error opening or reading keys from {filepath}: {e}") + raise + +def compare_tensors( + key: str, file1: Path, file2: Path, device: torch.device, atol: float +) -> Tuple[bool, float, float, float]: + """ + Loads and compares a single tensor from two files. + + Args: + key: The tensor key to compare. + file1: Path to the first safetensors file. + file2: Path to the second safetensors file. + device: The torch device to use for comparison. + atol: Absolute tolerance for torch.allclose check. + + Returns: + Tuple containing: + - is_close: Boolean indicating if tensors are close within tolerance. + - mean_abs_diff: Mean absolute difference. + - max_abs_diff: Maximum absolute difference. + - cosine_sim: Cosine similarity (-2.0 if not applicable/error). + """ + # Initialize variables to handle potential early returns + t1, t2, diff = None, None, None + mean_abs_diff = float('nan') + max_abs_diff = float('nan') + cosine_sim = -2.0 # Use -2.0 to indicate not computed or error + is_close = False + + try: + # Use safe_open for lazy loading + with safe_open(file1, framework="pt", device="cpu") as f1, \ + safe_open(file2, framework="pt", device="cpu") as f2: + + if key not in f1.keys(): + logging.warning(f"Key '{key}' missing in Model 1 ({file1.name}). Skipping comparison for this key.") + # No need to return here, let finally block handle cleanup if t2 was loaded + elif key not in f2.keys(): + logging.warning(f"Key '{key}' missing in Model 2 ({file2.name}). Skipping comparison for this key.") + # Load t1 to ensure it's deleted in finally if needed + t1 = f1.get_tensor(key) + else: + # Both keys exist, proceed with loading + t1 = f1.get_tensor(key) + t2 = f2.get_tensor(key) + + # --- Basic Checks --- + if t1.shape != t2.shape: + logging.warning( + f"Shape mismatch for key '{key}': {t1.shape} vs {t2.shape}. Cannot compare." + ) + # Return values indicating mismatch; t1/t2 will be cleaned up by finally + return False, float('nan'), float('nan'), -2.0 # Use NaN/special value for mismatch + + if t1.dtype != t2.dtype: + logging.warning( + f"Dtype mismatch for key '{key}': {t1.dtype} vs {t2.dtype}. Will attempt cast for comparison." + ) + # Attempt comparison anyway, might fail or give less meaningful results + try: + t2 = t2.to(t1.dtype) + except Exception as cast_e: + logging.error(f"Could not cast tensor '{key}' for comparison: {cast_e}") + # Return values indicating error; t1/t2 will be cleaned up by finally + return False, float('nan'), float('nan'), -2.0 + + # --- Move to device for computation --- + try: + # Move original tensors (or casted t2) + t1_dev = t1.to(device) + t2_dev = t2.to(device) + except Exception as move_e: + logging.error(f"Could not move tensor '{key}' to device '{device}': {move_e}. Trying CPU.") + device = torch.device('cpu') + t1_dev = t1.to(device) + t2_dev = t2.to(device) + + + # --- Comparison Metrics --- + with torch.no_grad(): + # Use float32 for difference calculation stability + diff = torch.abs(t1_dev.float() - t2_dev.float()) # Assign diff here + mean_abs_diff = torch.mean(diff).item() + max_abs_diff = torch.max(diff).item() + + # torch.allclose check + is_close = torch.allclose(t1_dev, t2_dev, atol=atol, rtol=0) # rtol=0 for FP16 comparison mostly depends on atol + + # Cosine Similarity (avoid for scalars, ensure vectors are flat) + if t1_dev.numel() > 1: + try: + # Ensure tensors are flat and float for cosine sim + cos_sim_val = torch.nn.functional.cosine_similarity( + t1_dev.flatten().float(), t2_dev.flatten().float(), dim=0 + ).item() + # Handle potential NaN/Inf from zero vectors etc. + cosine_sim = cos_sim_val if math.isfinite(cos_sim_val) else -1.0 + except Exception as cs_err: + logging.warning(f"Could not compute cosine similarity for '{key}': {cs_err}") + cosine_sim = -1.0 # Indicate computation error + elif t1_dev.numel() == 1: + cosine_sim = 1.0 if torch.allclose(t1_dev, t2_dev) else 0.0 # Define for scalars + + # Clean up device tensors explicitly after use + del t1_dev, t2_dev + + + except Exception as e: + logging.error(f"Unhandled error comparing tensor '{key}': {e}", exc_info=True) + # Return default failure values + return False, float('nan'), float('nan'), -2.0 + finally: + # --- Modified Finally Block --- + # Clear potential tensor references + if t1 is not None: + del t1 + if t2 is not None: + del t2 + if diff is not None: # Now 'diff' might be defined or not + del diff + + # Aggressive garbage collection and cache clearing + gc.collect() + if device.type == 'cuda': + torch.cuda.empty_cache() + elif device.type == 'mps': + try: # Newer pytorch versions have empty_cache for mps + torch.mps.empty_cache() + except AttributeError: + pass # Ignore if not available + + # Return the calculated values if comparison was successful + return is_close, mean_abs_diff, max_abs_diff, cosine_sim + + +def compare_models(file1_path: Path, file2_path: Path, device_str: str, atol: float, top_n_diff: int): + """ + Compares two safetensors models weight by weight. + + Args: + file1_path: Path to the first model file. + file2_path: Path to the second model file. + device_str: Device string ('cpu', 'cuda', 'mps'). + atol: Absolute tolerance for closeness check. + top_n_diff: Number of most different tensors to report. + """ + if not file1_path.is_file(): + logging.error(f"File not found: {file1_path}") + return + if not file2_path.is_file(): + logging.error(f"File not found: {file2_path}") + return + + try: + device = torch.device(device_str) + logging.info(f"Using device: {device}") + except Exception as e: + logging.warning(f"Could not select device '{device_str}', falling back to CPU. Error: {e}") + device = torch.device("cpu") + + logging.info(f"Comparing Model 1: {file1_path.name}") + logging.info(f" Model 2: {file2_path.name}") + logging.info(f"Absolute tolerance (atol) for closeness: {atol}") + + try: + keys1 = get_tensor_keys(file1_path) + keys2 = get_tensor_keys(file2_path) + except Exception: + return # Error already logged by get_tensor_keys + + common_keys = sorted(list(keys1.intersection(keys2))) + unique_keys1 = sorted(list(keys1 - keys2)) + unique_keys2 = sorted(list(keys2 - keys1)) + + logging.info(f"Found {len(common_keys)} common tensor keys.") + if unique_keys1: + logging.warning(f"{len(unique_keys1)} keys unique to Model 1 ({file1_path.name}): {unique_keys1[:10]}{'...' if len(unique_keys1) > 10 else ''}") + if unique_keys2: + logging.warning(f"{len(unique_keys2)} keys unique to Model 2 ({file2_path.name}): {unique_keys2[:10]}{'...' if len(unique_keys2) > 10 else ''}") + + if not common_keys: + logging.error("No common keys found between models. Cannot compare.") + return + + results: List[Dict[str, Any]] = [] + close_count = 0 + compared_count = 0 # Track how many comparisons were attempted + valid_comparisons = 0 # Track successful comparisons with numerical results + mismatched_shape_keys = [] + comparison_error_keys = [] + + all_mean_abs_diffs = [] + all_max_abs_diffs = [] + all_cosine_sims = [] + + logging.info("Starting tensor comparison...") + for key in tqdm(common_keys, desc="Comparing Tensors"): + compared_count += 1 + is_close, mean_ad, max_ad, cos_sim = compare_tensors( + key, file1_path, file2_path, device, atol + ) + + # Check for comparison failure (NaN or -2) + if math.isnan(mean_ad) or math.isnan(max_ad) or cos_sim == -2.0: + # Check if it was specifically a shape mismatch (common case) + # Re-check shapes briefly - less efficient but simple for logging + try: + with safe_open(file1_path, framework="pt", device="cpu") as f1, \ + safe_open(file2_path, framework="pt", device="cpu") as f2: + t1_shape = f1.get_shape(key) + t2_shape = f2.get_shape(key) + if t1_shape != t2_shape: + mismatched_shape_keys.append(key) + else: + comparison_error_keys.append(key) # Other error + except Exception: + comparison_error_keys.append(key) # Error getting shapes or other issue + + logging.debug(f"Skipping results aggregation for key '{key}' due to comparison errors/mismatch.") + continue # Skip adding results for this key + + # If we reach here, comparison was numerically successful + valid_comparisons += 1 + if is_close: + close_count += 1 + all_mean_abs_diffs.append(mean_ad) + all_max_abs_diffs.append(max_ad) + # Store cosine similarity if validly computed (-1 means computation issue like 0 vector) + if cos_sim >= -1.0: # Allow -1 (error during calc) but not -2 (no calc attempted/major error) + all_cosine_sims.append(cos_sim) + + results.append({ + "key": key, + "is_close": is_close, + "mean_abs_diff": mean_ad, + "max_abs_diff": max_ad, + "cosine_sim": cos_sim + }) + + + # --- Summary --- + logging.info("\n--- Comparison Summary ---") + logging.info(f"Attempted comparison for {compared_count} common keys.") + if mismatched_shape_keys: + logging.warning(f"Found {len(mismatched_shape_keys)} keys with mismatched shapes (skipped): {mismatched_shape_keys[:5]}{'...' if len(mismatched_shape_keys) > 5 else ''}") + if comparison_error_keys: + logging.error(f"Encountered errors during comparison for {len(comparison_error_keys)} keys (skipped): {comparison_error_keys[:5]}{'...' if len(comparison_error_keys) > 5 else ''}") + + + if valid_comparisons == 0: + logging.error("No tensors could be validly compared numerically (check for shape mismatches or errors).") + return + + logging.info(f"Successfully compared {valid_comparisons} tensors numerically.") + logging.info(f"Tensors within tolerance (atol={atol}): {close_count} / {valid_comparisons} ({close_count/valid_comparisons:.2%})") + + # Calculate overall stats only on valid comparisons + avg_mean_ad = np.mean(all_mean_abs_diffs) if all_mean_abs_diffs else float('nan') + avg_max_ad = np.mean(all_max_abs_diffs) if all_max_abs_diffs else float('nan') + overall_max_ad = np.max(all_max_abs_diffs) if all_max_abs_diffs else float('nan') + overall_max_ad_key = max(results, key=lambda x: x.get('max_abs_diff', -float('inf')))['key'] if results else 'N/A' + + # Filter out -1 cosine sims before calculating stats if desired, or include them + valid_cosine_sims = [cs for cs in all_cosine_sims if cs >= 0] # Only positive sims for avg/min + avg_cosine_sim = np.mean(valid_cosine_sims) if valid_cosine_sims else float('nan') + min_cosine_sim = np.min(valid_cosine_sims) if valid_cosine_sims else float('nan') + + + logging.info(f"Average Mean Absolute Difference (MAD): {avg_mean_ad:.6g}") + logging.info(f"Average Max Absolute Difference: {avg_max_ad:.6g}") + logging.info(f"Overall Maximum Absolute Difference: {overall_max_ad:.6g} (found in tensor '{overall_max_ad_key}')") + logging.info(f"Average Cosine Similarity (valid>=0): {avg_cosine_sim:.6f}" if not math.isnan(avg_cosine_sim) else "Average Cosine Similarity (valid>=0): N/A") + logging.info(f"Minimum Cosine Similarity (valid>=0): {min_cosine_sim:.6f}" if not math.isnan(min_cosine_sim) else "Minimum Cosine Similarity (valid>=0): N/A") + + + # --- Top Differences --- + # Sort by max absolute difference descending (handle potential missing keys) + results.sort(key=lambda x: x.get("max_abs_diff", -float('inf')), reverse=True) + + logging.info(f"\n--- Top {min(top_n_diff, len(results))} Tensors by Max Absolute Difference (Numerically Compared Only) ---") + for i in range(min(top_n_diff, len(results))): + res = results[i] + # Ensure keys exist before accessing + key = res.get('key', 'ERROR_MISSING_KEY') + max_ad_val = res.get('max_abs_diff', float('nan')) + mean_ad_val = res.get('mean_abs_diff', float('nan')) + cos_sim_val = res.get('cosine_sim', float('nan')) + close_val = res.get('is_close', 'N/A') + + logging.info( + f"{i+1}. Key: {key:<50} " + f"MaxAD: {max_ad_val:.6g} | " + f"MeanAD: {mean_ad_val:.6g} | " + f"CosSim: {cos_sim_val:.4f} | " + f"Close: {close_val}" + ) + + # --- Interpretation for LoRA --- + logging.info("\n--- LoRA Compatibility Interpretation ---") + # Prioritize architectural differences + if unique_keys1 or unique_keys2 or mismatched_shape_keys: + logging.error("Models have architectural differences (unique keys or mismatched shapes found). Direct LoRA swapping is NOT recommended.") + if unique_keys1 or unique_keys2: + logging.warning(" - Different sets of weights exist.") + if mismatched_shape_keys: + logging.warning(f" - Mismatched shapes found for keys like: {mismatched_shape_keys[0]}") + elif comparison_error_keys: + logging.warning("Some tensors could not be compared due to errors (other than shape mismatch). Check logs. LoRA compatibility might be affected.") + else: + # Assess based on numerical differences if architecture matches + logging.info("Models appear to have the same architecture (matching keys and shapes). Assessing numerical similarity:") + if avg_mean_ad < 1e-5 and overall_max_ad < 1e-3: + logging.info(" -> Differences are very small. Models appear highly similar. High LoRA compatibility expected.") + elif avg_mean_ad < 1e-4 and overall_max_ad < 5e-3: + logging.info(" -> Differences are small. Models appear quite similar. Good LoRA compatibility expected.") + elif avg_mean_ad < 1e-3 and overall_max_ad < 1e-2: + logging.info(" -> Moderate differences detected. LoRAs might work but performance could vary, especially if targeting layers with larger differences.") + else: + logging.warning(" -> Significant numerical differences detected (Average MAD > 1e-3 or Overall MaxAD > 0.01). LoRA compatibility is questionable. Performance may degrade even with matching architecture.") + + if not math.isnan(min_cosine_sim) and min_cosine_sim < 0.98: # Stricter threshold for matching architecture + logging.warning(f" -> Some tensors have lower cosine similarity (min >= 0: {min_cosine_sim:.4f}), indicating potential directional differences. This could affect LoRA.") + + +def main(): + parser = argparse.ArgumentParser( + description="Compare weights between two safetensors model files.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "model1_path", type=str, help="Path to the first .safetensors model file." + ) + parser.add_argument( + "model2_path", type=str, help="Path to the second .safetensors model file." + ) + parser.add_argument( + "--device", + type=str, + default=DEFAULT_DEVICE, + choices=["cpu", "cuda", "mps"], + help="Device to use for tensor comparisons ('cuda'/'mps' recommended if available).", + ) + parser.add_argument( + "--atol", + type=float, + default=1e-4, # A reasonable default for FP16 comparison + help="Absolute tolerance (atol) for torch.allclose check to consider tensors 'close'.", + ) + parser.add_argument( + "--top_n_diff", + type=int, + default=10, + help="Report details for the top N tensors with the largest maximum absolute difference.", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable debug logging." + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + compare_models( + Path(args.model1_path), + Path(args.model2_path), + args.device, + args.atol, + args.top_n_diff, + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/convert_hunyuan_to_framepack.py b/convert_hunyuan_to_framepack.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa9eb3e5caa613e90a1e9b9d74cfb40950f5826 --- /dev/null +++ b/convert_hunyuan_to_framepack.py @@ -0,0 +1,203 @@ +# convert_lora.py +import argparse +import os +import re +import torch +from safetensors.torch import load_file, save_file +import logging + +# Configure logging similar to the utility file +logger = logging.getLogger(__name__) +# Avoid re-configuring if basicConfig was already called by the imported module +if not logging.root.handlers: + logging.basicConfig(level=logging.INFO) + +# Assuming framepack_lora_inf_utils.py is in the same directory +try: + from framepack_lora_inf_utils import ( + convert_hunyuan_to_framepack, + convert_from_diffusion_pipe_or_something, + ) +except ImportError: + logger.error("Error: Could not import conversion functions from framepack_lora_inf_utils.") + logger.error("Please make sure framepack_lora_inf_utils.py is in the same directory as this script.") + exit(1) + +def main(): + """ + Main function to parse arguments and perform the LoRA conversion, + detecting the input format (Hunyuan or Diffusion Pipe-like). + """ + parser = argparse.ArgumentParser(description="Convert various LoRA formats to FramePack format.") + parser.add_argument( + "--input_lora", + type=str, + required=True, + help="Path to the input LoRA .safetensors file (Hunyuan, Diffusion Pipe-like, or Musubi).", + ) + parser.add_argument( + "--output_lora", + type=str, + required=True, + help="Path to save the converted FramePack LoRA .safetensors file.", + ) + + args = parser.parse_args() + + input_file = args.input_lora + output_file = args.output_lora + + # Validate input file + if not os.path.exists(input_file): + logger.error(f"Input file not found: {input_file}") + exit(1) + if not input_file.lower().endswith(".safetensors"): + logger.warning(f"Input file '{input_file}' does not end with .safetensors. Proceeding anyway.") + + # Ensure output directory exists + output_dir = os.path.dirname(output_file) + if output_dir and not os.path.exists(output_dir): + try: + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Created output directory: {output_dir}") + except OSError as e: + logger.error(f"Error creating output directory {output_dir}: {e}") + exit(1) + + # Ensure output file ends with .safetensors + if not output_file.lower().endswith(".safetensors"): + output_file += ".safetensors" + logger.warning(f"Output file name did not end with .safetensors. Appended: {output_file}") + + logger.info(f"Loading input LoRA file: {input_file}") + loaded_lora_sd = None + try: + # Load the state dictionary from the input .safetensors file + loaded_lora_sd = load_file(input_file) + logger.info(f"Input LoRA loaded successfully. Found {len(loaded_lora_sd)} keys.") + except Exception as e: + logger.error(f"Error loading input LoRA file {input_file}: {e}") + exit(1) + + # --- Determine LoRA format and apply conversion(s) --- + # Following the logic flow from merge_lora_to_state_dict + + processed_lora_sd = None # This will hold the SD after potential conversions + lora_keys = list(loaded_lora_sd.keys()) if loaded_lora_sd else [] + + if not lora_keys: + logger.error("Input LoRA file was empty or failed to load keys.") + exit(1) + + # 1. Check if it's Musubi Tuner format (first key starts with "lora_unet_") + is_musubi = lora_keys[0].startswith("lora_unet_") + if is_musubi: + logger.info("Detected Musubi Tuner format based on first key.") + # Keep the original SD for now, as Musubi format might still contain Hunyuan patterns + current_lora_sd_to_check = loaded_lora_sd + else: + # 2. If not Musubi (based on first key), check for Diffusion Pipe format + diffusion_pipe_pattern_found = False + transformer_prefixes = ["diffusion_model", "transformer"] + lora_suffix_A_or_B_found = False + + # Find the first key with .lora_A or .lora_B and check its prefix + for key in lora_keys: + if ".lora_A." in key or ".lora_B." in key: + lora_suffix_A_or_B_found = True + pfx = key.split(".")[0] + if pfx in transformer_prefixes: + diffusion_pipe_pattern_found = True + break # Found the required pattern + + if diffusion_pipe_pattern_found: + logger.info(f"Detected Diffusion Pipe (?) format based on keys like '{pfx}.*.lora_A/B.'. Attempting conversion...") + target_prefix_for_diffusers_conversion = "lora_unet_" + try: + # Apply the Diffusion Pipe conversion + current_lora_sd_to_check = convert_from_diffusion_pipe_or_something(loaded_lora_sd, target_prefix_for_diffusers_conversion) + logger.info("Converted from Diffusion Pipe format.") + except Exception as e: + logger.error(f"Error during Diffusion Pipe conversion: {e}", exc_info=True) # Log traceback + current_lora_sd_to_check = None # Conversion failed, treat as unprocessable + else: + # If not Musubi and not Diffusion Pipe, the format is unrecognized initially + logger.warning(f"LoRA file format not recognized based on common patterns (Musubi, Diffusion Pipe-like). Checking for Hunyuan anyway...") + current_lora_sd_to_check = loaded_lora_sd # Keep the original SD to check for Hunyuan keys next + + + # 3. Check for Hunyuan pattern (double_blocks/single_blocks) in the *current* state dict + if current_lora_sd_to_check is not None: + keys_for_hunyuan_check = list(current_lora_sd_to_check.keys()) + is_hunyuan_pattern_found = any("double_blocks" in key or "single_blocks" in key for key in keys_for_hunyuan_check) + + if is_hunyuan_pattern_found: + logger.info("Detected HunyuanVideo format based on keys (double_blocks/single_blocks). Attempting conversion...") + try: + # Apply the Hunyuan conversion + processed_lora_sd = convert_hunyuan_to_framepack(current_lora_sd_to_check) + logger.info("Converted from HunyuanVideo format.") + except Exception as e: + logger.error(f"Error during HunyuanVideo conversion: {e}", exc_info=True) # Log traceback + processed_lora_sd = None # Conversion failed, treat as unprocessable + else: + # If Hunyuan pattern is not found, the current_lora_sd_to_check is the final state + # (either the original Musubi SD, or the SD converted from Diffusion Pipe). + processed_lora_sd = current_lora_sd_to_check + if not is_musubi and not diffusion_pipe_pattern_found: + # If we reached here and neither Musubi nor Diffusion Pipe patterns were initially found, + # and no Hunyuan pattern was found either, then the format is truly unrecognized. + logger.warning("Input LoRA does not match Musubi, Diffusion Pipe-like, or Hunyuan patterns.") + # Log keys to help debugging unrecognized formats + logger.info(f"Input LoRA keys start with: {lora_keys[:20]}...") # Show first few keys + processed_lora_sd = None # Mark as unprocessable + + + # --- Final check and saving --- + if processed_lora_sd is None or not processed_lora_sd: + logger.error("Could not convert the input LoRA file to a recognized FramePack format.") + logger.error("Input file format not recognized or conversion failed.") + # Log keys if conversion didn't happen due to format not matching + if loaded_lora_sd is not None: + logger.info(f"Input LoRA keys start with: {lora_keys[:20]}...") # Show first few keys + exit(1) # Exit if conversion failed or no data resulted + + logger.info(f"Conversion complete. Converted state dictionary contains {len(processed_lora_sd)} keys.") + logger.info(f"Saving converted LoRA file to: {output_file}") + + # --- WORKAROUND for older safetensors versions that don't support allow_shared=True --- + # The conversion functions might create shared tensors in the dictionary. + # Older safetensors versions require tensors to be distinct objects for save_file. + # We create a deep copy of tensors to satisfy this requirement. + # The recommended fix is to upgrade safetensors (pip install --upgrade safetensors) + # and use allow_shared=True in save_file. + logger.info("Checking for shared tensors and creating copies for saving (workaround for older safetensors)...") + processed_lora_sd_copy = {} + for key, tensor in processed_lora_sd.items(): + if isinstance(tensor, torch.Tensor): + # Create a new tensor with copied data, detached from any computation graph + processed_lora_sd_copy[key] = tensor.clone().detach() + else: + # Keep non-tensor items (like alpha which might be a scalar number) as is + processed_lora_sd_copy[key] = tensor + logger.info("Deep copy complete.") + # --- END OF WORKAROUND --- + + + try: + # Save using the deep-copied dictionary. + # This works with older safetensors versions (pre-0.3.0) + # If you upgraded safetensors to 0.3.0 or later, you could use: + # save_file(processed_lora_sd, output_file, allow_shared=True) + save_file(processed_lora_sd_copy, output_file) + + logger.info(f"Successfully saved converted LoRA to {output_file}") + except Exception as e: + # Note: If you still get a shared memory error here, it implies the deep copy + # workaround didn't fully resolve it for your specific setup, OR the error + # is coming from a different source. Upgrading safetensors is then highly recommended. + logger.error(f"Error saving converted LoRA file {output_file}: {e}", exc_info=True) # Log traceback + exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/convert_lora.py b/convert_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..6b50a4217e05b982942a52487690a7c4aa8d8ac0 --- /dev/null +++ b/convert_lora.py @@ -0,0 +1,135 @@ +import argparse + +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from utils import model_utils + +import logging + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def convert_from_diffusers(prefix, weights_sd): + # convert from diffusers(?) to default LoRA + # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...} + # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...} + + # note: Diffusers has no alpha, so alpha is set to rank + new_weights_sd = {} + lora_dims = {} + for key, weight in weights_sd.items(): + diffusers_prefix, key_body = key.split(".", 1) + if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer": + logger.warning(f"unexpected key: {key} in diffusers format") + continue + + new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.") + new_weights_sd[new_key] = weight + + lora_name = new_key.split(".")[0] # before first dot + if lora_name not in lora_dims and "lora_down" in new_key: + lora_dims[lora_name] = weight.shape[0] + + # add alpha with rank + for lora_name, dim in lora_dims.items(): + new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim) + + return new_weights_sd + + +def convert_to_diffusers(prefix, weights_sd): + # convert from default LoRA to diffusers + + # get alphas + lora_alphas = {} + for key, weight in weights_sd.items(): + if key.startswith(prefix): + lora_name = key.split(".", 1)[0] # before first dot + if lora_name not in lora_alphas and "alpha" in key: + lora_alphas[lora_name] = weight + + new_weights_sd = {} + for key, weight in weights_sd.items(): + if key.startswith(prefix): + if "alpha" in key: + continue + + lora_name = key.split(".", 1)[0] # before first dot + + module_name = lora_name[len(prefix) :] # remove "lora_unet_" + module_name = module_name.replace("_", ".") # replace "_" with "." + if ".cross.attn." in module_name or ".self.attn." in module_name: + # Wan2.1 lora name to module name: ugly but works + module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn + module_name = module_name.replace("self.attn", "self_attn") # fix self attn + else: + # HunyuanVideo lora name to module name: ugly but works + module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks + module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks + module_name = module_name.replace("img.", "img_") # fix img + module_name = module_name.replace("txt.", "txt_") # fix txt + module_name = module_name.replace("attn.", "attn_") # fix attn + + diffusers_prefix = "diffusion_model" + if "lora_down" in key: + new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight" + dim = weight.shape[0] + elif "lora_up" in key: + new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight" + dim = weight.shape[1] + else: + logger.warning(f"unexpected key: {key} in default LoRA format") + continue + + # scale weight by alpha + if lora_name in lora_alphas: + # we scale both down and up, so scale is sqrt + scale = lora_alphas[lora_name] / dim + scale = scale.sqrt() + weight = weight * scale + else: + logger.warning(f"missing alpha for {lora_name}") + + new_weights_sd[new_key] = weight + + return new_weights_sd + + +def convert(input_file, output_file, target_format): + logger.info(f"loading {input_file}") + weights_sd = load_file(input_file) + with safe_open(input_file, framework="pt") as f: + metadata = f.metadata() + + logger.info(f"converting to {target_format}") + prefix = "lora_unet_" + if target_format == "default": + new_weights_sd = convert_from_diffusers(prefix, weights_sd) + metadata = metadata or {} + model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata) + elif target_format == "other": + new_weights_sd = convert_to_diffusers(prefix, weights_sd) + else: + raise ValueError(f"unknown target format: {target_format}") + + logger.info(f"saving to {output_file}") + save_file(new_weights_sd, output_file, metadata=metadata) + + logger.info("done") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats") + parser.add_argument("--input", type=str, required=True, help="input model file") + parser.add_argument("--output", type=str, required=True, help="output model file") + parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + convert(args.input, args.output, args.target) diff --git a/createnoise.py b/createnoise.py new file mode 100644 index 0000000000000000000000000000000000000000..fa40fae49fb141908ba1f09980b74592535d6c91 --- /dev/null +++ b/createnoise.py @@ -0,0 +1,118 @@ +import os +import numpy as np +from PIL import Image +import ffmpeg +import argparse + +def add_noise(image: np.ndarray, noise_level: float) -> np.ndarray: + """ + Add Gaussian noise to an image with increasing intensity. + + Args: + image: Input image as numpy array (0-255) + noise_level: Amount of noise to add (0-1) + """ + # Convert to float for calculations + img_float = image.astype(float) + + # Generate noise + noise = np.random.normal(0, 255 * noise_level, image.shape) + + # Add noise to image + noisy_image = img_float + noise + + # Clip values to valid range + noisy_image = np.clip(noisy_image, 0, 255) + + return noisy_image.astype(np.uint8) + +def create_noise_sequence(input_image_path: str, output_folder: str, num_frames: int = 201): + """ + Create a sequence of increasingly noisy images. + + Args: + input_image_path: Path to the input image + output_folder: Folder to save the sequence + num_frames: Number of frames to generate + """ + # Create output folder if it doesn't exist + os.makedirs(output_folder, exist_ok=True) + + # Load the image + image = Image.open(input_image_path) + image_array = np.array(image) + + # Calculate noise levels + # First 5 frames are clean, then noise increases to 0.25 + noise_levels = np.zeros(num_frames) + if num_frames > 5: + noise_levels[5:] = np.linspace(0, 0.50, num_frames - 5) + + # Generate and save frames + for i, noise_level in enumerate(noise_levels): + # First 5 frames are the original image + if i < 5: + noisy_image = image_array + else: + noisy_image = add_noise(image_array, noise_level) + + # Save frame + output_path = os.path.join(output_folder, f"frame_{i:03d}.png") + Image.fromarray(noisy_image).save(output_path) + + # Print progress + if (i + 1) % 20 == 0: + print(f"Generated {i + 1}/{num_frames} frames") + +def create_video(image_folder: str, output_path: str, fps: int = 24): + """ + Create a video from a sequence of images using ffmpeg command line. + + Args: + image_folder: Folder containing the image sequence + output_path: Path for the output video + fps: Frames per second for the video + """ + input_pattern = os.path.join(image_folder, 'frame_%03d.png') + + # Construct ffmpeg command + cmd = [ + 'ffmpeg', + '-y', # Overwrite output file if it exists + '-framerate', str(fps), + '-i', input_pattern, + '-c:v', 'libx264', + '-pix_fmt', 'yuv420p', + '-preset', 'medium', + '-crf', '23', + output_path + ] + + # Run ffmpeg command + import subprocess + try: + subprocess.run(cmd, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + print(f"Error creating video: {e.stderr.decode()}") + raise + +def main(): + parser = argparse.ArgumentParser(description="Generate a sequence of increasingly noisy images and create a video") + parser.add_argument("input_image", help="Path to input image") + parser.add_argument("--output_folder", default="noise_sequence", help="Output folder for image sequence") + parser.add_argument("--output_video", default="noise_sequence.mp4", help="Output video path") + parser.add_argument("--fps", type=int, default=24, help="Frames per second for the video") + parser.add_argument("--frames", type=int, default=201, help="Number of frames to generate") + + args = parser.parse_args() + + print("Generating noise sequence...") + create_noise_sequence(args.input_image, args.output_folder, args.frames) + + print("Creating video...") + create_video(args.output_folder, args.output_video, args.fps) + + print(f"Video saved to {args.output_video}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dataset/config_utils.py b/dataset/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..229b12373d569c3438cba23be101054b631634af --- /dev/null +++ b/dataset/config_utils.py @@ -0,0 +1,371 @@ +import argparse +from dataclasses import ( + asdict, + dataclass, +) +import functools +import random +from textwrap import dedent, indent +import json +from pathlib import Path + +# from toolz import curry +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import toml +import voluptuous +from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema + +from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +@dataclass +class BaseDatasetParams: + resolution: Tuple[int, int] = (960, 544) + enable_bucket: bool = False + bucket_no_upscale: bool = False + caption_extension: Optional[str] = None + batch_size: int = 1 + num_repeats: int = 1 + cache_directory: Optional[str] = None + debug_dataset: bool = False + + +@dataclass +class ImageDatasetParams(BaseDatasetParams): + image_directory: Optional[str] = None + image_jsonl_file: Optional[str] = None + + +@dataclass +class VideoDatasetParams(BaseDatasetParams): + video_directory: Optional[str] = None + video_jsonl_file: Optional[str] = None + target_frames: Sequence[int] = (1,) + frame_extraction: Optional[str] = "head" + frame_stride: Optional[int] = 1 + frame_sample: Optional[int] = 1 + + +@dataclass +class DatasetBlueprint: + is_image_dataset: bool + params: Union[ImageDatasetParams, VideoDatasetParams] + + +@dataclass +class DatasetGroupBlueprint: + datasets: Sequence[DatasetBlueprint] + + +@dataclass +class Blueprint: + dataset_group: DatasetGroupBlueprint + + +class ConfigSanitizer: + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) + + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "batch_size": int, + "num_repeats": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "enable_bucket": bool, + "bucket_no_upscale": bool, + } + IMAGE_DATASET_DISTINCT_SCHEMA = { + "image_directory": str, + "image_jsonl_file": str, + "cache_directory": str, + } + VIDEO_DATASET_DISTINCT_SCHEMA = { + "video_directory": str, + "video_jsonl_file": str, + "target_frames": [int], + "frame_extraction": str, + "frame_stride": int, + "frame_sample": int, + "cache_directory": str, + } + + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + } + + def __init__(self) -> None: + self.image_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.IMAGE_DATASET_DISTINCT_SCHEMA, + ) + self.video_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.VIDEO_DATASET_DISTINCT_SCHEMA, + ) + + def validate_flex_dataset(dataset_config: dict): + if "target_frames" in dataset_config: + return Schema(self.video_dataset_schema)(dataset_config) + else: + return Schema(self.image_dataset_schema)(dataset_config) + + self.dataset_schema = validate_flex_dataset + + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + ) + self.user_config_validator = Schema( + { + "general": self.general_schema, + "datasets": [self.dataset_schema], + } + ) + self.argparse_schema = self.__merge_dict( + self.ARGPARSE_SPECIFIC_SCHEMA, + ) + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: clarify the error message + logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") + raise + + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged + + +class BlueprintGenerator: + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} + + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer + + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + + argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None} + general_config = sanitized_user_config.get("general", {}) + + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + is_image_dataset = "target_frames" not in dataset_config + if is_image_dataset: + dataset_params_klass = ImageDatasetParams + else: + dataset_params_klass = VideoDatasetParams + + params = self.generate_params_by_fallbacks( + dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] + ) + dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params)) + + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + + return Blueprint(dataset_group_blueprint) + + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() + + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + + return param_klass(**params) + + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value=None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value + + return default_value + + +# if training is True, it will return a dataset group for training, otherwise for caching +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup: + datasets: List[Union[ImageDataset, VideoDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_image_dataset: + dataset_klass = ImageDataset + else: + dataset_klass = VideoDataset + + dataset = dataset_klass(**asdict(dataset_blueprint.params)) + datasets.append(dataset) + + # assertion + cache_directories = [dataset.cache_directory for dataset in datasets] + num_of_unique_cache_directories = len(set(cache_directories)) + if num_of_unique_cache_directories != len(cache_directories): + raise ValueError( + "cache directory should be unique for each dataset (note that cache directory is image/video directory if not specified)" + + " / cache directory は各データセットごとに異なる必要があります(指定されていない場合はimage/video directoryが使われるので注意)" + ) + + # print info + info = "" + for i, dataset in enumerate(datasets): + is_image_dataset = isinstance(dataset, ImageDataset) + info += dedent( + f"""\ + [Dataset {i}] + is_image_dataset: {is_image_dataset} + resolution: {dataset.resolution} + batch_size: {dataset.batch_size} + num_repeats: {dataset.num_repeats} + caption_extension: "{dataset.caption_extension}" + enable_bucket: {dataset.enable_bucket} + bucket_no_upscale: {dataset.bucket_no_upscale} + cache_directory: "{dataset.cache_directory}" + debug_dataset: {dataset.debug_dataset} + """ + ) + + if is_image_dataset: + info += indent( + dedent( + f"""\ + image_directory: "{dataset.image_directory}" + image_jsonl_file: "{dataset.image_jsonl_file}" + \n""" + ), + " ", + ) + else: + info += indent( + dedent( + f"""\ + video_directory: "{dataset.video_directory}" + video_jsonl_file: "{dataset.video_jsonl_file}" + target_frames: {dataset.target_frames} + frame_extraction: {dataset.frame_extraction} + frame_stride: {dataset.frame_stride} + frame_sample: {dataset.frame_sample} + \n""" + ), + " ", + ) + logger.info(f"{info}") + + # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): + # logger.info(f"[Dataset {i}]") + dataset.set_seed(seed) + if training: + dataset.prepare_for_training() + + return DatasetGroup(datasets) + + +def load_user_config(file: str) -> dict: + file: Path = Path(file) + if not file.is_file(): + raise ValueError(f"file not found / ファイルが見つかりません: {file}") + + if file.name.lower().endswith(".json"): + try: + with open(file, "r", encoding="utf-8") as f: + config = json.load(f) + except Exception: + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + elif file.name.lower().endswith(".toml"): + try: + config = toml.load(file) + except Exception: + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + + return config + + +# for config test +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() + + parser = argparse.ArgumentParser() + parser.add_argument("--debug_dataset", action="store_true") + argparse_namespace = parser.parse_args(remain) + + logger.info("[argparse_namespace]") + logger.info(f"{vars(argparse_namespace)}") + + user_config = load_user_config(config_args.dataset_config) + + logger.info("") + logger.info("[user_config]") + logger.info(f"{user_config}") + + sanitizer = ConfigSanitizer() + sanitized_user_config = sanitizer.sanitize_user_config(user_config) + + logger.info("") + logger.info("[sanitized_user_config]") + logger.info(f"{sanitized_user_config}") + + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + + logger.info("") + logger.info("[blueprint]") + logger.info(f"{blueprint}") + + dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group) diff --git a/dataset/dataset_config.md b/dataset/dataset_config.md new file mode 100644 index 0000000000000000000000000000000000000000..436e1bc195a433f9daa1e5e7715bcbfedcbc331e --- /dev/null +++ b/dataset/dataset_config.md @@ -0,0 +1,387 @@ +> 📝 Click on the language section to expand / 言語をクリックして展開 + +## Dataset Configuration + +
+English + +Please create a TOML file for dataset configuration. + +Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files. + +The cache directory must be different for each dataset. +
+ +
+日本語 + +データセットの設定を行うためのTOMLファイルを作成してください。 + +画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。 + +キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。 +
+ +### Sample for Image Dataset with Caption Text Files + +```toml +# resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets +# otherwise, the default values will be used for each item + +# general configurations +[general] +resolution = [960, 544] +caption_extension = ".txt" +batch_size = 1 +enable_bucket = true +bucket_no_upscale = false + +[[datasets]] +image_directory = "/path/to/image_dir" +cache_directory = "/path/to/cache_directory" +num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes. + +# other datasets can be added here. each dataset can have different configurations +``` + +
+English + +`cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets. + +`num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes. + +
+ +
+日本語 + +`cache_directory` はオプションです。デフォルトは画像ディレクトリと同じディレクトリに設定されます。ただし、異なるデータセット間でキャッシュファイルが共有されるのを防ぐために、明示的に別のキャッシュディレクトリを設定することをお勧めします。 + +`num_repeats` はオプションで、デフォルトは 1 です(繰り返しなし)。画像(や動画)を、その回数だけ単純に繰り返してデータセットを拡張します。たとえば`num_repeats = 2`としたとき、画像20枚のデータセットなら、各画像が2枚ずつ(同一のキャプションで)計40枚存在した場合と同じになります。異なるデータ数のデータセット間でバランスを取るために使用可能です。 + +resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。 + +`[[datasets]]`以下を追加することで、他のデータセットを追加できます。各データセットには異なる設定を持てます。 +
+ +### Sample for Image Dataset with Metadata JSONL File + +```toml +# resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets +# caption_extension is not required for metadata jsonl file +# cache_directory is required for each dataset with metadata jsonl file + +# general configurations +[general] +resolution = [960, 544] +batch_size = 1 +enable_bucket = true +bucket_no_upscale = false + +[[datasets]] +image_jsonl_file = "/path/to/metadata.jsonl" +cache_directory = "/path/to/cache_directory" # required for metadata jsonl file +num_repeats = 1 # optional, default is 1. Same as above. + +# other datasets can be added here. each dataset can have different configurations +``` + +JSONL file format for metadata: + +```json +{"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"} +{"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"} +``` + +
+日本語 + +resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。 + +metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。 + +キャプションによるデータセットと同様に、複数のデータセットを追加できます。各データセットには異なる設定を持てます。 +
+ + +### Sample for Video Dataset with Caption Text Files + +```toml +# resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, +# batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets +# num_repeats is also available for video dataset, example is not shown here + +# general configurations +[general] +resolution = [960, 544] +caption_extension = ".txt" +batch_size = 1 +enable_bucket = true +bucket_no_upscale = false + +[[datasets]] +video_directory = "/path/to/video_dir" +cache_directory = "/path/to/cache_directory" # recommended to set cache directory +target_frames = [1, 25, 45] +frame_extraction = "head" + +# other datasets can be added here. each dataset can have different configurations +``` + +
+日本語 + +resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。 + +他の注意事項は画像データセットと同様です。 +
+ +### Sample for Video Dataset with Metadata JSONL File + +```toml +# resolution, target_frames, frame_extraction, frame_stride, frame_sample, +# batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets +# caption_extension is not required for metadata jsonl file +# cache_directory is required for each dataset with metadata jsonl file + +# general configurations +[general] +resolution = [960, 544] +batch_size = 1 +enable_bucket = true +bucket_no_upscale = false + +[[datasets]] +video_jsonl_file = "/path/to/metadata.jsonl" +target_frames = [1, 25, 45] +frame_extraction = "head" +cache_directory = "/path/to/cache_directory_head" + +# same metadata jsonl file can be used for multiple datasets +[[datasets]] +video_jsonl_file = "/path/to/metadata.jsonl" +target_frames = [1] +frame_stride = 10 +cache_directory = "/path/to/cache_directory_stride" + +# other datasets can be added here. each dataset can have different configurations +``` + +JSONL file format for metadata: + +```json +{"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"} +{"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"} +``` + +
+日本語 + +resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。 + +metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。 + +他の注意事項は今までのデータセットと同様です。 +
+ +### frame_extraction Options + +
+English + +- `head`: Extract the first N frames from the video. +- `chunk`: Extract frames by splitting the video into chunks of N frames. +- `slide`: Extract frames from the video with a stride of `frame_stride`. +- `uniform`: Extract `frame_sample` samples uniformly from the video. + +For example, consider a video with 40 frames. The following diagrams illustrate each extraction: +
+ +
+日本語 + +- `head`: 動画から最初のNフレームを抽出します。 +- `chunk`: 動画をNフレームずつに分割してフレームを抽出します。 +- `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。 +- `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。 + +例えば、40フレームの動画を例とした抽出について、以下の図で説明します。 +
+ +``` +Original Video, 40 frames: x = frame, o = no frame +oooooooooooooooooooooooooooooooooooooooo + +head, target_frames = [1, 13, 25] -> extract head frames: +xooooooooooooooooooooooooooooooooooooooo +xxxxxxxxxxxxxooooooooooooooooooooooooooo +xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo + +chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames: +xxxxxxxxxxxxxooooooooooooooooooooooooooo +oooooooooooooxxxxxxxxxxxxxoooooooooooooo +ooooooooooooooooooooooooooxxxxxxxxxxxxxo +xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo + +NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted. +注: frame_extraction "chunk" を使用する場合、target_frames に 1 を含めないでください。全てのフレームが抽出されてしまいます。 + +slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10: +xooooooooooooooooooooooooooooooooooooooo +ooooooooooxooooooooooooooooooooooooooooo +ooooooooooooooooooooxooooooooooooooooooo +ooooooooooooooooooooooooooooooxooooooooo +xxxxxxxxxxxxxooooooooooooooooooooooooooo +ooooooooooxxxxxxxxxxxxxooooooooooooooooo +ooooooooooooooooooooxxxxxxxxxxxxxooooooo +xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo +ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo + +uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each: +xooooooooooooooooooooooooooooooooooooooo +oooooooooooooxoooooooooooooooooooooooooo +oooooooooooooooooooooooooxoooooooooooooo +ooooooooooooooooooooooooooooooooooooooox +xxxxxxxxxxxxxooooooooooooooooooooooooooo +oooooooooxxxxxxxxxxxxxoooooooooooooooooo +ooooooooooooooooooxxxxxxxxxxxxxooooooooo +oooooooooooooooooooooooooooxxxxxxxxxxxxx +xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo +oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo +ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo +oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx +``` + +## Specifications + +```toml +# general configurations +[general] +resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets +caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets +batch_size = 1 # optional, default is 1. This is the default batch size for all datasets +num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes. +enable_bucket = true # optional, default is false. Enable bucketing for datasets +bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false + +### Image Dataset + +# sample image dataset with caption text files +[[datasets]] +image_directory = "/path/to/image_dir" +caption_extension = ".txt" # required for caption text files, if general caption extension is not set +resolution = [960, 544] # required if general resolution is not set +batch_size = 4 # optional, overwrite the default batch size +num_repeats = 1 # optional, overwrite the default num_repeats +enable_bucket = false # optional, overwrite the default bucketing setting +bucket_no_upscale = true # optional, overwrite the default bucketing setting +cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled + +# sample image dataset with metadata **jsonl** file +[[datasets]] +image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions +resolution = [960, 544] # required if general resolution is not set +cache_directory = "/path/to/cache_directory" # required for metadata jsonl file +# caption_extension is not required for metadata jsonl file +# batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file + +### Video Dataset + +# sample video dataset with caption text files +[[datasets]] +video_directory = "/path/to/video_dir" +caption_extension = ".txt" # required for caption text files, if general caption extension is not set +resolution = [960, 544] # required if general resolution is not set + +target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...) + +# NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted. + +frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head" +frame_stride = 1 # optional, default is 1, available for "slide" frame extraction +frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction +# batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset + +# sample video dataset with metadata jsonl file +[[datasets]] +video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions + +target_frames = [1, 79] + +cache_directory = "/path/to/cache_directory" # required for metadata jsonl file +# frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file +``` + + + +The metadata with .json file will be supported in the near future. + + + + \ No newline at end of file diff --git a/dataset/image_video_dataset.py b/dataset/image_video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..44c0296a785ac25b068bb589b9acf937df0bd112 --- /dev/null +++ b/dataset/image_video_dataset.py @@ -0,0 +1,1321 @@ +from concurrent.futures import ThreadPoolExecutor +import glob +import json +import math +import os +import random +import time +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from safetensors.torch import save_file, load_file +from safetensors import safe_open +from PIL import Image +import cv2 +import av + +from utils import safetensors_utils +from utils.model_utils import dtype_to_str + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] + +try: + import pillow_avif + + IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) +except: + pass + +# JPEG-XL on Linux +try: + from jxlpy import JXLImagePlugin + + IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) +except: + pass + +# JPEG-XL on Windows +try: + import pillow_jxl + + IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) +except: + pass + +VIDEO_EXTENSIONS = [ + ".mp4", + ".webm", + ".avi", + ".mkv", + ".mov", + ".flv", + ".wmv", + ".m4v", + ".mpg", + ".mpeg", + ".MP4", + ".WEBM", + ".AVI", + ".MKV", + ".MOV", + ".FLV", + ".WMV", + ".M4V", + ".MPG", + ".MPEG", +] # some of them are not tested + +ARCHITECTURE_HUNYUAN_VIDEO = "hv" + + +def glob_images(directory, base="*"): + img_paths = [] + for ext in IMAGE_EXTENSIONS: + if base == "*": + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) + else: + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + img_paths = list(set(img_paths)) # remove duplicates + img_paths.sort() + return img_paths + + +def glob_videos(directory, base="*"): + video_paths = [] + for ext in VIDEO_EXTENSIONS: + if base == "*": + video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) + else: + video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + video_paths = list(set(video_paths)) # remove duplicates + video_paths.sort() + return video_paths + + +def divisible_by(num: int, divisor: int) -> int: + return num - num % divisor + + +def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: + """ + Resize the image to the bucket resolution. + """ + is_pil_image = isinstance(image, Image.Image) + if is_pil_image: + image_width, image_height = image.size + else: + image_height, image_width = image.shape[:2] + + if bucket_reso == (image_width, image_height): + return np.array(image) if is_pil_image else image + + bucket_width, bucket_height = bucket_reso + if bucket_width == image_width or bucket_height == image_height: + image = np.array(image) if is_pil_image else image + else: + # resize the image to the bucket resolution to match the short side + scale_width = bucket_width / image_width + scale_height = bucket_height / image_height + scale = max(scale_width, scale_height) + image_width = int(image_width * scale + 0.5) + image_height = int(image_height * scale + 0.5) + + if scale > 1: + image = Image.fromarray(image) if not is_pil_image else image + image = image.resize((image_width, image_height), Image.LANCZOS) + image = np.array(image) + else: + image = np.array(image) if is_pil_image else image + image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) + + # crop the image to the bucket resolution + crop_left = (image_width - bucket_width) // 2 + crop_top = (image_height - bucket_height) // 2 + image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] + return image + + +class ItemInfo: + def __init__( + self, + item_key: str, + caption: str, + original_size: tuple[int, int], + bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None, + frame_count: Optional[int] = None, + content: Optional[np.ndarray] = None, + latent_cache_path: Optional[str] = None, + ) -> None: + self.item_key = item_key + self.caption = caption + self.original_size = original_size + self.bucket_size = bucket_size + self.frame_count = frame_count + self.content = content + self.latent_cache_path = latent_cache_path + self.text_encoder_output_cache_path: Optional[str] = None + + def __str__(self) -> str: + return ( + f"ItemInfo(item_key={self.item_key}, caption={self.caption}, " + + f"original_size={self.original_size}, bucket_size={self.bucket_size}, " + + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})" + ) + + +def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor): + assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)" + + # NaN check and show warning, replace NaN with 0 + if torch.isnan(latent).any(): + logger.warning(f"latent tensor has NaN: {item_info.item_key}, replace NaN with 0") + latent[torch.isnan(latent)] = 0 + + metadata = { + "architecture": "hunyuan_video", + "width": f"{item_info.original_size[0]}", + "height": f"{item_info.original_size[1]}", + "format_version": "1.0.0", + } + if item_info.frame_count is not None: + metadata["frame_count"] = f"{item_info.frame_count}" + + _, F, H, W = latent.shape + dtype_str = dtype_to_str(latent.dtype) + sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()} + + latent_dir = os.path.dirname(item_info.latent_cache_path) + os.makedirs(latent_dir, exist_ok=True) + + save_file(sd, item_info.latent_cache_path, metadata=metadata) + + +def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool): + assert ( + embed.dim() == 1 or embed.dim() == 2 + ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}" + assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}" + + # NaN check and show warning, replace NaN with 0 + if torch.isnan(embed).any(): + logger.warning(f"embed tensor has NaN: {item_info.item_key}, replace NaN with 0") + embed[torch.isnan(embed)] = 0 + + metadata = { + "architecture": "hunyuan_video", + "caption1": item_info.caption, + "format_version": "1.0.0", + } + + sd = {} + if os.path.exists(item_info.text_encoder_output_cache_path): + # load existing cache and update metadata + with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f: + existing_metadata = f.metadata() + for key in f.keys(): + sd[key] = f.get_tensor(key) + + assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch" + if existing_metadata["caption1"] != metadata["caption1"]: + logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite") + # TODO verify format_version + + existing_metadata.pop("caption1", None) + existing_metadata.pop("format_version", None) + metadata.update(existing_metadata) # copy existing metadata + else: + text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path) + os.makedirs(text_encoder_output_dir, exist_ok=True) + + dtype_str = dtype_to_str(embed.dtype) + text_encoder_type = "llm" if is_llm else "clipL" + sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu() + if mask is not None: + sd[f"{text_encoder_type}_mask"] = mask.detach().cpu() + + safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata) + + +class BucketSelector: + RESOLUTION_STEPS_HUNYUAN = 16 + + def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False): + self.resolution = resolution + self.bucket_area = resolution[0] * resolution[1] + self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN + + if not enable_bucket: + # only define one bucket + self.bucket_resolutions = [resolution] + self.no_upscale = False + else: + # prepare bucket resolution + self.no_upscale = no_upscale + sqrt_size = int(math.sqrt(self.bucket_area)) + min_size = divisible_by(sqrt_size // 2, self.reso_steps) + self.bucket_resolutions = [] + for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps): + h = divisible_by(self.bucket_area // w, self.reso_steps) + self.bucket_resolutions.append((w, h)) + self.bucket_resolutions.append((h, w)) + + self.bucket_resolutions = list(set(self.bucket_resolutions)) + self.bucket_resolutions.sort() + + # calculate aspect ratio to find the nearest resolution + self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions]) + + def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]: + """ + return the bucket resolution for the given image size, (width, height) + """ + area = image_size[0] * image_size[1] + if self.no_upscale and area <= self.bucket_area: + w, h = image_size + w = divisible_by(w, self.reso_steps) + h = divisible_by(h, self.reso_steps) + return w, h + + aspect_ratio = image_size[0] / image_size[1] + ar_errors = self.aspect_ratios - aspect_ratio + bucket_id = np.abs(ar_errors).argmin() + return self.bucket_resolutions[bucket_id] + + +def load_video( + video_path: str, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + bucket_selector: Optional[BucketSelector] = None, + bucket_reso: Optional[tuple[int, int]] = None, +) -> list[np.ndarray]: + """ + bucket_reso: if given, resize the video to the bucket resolution, (width, height) + """ + container = av.open(video_path) + video = [] + for i, frame in enumerate(container.decode(video=0)): + if start_frame is not None and i < start_frame: + continue + if end_frame is not None and i >= end_frame: + break + frame = frame.to_image() + + if bucket_selector is not None and bucket_reso is None: + bucket_reso = bucket_selector.get_bucket_resolution(frame.size) + + if bucket_reso is not None: + frame = resize_image_to_bucket(frame, bucket_reso) + else: + frame = np.array(frame) + + video.append(frame) + container.close() + return video + + +class BucketBatchManager: + + def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int): + self.batch_size = batch_size + self.buckets = bucketed_item_info + self.bucket_resos = list(self.buckets.keys()) + self.bucket_resos.sort() + + self.bucket_batch_indices = [] + for bucket_reso in self.bucket_resos: + bucket = self.buckets[bucket_reso] + num_batches = math.ceil(len(bucket) / self.batch_size) + for i in range(num_batches): + self.bucket_batch_indices.append((bucket_reso, i)) + + self.shuffle() + + def show_bucket_info(self): + for bucket_reso in self.bucket_resos: + bucket = self.buckets[bucket_reso] + logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}") + + logger.info(f"total batches: {len(self)}") + + def shuffle(self): + for bucket in self.buckets.values(): + random.shuffle(bucket) + random.shuffle(self.bucket_batch_indices) + + def __len__(self): + return len(self.bucket_batch_indices) + + def __getitem__(self, idx): + bucket_reso, batch_idx = self.bucket_batch_indices[idx] + bucket = self.buckets[bucket_reso] + start = batch_idx * self.batch_size + end = min(start + self.batch_size, len(bucket)) + + latents = [] + llm_embeds = [] + llm_masks = [] + clip_l_embeds = [] + for item_info in bucket[start:end]: + sd = load_file(item_info.latent_cache_path) + latent = None + for key in sd.keys(): + if key.startswith("latents_"): + latent = sd[key] + break + latents.append(latent) + + sd = load_file(item_info.text_encoder_output_cache_path) + llm_embed = llm_mask = clip_l_embed = None + for key in sd.keys(): + if key.startswith("llm_mask"): + llm_mask = sd[key] + elif key.startswith("llm_"): + llm_embed = sd[key] + elif key.startswith("clipL_mask"): + pass + elif key.startswith("clipL_"): + clip_l_embed = sd[key] + llm_embeds.append(llm_embed) + llm_masks.append(llm_mask) + clip_l_embeds.append(clip_l_embed) + + latents = torch.stack(latents) + llm_embeds = torch.stack(llm_embeds) + llm_masks = torch.stack(llm_masks) + clip_l_embeds = torch.stack(clip_l_embeds) + + return latents, llm_embeds, llm_masks, clip_l_embeds + + +class ContentDatasource: + def __init__(self): + self.caption_only = False + + def set_caption_only(self, caption_only: bool): + self.caption_only = caption_only + + def is_indexable(self): + return False + + def get_caption(self, idx: int) -> tuple[str, str]: + """ + Returns caption. May not be called if is_indexable() returns False. + """ + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def __iter__(self): + raise NotImplementedError + + def __next__(self): + raise NotImplementedError + + +class ImageDatasource(ContentDatasource): + def __init__(self): + super().__init__() + + def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: + """ + Returns image data as a tuple of image path, image, and caption for the given index. + Key must be unique and valid as a file name. + May not be called if is_indexable() returns False. + """ + raise NotImplementedError + + +class ImageDirectoryDatasource(ImageDatasource): + def __init__(self, image_directory: str, caption_extension: Optional[str] = None): + super().__init__() + self.image_directory = image_directory + self.caption_extension = caption_extension + self.current_idx = 0 + + # glob images + logger.info(f"glob images in {self.image_directory}") + self.image_paths = glob_images(self.image_directory) + logger.info(f"found {len(self.image_paths)} images") + + def is_indexable(self): + return True + + def __len__(self): + return len(self.image_paths) + + def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: + image_path = self.image_paths[idx] + image = Image.open(image_path).convert("RGB") + + _, caption = self.get_caption(idx) + + return image_path, image, caption + + def get_caption(self, idx: int) -> tuple[str, str]: + image_path = self.image_paths[idx] + caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else "" + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + return image_path, caption + + def __iter__(self): + self.current_idx = 0 + return self + + def __next__(self) -> callable: + """ + Returns a fetcher function that returns image data. + """ + if self.current_idx >= len(self.image_paths): + raise StopIteration + + if self.caption_only: + + def create_caption_fetcher(index): + return lambda: self.get_caption(index) + + fetcher = create_caption_fetcher(self.current_idx) + else: + + def create_image_fetcher(index): + return lambda: self.get_image_data(index) + + fetcher = create_image_fetcher(self.current_idx) + + self.current_idx += 1 + return fetcher + + +class ImageJsonlDatasource(ImageDatasource): + def __init__(self, image_jsonl_file: str): + super().__init__() + self.image_jsonl_file = image_jsonl_file + self.current_idx = 0 + + # load jsonl + logger.info(f"load image jsonl from {self.image_jsonl_file}") + self.data = [] + with open(self.image_jsonl_file, "r", encoding="utf-8") as f: + for line in f: + try: + data = json.loads(line) + except json.JSONDecodeError: + logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}") + raise + self.data.append(data) + logger.info(f"loaded {len(self.data)} images") + + def is_indexable(self): + return True + + def __len__(self): + return len(self.data) + + def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: + data = self.data[idx] + image_path = data["image_path"] + image = Image.open(image_path).convert("RGB") + + caption = data["caption"] + + return image_path, image, caption + + def get_caption(self, idx: int) -> tuple[str, str]: + data = self.data[idx] + image_path = data["image_path"] + caption = data["caption"] + return image_path, caption + + def __iter__(self): + self.current_idx = 0 + return self + + def __next__(self) -> callable: + if self.current_idx >= len(self.data): + raise StopIteration + + if self.caption_only: + + def create_caption_fetcher(index): + return lambda: self.get_caption(index) + + fetcher = create_caption_fetcher(self.current_idx) + + else: + + def create_fetcher(index): + return lambda: self.get_image_data(index) + + fetcher = create_fetcher(self.current_idx) + + self.current_idx += 1 + return fetcher + + +class VideoDatasource(ContentDatasource): + def __init__(self): + super().__init__() + + # None means all frames + self.start_frame = None + self.end_frame = None + + self.bucket_selector = None + + def __len__(self): + raise NotImplementedError + + def get_video_data_from_path( + self, + video_path: str, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + bucket_selector: Optional[BucketSelector] = None, + ) -> tuple[str, list[Image.Image], str]: + # this method can resize the video if bucket_selector is given to reduce the memory usage + + start_frame = start_frame if start_frame is not None else self.start_frame + end_frame = end_frame if end_frame is not None else self.end_frame + bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector + + video = load_video(video_path, start_frame, end_frame, bucket_selector) + return video + + def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]): + self.start_frame = start_frame + self.end_frame = end_frame + + def set_bucket_selector(self, bucket_selector: BucketSelector): + self.bucket_selector = bucket_selector + + def __iter__(self): + raise NotImplementedError + + def __next__(self): + raise NotImplementedError + + +class VideoDirectoryDatasource(VideoDatasource): + def __init__(self, video_directory: str, caption_extension: Optional[str] = None): + super().__init__() + self.video_directory = video_directory + self.caption_extension = caption_extension + self.current_idx = 0 + + # glob images + logger.info(f"glob images in {self.video_directory}") + self.video_paths = glob_videos(self.video_directory) + logger.info(f"found {len(self.video_paths)} videos") + + def is_indexable(self): + return True + + def __len__(self): + return len(self.video_paths) + + def get_video_data( + self, + idx: int, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + bucket_selector: Optional[BucketSelector] = None, + ) -> tuple[str, list[Image.Image], str]: + video_path = self.video_paths[idx] + video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) + + _, caption = self.get_caption(idx) + + return video_path, video, caption + + def get_caption(self, idx: int) -> tuple[str, str]: + video_path = self.video_paths[idx] + caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else "" + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + return video_path, caption + + def __iter__(self): + self.current_idx = 0 + return self + + def __next__(self): + if self.current_idx >= len(self.video_paths): + raise StopIteration + + if self.caption_only: + + def create_caption_fetcher(index): + return lambda: self.get_caption(index) + + fetcher = create_caption_fetcher(self.current_idx) + + else: + + def create_fetcher(index): + return lambda: self.get_video_data(index) + + fetcher = create_fetcher(self.current_idx) + + self.current_idx += 1 + return fetcher + + +class VideoJsonlDatasource(VideoDatasource): + def __init__(self, video_jsonl_file: str): + super().__init__() + self.video_jsonl_file = video_jsonl_file + self.current_idx = 0 + + # load jsonl + logger.info(f"load video jsonl from {self.video_jsonl_file}") + self.data = [] + with open(self.video_jsonl_file, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + self.data.append(data) + logger.info(f"loaded {len(self.data)} videos") + + def is_indexable(self): + return True + + def __len__(self): + return len(self.data) + + def get_video_data( + self, + idx: int, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + bucket_selector: Optional[BucketSelector] = None, + ) -> tuple[str, list[Image.Image], str]: + data = self.data[idx] + video_path = data["video_path"] + video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) + + caption = data["caption"] + + return video_path, video, caption + + def get_caption(self, idx: int) -> tuple[str, str]: + data = self.data[idx] + video_path = data["video_path"] + caption = data["caption"] + return video_path, caption + + def __iter__(self): + self.current_idx = 0 + return self + + def __next__(self): + if self.current_idx >= len(self.data): + raise StopIteration + + if self.caption_only: + + def create_caption_fetcher(index): + return lambda: self.get_caption(index) + + fetcher = create_caption_fetcher(self.current_idx) + + else: + + def create_fetcher(index): + return lambda: self.get_video_data(index) + + fetcher = create_fetcher(self.current_idx) + + self.current_idx += 1 + return fetcher + + +class BaseDataset(torch.utils.data.Dataset): + def __init__( + self, + resolution: Tuple[int, int] = (960, 544), + caption_extension: Optional[str] = None, + batch_size: int = 1, + num_repeats: int = 1, + enable_bucket: bool = False, + bucket_no_upscale: bool = False, + cache_directory: Optional[str] = None, + debug_dataset: bool = False, + ): + self.resolution = resolution + self.caption_extension = caption_extension + self.batch_size = batch_size + self.num_repeats = num_repeats + self.enable_bucket = enable_bucket + self.bucket_no_upscale = bucket_no_upscale + self.cache_directory = cache_directory + self.debug_dataset = debug_dataset + self.seed = None + self.current_epoch = 0 + + if not self.enable_bucket: + self.bucket_no_upscale = False + + def get_metadata(self) -> dict: + metadata = { + "resolution": self.resolution, + "caption_extension": self.caption_extension, + "batch_size_per_device": self.batch_size, + "num_repeats": self.num_repeats, + "enable_bucket": bool(self.enable_bucket), + "bucket_no_upscale": bool(self.bucket_no_upscale), + } + return metadata + + def get_all_latent_cache_files(self): + return glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) + + def get_all_text_encoder_output_cache_files(self): + return glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors")) + + def get_latent_cache_path(self, item_info: ItemInfo) -> str: + """ + Returns the cache path for the latent tensor. + + item_info: ItemInfo object + + Returns: + str: cache path + + cache_path is based on the item_key and the resolution. + """ + w, h = item_info.original_size + basename = os.path.splitext(os.path.basename(item_info.item_key))[0] + assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" + return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors") + + def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str: + basename = os.path.splitext(os.path.basename(item_info.item_key))[0] + assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" + return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors") + + def retrieve_latent_cache_batches(self, num_workers: int): + raise NotImplementedError + + def retrieve_text_encoder_output_cache_batches(self, num_workers: int): + raise NotImplementedError + + def prepare_for_training(self): + pass + + def set_seed(self, seed: int): + self.seed = seed + + def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented + if epoch > self.current_epoch: + logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + num_epochs = epoch - self.current_epoch + for _ in range(num_epochs): + self.current_epoch += 1 + self.shuffle_buckets() + # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? + else: + logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + self.current_epoch = epoch + + def set_current_step(self, step): + self.current_step = step + + def set_max_train_steps(self, max_train_steps): + self.max_train_steps = max_train_steps + + def shuffle_buckets(self): + raise NotImplementedError + + def __len__(self): + return NotImplementedError + + def __getitem__(self, idx): + raise NotImplementedError + + def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int): + datasource.set_caption_only(True) + executor = ThreadPoolExecutor(max_workers=num_workers) + + data: list[ItemInfo] = [] + futures = [] + + def aggregate_future(consume_all: bool = False): + while len(futures) >= num_workers or (consume_all and len(futures) > 0): + completed_futures = [future for future in futures if future.done()] + if len(completed_futures) == 0: + if len(futures) >= num_workers or consume_all: # to avoid adding too many futures + time.sleep(0.1) + continue + else: + break # submit batch if possible + + for future in completed_futures: + item_key, caption = future.result() + item_info = ItemInfo(item_key, caption, (0, 0), (0, 0)) + item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info) + data.append(item_info) + + futures.remove(future) + + def submit_batch(flush: bool = False): + nonlocal data + if len(data) >= batch_size or (len(data) > 0 and flush): + batch = data[0:batch_size] + if len(data) > batch_size: + data = data[batch_size:] + else: + data = [] + return batch + return None + + for fetch_op in datasource: + future = executor.submit(fetch_op) + futures.append(future) + aggregate_future() + while True: + batch = submit_batch() + if batch is None: + break + yield batch + + aggregate_future(consume_all=True) + while True: + batch = submit_batch(flush=True) + if batch is None: + break + yield batch + + executor.shutdown() + + +class ImageDataset(BaseDataset): + def __init__( + self, + resolution: Tuple[int, int], + caption_extension: Optional[str], + batch_size: int, + num_repeats: int, + enable_bucket: bool, + bucket_no_upscale: bool, + image_directory: Optional[str] = None, + image_jsonl_file: Optional[str] = None, + cache_directory: Optional[str] = None, + debug_dataset: bool = False, + ): + super(ImageDataset, self).__init__( + resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset + ) + self.image_directory = image_directory + self.image_jsonl_file = image_jsonl_file + if image_directory is not None: + self.datasource = ImageDirectoryDatasource(image_directory, caption_extension) + elif image_jsonl_file is not None: + self.datasource = ImageJsonlDatasource(image_jsonl_file) + else: + raise ValueError("image_directory or image_jsonl_file must be specified") + + if self.cache_directory is None: + self.cache_directory = self.image_directory + + self.batch_manager = None + self.num_train_items = 0 + + def get_metadata(self): + metadata = super().get_metadata() + if self.image_directory is not None: + metadata["image_directory"] = os.path.basename(self.image_directory) + if self.image_jsonl_file is not None: + metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file) + return metadata + + def get_total_image_count(self): + return len(self.datasource) if self.datasource.is_indexable() else None + + def retrieve_latent_cache_batches(self, num_workers: int): + buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) + executor = ThreadPoolExecutor(max_workers=num_workers) + + batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo] + futures = [] + + # aggregate futures and sort by bucket resolution + def aggregate_future(consume_all: bool = False): + while len(futures) >= num_workers or (consume_all and len(futures) > 0): + completed_futures = [future for future in futures if future.done()] + if len(completed_futures) == 0: + if len(futures) >= num_workers or consume_all: # to avoid adding too many futures + time.sleep(0.1) + continue + else: + break # submit batch if possible + + for future in completed_futures: + original_size, item_key, image, caption = future.result() + bucket_height, bucket_width = image.shape[:2] + bucket_reso = (bucket_width, bucket_height) + + item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image) + item_info.latent_cache_path = self.get_latent_cache_path(item_info) + + if bucket_reso not in batches: + batches[bucket_reso] = [] + batches[bucket_reso].append(item_info) + + futures.remove(future) + + # submit batch if some bucket has enough items + def submit_batch(flush: bool = False): + for key in batches: + if len(batches[key]) >= self.batch_size or flush: + batch = batches[key][0 : self.batch_size] + if len(batches[key]) > self.batch_size: + batches[key] = batches[key][self.batch_size :] + else: + del batches[key] + return key, batch + return None, None + + for fetch_op in self.datasource: + + # fetch and resize image in a separate thread + def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]: + image_key, image, caption = op() + image: Image.Image + image_size = image.size + + bucket_reso = buckset_selector.get_bucket_resolution(image_size) + image = resize_image_to_bucket(image, bucket_reso) + return image_size, image_key, image, caption + + future = executor.submit(fetch_and_resize, fetch_op) + futures.append(future) + aggregate_future() + while True: + key, batch = submit_batch() + if key is None: + break + yield key, batch + + aggregate_future(consume_all=True) + while True: + key, batch = submit_batch(flush=True) + if key is None: + break + yield key, batch + + executor.shutdown() + + def retrieve_text_encoder_output_cache_batches(self, num_workers: int): + return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) + + def prepare_for_training(self): + bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) + + # glob cache files + latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) + + # assign cache files to item info + bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo] + for cache_file in latent_cache_files: + tokens = os.path.basename(cache_file).split("_") + + image_size = tokens[-2] # 0000x0000 + image_width, image_height = map(int, image_size.split("x")) + image_size = (image_width, image_height) + + item_key = "_".join(tokens[:-2]) + text_encoder_output_cache_file = os.path.join( + self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" + ) + if not os.path.exists(text_encoder_output_cache_file): + logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") + continue + + bucket_reso = bucket_selector.get_bucket_resolution(image_size) + item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file) + item_info.text_encoder_output_cache_path = text_encoder_output_cache_file + + bucket = bucketed_item_info.get(bucket_reso, []) + for _ in range(self.num_repeats): + bucket.append(item_info) + bucketed_item_info[bucket_reso] = bucket + + # prepare batch manager + self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) + self.batch_manager.show_bucket_info() + + self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) + + def shuffle_buckets(self): + # set random seed for this epoch + random.seed(self.seed + self.current_epoch) + self.batch_manager.shuffle() + + def __len__(self): + if self.batch_manager is None: + return 100 # dummy value + return len(self.batch_manager) + + def __getitem__(self, idx): + return self.batch_manager[idx] + + +class VideoDataset(BaseDataset): + def __init__( + self, + resolution: Tuple[int, int], + caption_extension: Optional[str], + batch_size: int, + num_repeats: int, + enable_bucket: bool, + bucket_no_upscale: bool, + frame_extraction: Optional[str] = "head", + frame_stride: Optional[int] = 1, + frame_sample: Optional[int] = 1, + target_frames: Optional[list[int]] = None, + video_directory: Optional[str] = None, + video_jsonl_file: Optional[str] = None, + cache_directory: Optional[str] = None, + debug_dataset: bool = False, + ): + super(VideoDataset, self).__init__( + resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset + ) + self.video_directory = video_directory + self.video_jsonl_file = video_jsonl_file + self.target_frames = target_frames + self.frame_extraction = frame_extraction + self.frame_stride = frame_stride + self.frame_sample = frame_sample + + if video_directory is not None: + self.datasource = VideoDirectoryDatasource(video_directory, caption_extension) + elif video_jsonl_file is not None: + self.datasource = VideoJsonlDatasource(video_jsonl_file) + + if self.frame_extraction == "uniform" and self.frame_sample == 1: + self.frame_extraction = "head" + logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.") + if self.frame_extraction == "head": + # head extraction. we can limit the number of frames to be extracted + self.datasource.set_start_and_end_frame(0, max(self.target_frames)) + + if self.cache_directory is None: + self.cache_directory = self.video_directory + + self.batch_manager = None + self.num_train_items = 0 + + def get_metadata(self): + metadata = super().get_metadata() + if self.video_directory is not None: + metadata["video_directory"] = os.path.basename(self.video_directory) + if self.video_jsonl_file is not None: + metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file) + metadata["frame_extraction"] = self.frame_extraction + metadata["frame_stride"] = self.frame_stride + metadata["frame_sample"] = self.frame_sample + metadata["target_frames"] = self.target_frames + return metadata + + def retrieve_latent_cache_batches(self, num_workers: int): + buckset_selector = BucketSelector(self.resolution) + self.datasource.set_bucket_selector(buckset_selector) + + executor = ThreadPoolExecutor(max_workers=num_workers) + + # key: (width, height, frame_count), value: [ItemInfo] + batches: dict[tuple[int, int, int], list[ItemInfo]] = {} + futures = [] + + def aggregate_future(consume_all: bool = False): + while len(futures) >= num_workers or (consume_all and len(futures) > 0): + completed_futures = [future for future in futures if future.done()] + if len(completed_futures) == 0: + if len(futures) >= num_workers or consume_all: # to avoid adding too many futures + time.sleep(0.1) + continue + else: + break # submit batch if possible + + for future in completed_futures: + original_frame_size, video_key, video, caption = future.result() + + frame_count = len(video) + video = np.stack(video, axis=0) + height, width = video.shape[1:3] + bucket_reso = (width, height) # already resized + + crop_pos_and_frames = [] + if self.frame_extraction == "head": + for target_frame in self.target_frames: + if frame_count >= target_frame: + crop_pos_and_frames.append((0, target_frame)) + elif self.frame_extraction == "chunk": + # split by target_frames + for target_frame in self.target_frames: + for i in range(0, frame_count, target_frame): + if i + target_frame <= frame_count: + crop_pos_and_frames.append((i, target_frame)) + elif self.frame_extraction == "slide": + # slide window + for target_frame in self.target_frames: + if frame_count >= target_frame: + for i in range(0, frame_count - target_frame + 1, self.frame_stride): + crop_pos_and_frames.append((i, target_frame)) + elif self.frame_extraction == "uniform": + # select N frames uniformly + for target_frame in self.target_frames: + if frame_count >= target_frame: + frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int) + for i in frame_indices: + crop_pos_and_frames.append((i, target_frame)) + else: + raise ValueError(f"frame_extraction {self.frame_extraction} is not supported") + + for crop_pos, target_frame in crop_pos_and_frames: + cropped_video = video[crop_pos : crop_pos + target_frame] + body, ext = os.path.splitext(video_key) + item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}" + batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count + + item_info = ItemInfo( + item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video + ) + item_info.latent_cache_path = self.get_latent_cache_path(item_info) + + batch = batches.get(batch_key, []) + batch.append(item_info) + batches[batch_key] = batch + + futures.remove(future) + + def submit_batch(flush: bool = False): + for key in batches: + if len(batches[key]) >= self.batch_size or flush: + batch = batches[key][0 : self.batch_size] + if len(batches[key]) > self.batch_size: + batches[key] = batches[key][self.batch_size :] + else: + del batches[key] + return key, batch + return None, None + + for operator in self.datasource: + + def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]: + video_key, video, caption = op() + video: list[np.ndarray] + frame_size = (video[0].shape[1], video[0].shape[0]) + + # resize if necessary + bucket_reso = buckset_selector.get_bucket_resolution(frame_size) + video = [resize_image_to_bucket(frame, bucket_reso) for frame in video] + + return frame_size, video_key, video, caption + + future = executor.submit(fetch_and_resize, operator) + futures.append(future) + aggregate_future() + while True: + key, batch = submit_batch() + if key is None: + break + yield key, batch + + aggregate_future(consume_all=True) + while True: + key, batch = submit_batch(flush=True) + if key is None: + break + yield key, batch + + executor.shutdown() + + def retrieve_text_encoder_output_cache_batches(self, num_workers: int): + return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) + + def prepare_for_training(self): + bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) + + # glob cache files + latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) + + # assign cache files to item info + bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo] + for cache_file in latent_cache_files: + tokens = os.path.basename(cache_file).split("_") + + image_size = tokens[-2] # 0000x0000 + image_width, image_height = map(int, image_size.split("x")) + image_size = (image_width, image_height) + + frame_pos, frame_count = tokens[-3].split("-") + frame_pos, frame_count = int(frame_pos), int(frame_count) + + item_key = "_".join(tokens[:-3]) + text_encoder_output_cache_file = os.path.join( + self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" + ) + if not os.path.exists(text_encoder_output_cache_file): + logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") + continue + + bucket_reso = bucket_selector.get_bucket_resolution(image_size) + bucket_reso = (*bucket_reso, frame_count) + item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file) + item_info.text_encoder_output_cache_path = text_encoder_output_cache_file + + bucket = bucketed_item_info.get(bucket_reso, []) + for _ in range(self.num_repeats): + bucket.append(item_info) + bucketed_item_info[bucket_reso] = bucket + + # prepare batch manager + self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) + self.batch_manager.show_bucket_info() + + self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) + + def shuffle_buckets(self): + # set random seed for this epoch + random.seed(self.seed + self.current_epoch) + self.batch_manager.shuffle() + + def __len__(self): + if self.batch_manager is None: + return 100 # dummy value + return len(self.batch_manager) + + def __getitem__(self, idx): + return self.batch_manager[idx] + + +class DatasetGroup(torch.utils.data.ConcatDataset): + def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]): + super().__init__(datasets) + self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets + self.num_train_items = 0 + for dataset in self.datasets: + self.num_train_items += dataset.num_train_items + + def set_current_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_current_epoch(epoch) + + def set_current_step(self, step): + for dataset in self.datasets: + dataset.set_current_step(step) + + def set_max_train_steps(self, max_train_steps): + for dataset in self.datasets: + dataset.set_max_train_steps(max_train_steps) diff --git a/diffusers_helper/bucket_tools.py b/diffusers_helper/bucket_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..1d531642945e5951214e8a7bd6fbd39d824664a5 --- /dev/null +++ b/diffusers_helper/bucket_tools.py @@ -0,0 +1,157 @@ +# Base configuration for scaling bucket options +_BASE_RESOLUTION = 640 +_BASE_BUCKET_OPTIONS = [ + (416, 960), (448, 864), (480, 832), (512, 768), (544, 704), + (576, 672), (608, 640), (640, 608), (672, 576), (704, 544), + (768, 512), (832, 480), (864, 448), (960, 416), +] + +# Cache for generated bucket options to avoid redundant calculations +_generated_bucket_cache = {} + +def _round_to_multiple(number, multiple): + """Rounds a number to the nearest multiple of a given number.""" + if multiple == 0: + # Default behavior: round to nearest int. Could also raise an error. + return int(round(number)) + return int(multiple * round(float(number) / multiple)) + +def _adjust_resolution(resolution, divisor=32): + """ + Adjusts a given resolution to the nearest multiple of 'divisor'. + If the input resolution is positive but rounds to 0 (e.g., resolution=10, divisor=32), + it's adjusted to 'divisor'. + If the input resolution is non-positive (<=0), it defaults to 'divisor'. + """ + if resolution <= 0: + return divisor # Default to minimum valid resolution for non-positive inputs + + adjusted = _round_to_multiple(resolution, divisor) + + # If resolution was positive but _round_to_multiple resulted in 0 + # (e.g. input 10 for divisor 32 rounds to 0), ensure it's at least the divisor. + if adjusted == 0: + return divisor + return adjusted + +def generate_scaled_buckets(target_resolution_input, + base_resolution=_BASE_RESOLUTION, + base_options=_BASE_BUCKET_OPTIONS, + divisor=32): + """ + Generates scaled bucket options for a target resolution. + + The target_resolution_input is first adjusted to the nearest multiple of 'divisor'. + Bucket dimensions are scaled from 'base_options' (which are for 'base_resolution') + to the adjusted target resolution. These scaled dimensions are then rounded to the + nearest multiple of 'divisor' and ensured to be at least 'divisor'. + + Args: + target_resolution_input (int): The desired target resolution. + base_resolution (int): The resolution for which 'base_options' are defined. + base_options (list of tuples): A list of (height, width) tuples for 'base_resolution'. + divisor (int): The number that resolutions and bucket dimensions should be multiples of. + + Returns: + list of tuples: Scaled and adjusted bucket options (height, width). + """ + # Adjust the target resolution for scaling + actual_target_resolution = _adjust_resolution(target_resolution_input, divisor) + + if actual_target_resolution in _generated_bucket_cache: + return _generated_bucket_cache[actual_target_resolution] + + # Optimization: If adjusted target resolution matches base resolution. + # This assumes base_options are already compliant with the divisor. + # (Our _BASE_BUCKET_OPTIONS are multiples of 32, so this is fine for divisor=32). + if actual_target_resolution == base_resolution: + options_to_return = list(base_options) # Return a copy + _generated_bucket_cache[actual_target_resolution] = options_to_return + return options_to_return + + scaled_options = [] + seen_options = set() # To handle potential duplicates after rounding + + # Prevent division by zero if base_resolution is 0 (though _BASE_RESOLUTION is 640). + if base_resolution == 0: + # Fallback: return a single square bucket of the target resolution. + # This case should not be hit with current constants. + default_bucket = (actual_target_resolution, actual_target_resolution) + _generated_bucket_cache[actual_target_resolution] = [default_bucket] + return [default_bucket] + + scale_factor = float(actual_target_resolution) / base_resolution + + for base_h, base_w in base_options: + scaled_h_float = base_h * scale_factor + scaled_w_float = base_w * scale_factor + + scaled_h = _round_to_multiple(scaled_h_float, divisor) + scaled_w = _round_to_multiple(scaled_w_float, divisor) + + # Ensure minimum dimension is at least the divisor + scaled_h = max(scaled_h, divisor) + scaled_w = max(scaled_w, divisor) + + bucket_tuple = (scaled_h, scaled_w) + if bucket_tuple not in seen_options: + scaled_options.append(bucket_tuple) + seen_options.add(bucket_tuple) + + # If base_options was empty (not the case for internal use but could be if called externally), + # scaled_options would be empty. Provide a default bucket in such a scenario. + # actual_target_resolution is guaranteed to be >= divisor by _adjust_resolution. + if not scaled_options: + default_bucket = (actual_target_resolution, actual_target_resolution) + scaled_options.append(default_bucket) + + _generated_bucket_cache[actual_target_resolution] = scaled_options + return scaled_options + +def find_nearest_bucket(h, w, resolution=640): + """ + Finds the nearest bucket for a given height (h) and width (w) + at a specified target resolution. + + The 'resolution' parameter is the user's intended target resolution. + This function will: + 1. Adjust this resolution to the nearest multiple of 32 (minimum 32). + 2. Generate a list of bucket options (height, width pairs) by scaling + predefined base options (for 640px) to this adjusted resolution. + All generated bucket dimensions will also be multiples of 32 and at least 32. + 3. Find the bucket from this generated list that is "nearest" to the + aspect ratio of the input h, w. The nearness metric is + abs(input_h * bucket_w - input_w * bucket_h). + + Args: + h (int): The height of the image/item. + w (int): The width of the image/item. + resolution (int): The target resolution for which to find buckets. + Defaults to 640. + + Returns: + tuple: A (bucket_h, bucket_w) tuple representing the best bucket found. + """ + # generate_scaled_buckets handles the adjustment of 'resolution' internally + # and uses a divisor of 32 by default for its calculations. + # The problem statement implies a fixed divisor of 32 for this tool. + current_bucket_options = generate_scaled_buckets(resolution, divisor=32) + + # Failsafe: If generate_scaled_buckets somehow returned an empty list (e.g., if _BASE_BUCKET_OPTIONS was empty), + # provide a default bucket based on the adjusted resolution. + if not current_bucket_options: + adjusted_res_for_fallback = _adjust_resolution(resolution, 32) + return (adjusted_res_for_fallback, adjusted_res_for_fallback) + + min_metric = float('inf') + best_bucket = None + # Since current_bucket_options is guaranteed to be non-empty by the check above (or by generate_scaled_buckets's own logic + # when _BASE_BUCKET_OPTIONS is populated), best_bucket will be assigned in the loop. + + for (bucket_h, bucket_w) in current_bucket_options: + metric = abs(h * bucket_w - w * bucket_h) + if metric <= min_metric: # Using "<=" preserves original behavior (last encountered wins on ties) + min_metric = metric + best_bucket = (bucket_h, bucket_w) + + return best_bucket \ No newline at end of file diff --git a/diffusers_helper/clip_vision.py b/diffusers_helper/clip_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf40dbf1b4ef975640e0ad0d5a7792652d79334 --- /dev/null +++ b/diffusers_helper/clip_vision.py @@ -0,0 +1,12 @@ +import numpy as np + + +def hf_clip_vision_encode(image, feature_extractor, image_encoder): + assert isinstance(image, np.ndarray) + assert image.ndim == 3 and image.shape[2] == 3 + assert image.dtype == np.uint8 + + preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype) + image_encoder_output = image_encoder(**preprocessed) + + return image_encoder_output diff --git a/diffusers_helper/dit_common.py b/diffusers_helper/dit_common.py new file mode 100644 index 0000000000000000000000000000000000000000..f02e7b012bff0b3b0fce9136d29fee4a1d49e45e --- /dev/null +++ b/diffusers_helper/dit_common.py @@ -0,0 +1,53 @@ +import torch +import accelerate.accelerator + +from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous + + +accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x + + +def LayerNorm_forward(self, x): + return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x) + + +LayerNorm.forward = LayerNorm_forward +torch.nn.LayerNorm.forward = LayerNorm_forward + + +def FP32LayerNorm_forward(self, x): + origin_dtype = x.dtype + return torch.nn.functional.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +FP32LayerNorm.forward = FP32LayerNorm_forward + + +def RMSNorm_forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is None: + return hidden_states.to(input_dtype) + + return hidden_states.to(input_dtype) * self.weight.to(input_dtype) + + +RMSNorm.forward = RMSNorm_forward + + +def AdaLayerNormContinuous_forward(self, x, conditioning_embedding): + emb = self.linear(self.silu(conditioning_embedding)) + scale, shift = emb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward diff --git a/diffusers_helper/gradio/progress_bar.py b/diffusers_helper/gradio/progress_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc612163a171cef37d67d991d729ba9fec066db --- /dev/null +++ b/diffusers_helper/gradio/progress_bar.py @@ -0,0 +1,86 @@ +progress_html = ''' +
+
+
+ +
+ *text* +
+''' + +css = ''' +.loader-container { + display: flex; /* Use flex to align items horizontally */ + align-items: center; /* Center items vertically within the container */ + white-space: nowrap; /* Prevent line breaks within the container */ +} + +.loader { + border: 8px solid #f3f3f3; /* Light grey */ + border-top: 8px solid #3498db; /* Blue */ + border-radius: 50%; + width: 30px; + height: 30px; + animation: spin 2s linear infinite; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +/* Style the progress bar */ +progress { + appearance: none; /* Remove default styling */ + height: 20px; /* Set the height of the progress bar */ + border-radius: 5px; /* Round the corners of the progress bar */ + background-color: #f3f3f3; /* Light grey background */ + width: 100%; + vertical-align: middle !important; +} + +/* Style the progress bar container */ +.progress-container { + margin-left: 20px; + margin-right: 20px; + flex-grow: 1; /* Allow the progress container to take up remaining space */ +} + +/* Set the color of the progress bar fill */ +progress::-webkit-progress-value { + background-color: #3498db; /* Blue color for the fill */ +} + +progress::-moz-progress-bar { + background-color: #3498db; /* Blue color for the fill in Firefox */ +} + +/* Style the text on the progress bar */ +progress::after { + content: attr(value '%'); /* Display the progress value followed by '%' */ + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + color: white; /* Set text color */ + font-size: 14px; /* Set font size */ +} + +/* Style other texts */ +.loader-container > span { + margin-left: 5px; /* Add spacing between the progress bar and the text */ +} + +.no-generating-animation > .generating { + display: none !important; +} + +''' + + +def make_progress_bar_html(number, text): + return progress_html.replace('*number*', str(number)).replace('*text*', text) + + +def make_progress_bar_css(): + return css diff --git a/diffusers_helper/hf_login.py b/diffusers_helper/hf_login.py new file mode 100644 index 0000000000000000000000000000000000000000..b039db24378b0419e69ee97042f88e96460766ef --- /dev/null +++ b/diffusers_helper/hf_login.py @@ -0,0 +1,21 @@ +import os + + +def login(token): + from huggingface_hub import login + import time + + while True: + try: + login(token) + print('HF login ok.') + break + except Exception as e: + print(f'HF login failed: {e}. Retrying') + time.sleep(0.5) + + +hf_token = os.environ.get('HF_TOKEN', None) + +if hf_token is not None: + login(hf_token) diff --git a/diffusers_helper/hunyuan.py b/diffusers_helper/hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5c8561f5701f201c3b22c182924e3b819e63bf --- /dev/null +++ b/diffusers_helper/hunyuan.py @@ -0,0 +1,111 @@ +import torch + +from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE +from diffusers_helper.utils import crop_or_pad_yield_mask + + +@torch.no_grad() +def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256): + assert isinstance(prompt, str) + + prompt = [prompt] + + # LLAMA + + prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt] + crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"] + + llama_inputs = tokenizer( + prompt_llama, + padding="max_length", + max_length=max_length + crop_start, + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + + llama_input_ids = llama_inputs.input_ids.to(text_encoder.device) + llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device) + llama_attention_length = int(llama_attention_mask.sum()) + + llama_outputs = text_encoder( + input_ids=llama_input_ids, + attention_mask=llama_attention_mask, + output_hidden_states=True, + ) + + llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length] + # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:] + llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length] + + assert torch.all(llama_attention_mask.bool()) + + # CLIP + + clip_l_input_ids = tokenizer_2( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ).input_ids + clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output + + return llama_vec, clip_l_pooler + + +@torch.no_grad() +def vae_decode_fake(latents): + latent_rgb_factors = [ + [-0.0395, -0.0331, 0.0445], + [0.0696, 0.0795, 0.0518], + [0.0135, -0.0945, -0.0282], + [0.0108, -0.0250, -0.0765], + [-0.0209, 0.0032, 0.0224], + [-0.0804, -0.0254, -0.0639], + [-0.0991, 0.0271, -0.0669], + [-0.0646, -0.0422, -0.0400], + [-0.0696, -0.0595, -0.0894], + [-0.0799, -0.0208, -0.0375], + [0.1166, 0.1627, 0.0962], + [0.1165, 0.0432, 0.0407], + [-0.2315, -0.1920, -0.1355], + [-0.0270, 0.0401, -0.0821], + [-0.0616, -0.0997, -0.0727], + [0.0249, -0.0469, -0.1703] + ] # From comfyui + + latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] + + weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None] + bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) + + images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1) + images = images.clamp(0.0, 1.0) + + return images + + +@torch.no_grad() +def vae_decode(latents, vae, image_mode=False): + latents = latents / vae.config.scaling_factor + + if not image_mode: + image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample + else: + latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2) + image = [vae.decode(l.unsqueeze(2)).sample for l in latents] + image = torch.cat(image, dim=2) + + return image + + +@torch.no_grad() +def vae_encode(image, vae): + latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + return latents diff --git a/diffusers_helper/k_diffusion/uni_pc_fm.py b/diffusers_helper/k_diffusion/uni_pc_fm.py new file mode 100644 index 0000000000000000000000000000000000000000..b5763532a04fc81317b773c59c9878f213abe841 --- /dev/null +++ b/diffusers_helper/k_diffusion/uni_pc_fm.py @@ -0,0 +1,141 @@ +# Better Flow Matching UniPC by Lvmin Zhang +# (c) 2025 +# CC BY-SA 4.0 +# Attribution-ShareAlike 4.0 International Licence + + +import torch + +from tqdm.auto import trange + + +def expand_dims(v, dims): + return v[(...,) + (None,) * (dims - 1)] + + +class FlowMatchUniPC: + def __init__(self, model, extra_args, variant='bh1'): + self.model = model + self.variant = variant + self.extra_args = extra_args + + def model_fn(self, x, t): + return self.model(x, t, **self.extra_args) + + def update_fn(self, x, model_prev_list, t_prev_list, t, order): + assert order <= len(model_prev_list) + dims = x.dim() + + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = - torch.log(t_prev_0) + lambda_t = - torch.log(t) + model_prev_0 = model_prev_list[-1] + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = - torch.log(t_prev_i) + rk = ((lambda_prev_i - lambda_prev_0) / h)[0] + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h[0] + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError('Bad variant!') + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=x.device) + + use_predictor = len(D1s) > 0 + + if use_predictor: + D1s = torch.stack(D1s, dim=1) + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + rhos_p = None + + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0 + + if use_predictor: + pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) + else: + pred_res = 0 + + x_t = x_t_ - expand_dims(B_h, dims) * pred_res + model_t = self.model_fn(x_t, t) + + if D1s is not None: + corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) + else: + corr_res = 0 + + D1_t = (model_t - model_prev_0) + x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + + return x_t, model_t + + def sample(self, x, sigmas, callback=None, disable_pbar=False): + order = min(3, len(sigmas) - 2) + model_prev_list, t_prev_list = [], [] + for i in trange(len(sigmas) - 1, disable=disable_pbar): + vec_t = sigmas[i].expand(x.shape[0]) + + if i == 0: + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + elif i < order: + init_order = i + x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + else: + x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + + model_prev_list = model_prev_list[-order:] + t_prev_list = t_prev_list[-order:] + + if callback is not None: + callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]}) + + return model_prev_list[-1] + + +def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): + assert variant in ['bh1', 'bh2'] + return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable) diff --git a/diffusers_helper/k_diffusion/wrapper.py b/diffusers_helper/k_diffusion/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..cc420da4db1134deca30648077923021b35f82d1 --- /dev/null +++ b/diffusers_helper/k_diffusion/wrapper.py @@ -0,0 +1,51 @@ +import torch + + +def append_dims(x, target_dims): + return x[(...,) + (None,) * (target_dims - x.ndim)] + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0): + if guidance_rescale == 0: + return noise_cfg + + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg + return noise_cfg + + +def fm_wrapper(transformer, t_scale=1000.0): + def k_model(x, sigma, **extra_args): + dtype = extra_args['dtype'] + cfg_scale = extra_args['cfg_scale'] + cfg_rescale = extra_args['cfg_rescale'] + concat_latent = extra_args['concat_latent'] + + original_dtype = x.dtype + sigma = sigma.float() + + x = x.to(dtype) + timestep = (sigma * t_scale).to(dtype) + + if concat_latent is None: + hidden_states = x + else: + hidden_states = torch.cat([x, concat_latent.to(x)], dim=1) + + pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float() + + if cfg_scale == 1.0: + pred_negative = torch.zeros_like(pred_positive) + else: + pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float() + + pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative) + pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale) + + x0 = x.float() - pred.float() * append_dims(sigma, x.ndim) + + return x0.to(dtype=original_dtype) + + return k_model diff --git a/diffusers_helper/load_lora.py b/diffusers_helper/load_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..8dafc1fe6f8b8cc82e1201ab8cf6eb4326462a26 --- /dev/null +++ b/diffusers_helper/load_lora.py @@ -0,0 +1,39 @@ +from pathlib import Path +from typing import Optional +from diffusers.loaders.lora_pipeline import _fetch_state_dict +from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers + +def load_lora(transformer, lora_path: Path, weight_name: Optional[str] = "pytorch_lora_weights.safetensors", diffuser_lora: bool = False): + """ + Load LoRA weights into the transformer model. + + Args: + transformer: The transformer model to which LoRA weights will be applied. + lora_path (Path): Path to the LoRA weights file. + weight_name (Optional[str]): Name of the weight to load. + + """ + + state_dict = _fetch_state_dict( + lora_path, + weight_name, + True, + True, + None, + None, + None, + None, + None, + None, + None, + None) + + + if not diffuser_lora: + print("Not a diffusers lora, assuming Hunyuan.") + state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) + + transformer.load_lora_adapter(state_dict, network_alphas=None) + print("LoRA weights loaded successfully.") + return transformer + \ No newline at end of file diff --git a/diffusers_helper/memory.py b/diffusers_helper/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..3380c538a185b0cbd07657ea475d0f5a0aeb17d3 --- /dev/null +++ b/diffusers_helper/memory.py @@ -0,0 +1,134 @@ +# By lllyasviel + + +import torch + + +cpu = torch.device('cpu') +gpu = torch.device(f'cuda:{torch.cuda.current_device()}') +gpu_complete_modules = [] + + +class DynamicSwapInstaller: + @staticmethod + def _install_module(module: torch.nn.Module, **kwargs): + original_class = module.__class__ + module.__dict__['forge_backup_original_class'] = original_class + + def hacked_get_attr(self, name: str): + if '_parameters' in self.__dict__: + _parameters = self.__dict__['_parameters'] + if name in _parameters: + p = _parameters[name] + if p is None: + return None + if p.__class__ == torch.nn.Parameter: + return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad) + else: + return p.to(**kwargs) + if '_buffers' in self.__dict__: + _buffers = self.__dict__['_buffers'] + if name in _buffers: + return _buffers[name].to(**kwargs) + return super(original_class, self).__getattr__(name) + + module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), { + '__getattr__': hacked_get_attr, + }) + + return + + @staticmethod + def _uninstall_module(module: torch.nn.Module): + if 'forge_backup_original_class' in module.__dict__: + module.__class__ = module.__dict__.pop('forge_backup_original_class') + return + + @staticmethod + def install_model(model: torch.nn.Module, **kwargs): + for m in model.modules(): + DynamicSwapInstaller._install_module(m, **kwargs) + return + + @staticmethod + def uninstall_model(model: torch.nn.Module): + for m in model.modules(): + DynamicSwapInstaller._uninstall_module(m) + return + + +def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device): + if hasattr(model, 'scale_shift_table'): + model.scale_shift_table.data = model.scale_shift_table.data.to(target_device) + return + + for k, p in model.named_modules(): + if hasattr(p, 'weight'): + p.to(target_device) + return + + +def get_cuda_free_memory_gb(device=None): + if device is None: + device = gpu + + memory_stats = torch.cuda.memory_stats(device) + bytes_active = memory_stats['active_bytes.all.current'] + bytes_reserved = memory_stats['reserved_bytes.all.current'] + bytes_free_cuda, _ = torch.cuda.mem_get_info(device) + bytes_inactive_reserved = bytes_reserved - bytes_active + bytes_total_available = bytes_free_cuda + bytes_inactive_reserved + return bytes_total_available / (1024 ** 3) + + +def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0): + print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB') + + for m in model.modules(): + if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb: + torch.cuda.empty_cache() + return + + if hasattr(m, 'weight'): + m.to(device=target_device) + + model.to(device=target_device) + torch.cuda.empty_cache() + return + + +def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0): + print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB') + + for m in model.modules(): + if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb: + torch.cuda.empty_cache() + return + + if hasattr(m, 'weight'): + m.to(device=cpu) + + model.to(device=cpu) + torch.cuda.empty_cache() + return + + +def unload_complete_models(*args): + for m in gpu_complete_modules + list(args): + m.to(device=cpu) + print(f'Unloaded {m.__class__.__name__} as complete.') + + gpu_complete_modules.clear() + torch.cuda.empty_cache() + return + + +def load_model_as_complete(model, target_device, unload=True): + if unload: + unload_complete_models() + + model.to(device=target_device) + print(f'Loaded {model.__class__.__name__} to {target_device} as complete.') + + gpu_complete_modules.append(model) + return diff --git a/diffusers_helper/models/hunyuan_video_packed.py b/diffusers_helper/models/hunyuan_video_packed.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb42abcb52c20a457e50aa066ec1d4a89c6d57f --- /dev/null +++ b/diffusers_helper/models/hunyuan_video_packed.py @@ -0,0 +1,1035 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import einops +import torch.nn as nn +import numpy as np + +from diffusers.loaders import FromOriginalModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.utils import logging +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers_helper.dit_common import LayerNorm +from diffusers_helper.utils import zero_module + + +enabled_backends = [] + +if torch.backends.cuda.flash_sdp_enabled(): + enabled_backends.append("flash") +if torch.backends.cuda.math_sdp_enabled(): + enabled_backends.append("math") +if torch.backends.cuda.mem_efficient_sdp_enabled(): + enabled_backends.append("mem_efficient") +if torch.backends.cuda.cudnn_sdp_enabled(): + enabled_backends.append("cudnn") + +print("Currently enabled native sdp backends:", enabled_backends) + +try: + # raise NotImplementedError + from xformers.ops import memory_efficient_attention as xformers_attn_func + print('Xformers is installed!') +except: + print('Xformers is not installed!') + xformers_attn_func = None + +try: + # raise NotImplementedError + from flash_attn import flash_attn_varlen_func, flash_attn_func + print('Flash Attn is installed!') +except: + print('Flash Attn is not installed!') + flash_attn_varlen_func = None + flash_attn_func = None + +try: + # raise NotImplementedError + from sageattention import sageattn_varlen, sageattn + print('Sage Attn is installed!') +except: + print('Sage Attn is not installed!') + sageattn_varlen = None + sageattn = None + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def pad_for_3d_conv(x, kernel_size): + b, c, t, h, w = x.shape + pt, ph, pw = kernel_size + pad_t = (pt - (t % pt)) % pt + pad_h = (ph - (h % ph)) % ph + pad_w = (pw - (w % pw)) % pw + return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate') + + +def center_down_sample_3d(x, kernel_size): + # pt, ph, pw = kernel_size + # cp = (pt * ph * pw) // 2 + # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw) + # xc = xp[cp] + # return xc + return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) + + +def get_cu_seqlens(text_mask, img_len): + batch_size = text_mask.shape[0] + text_len = text_mask.sum(dim=1) + max_len = text_mask.shape[1] + img_len + + cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") + + for i in range(batch_size): + s = text_len[i] + img_len + s1 = i * max_len + s + s2 = (i + 1) * max_len + cu_seqlens[2 * i + 1] = s1 + cu_seqlens[2 * i + 2] = s2 + + return cu_seqlens + + +def apply_rotary_emb_transposed(x, freqs_cis): + cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) + x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = x.float() * cos + x_rotated.float() * sin + out = out.to(x) + return out + + +def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): + if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None: + if sageattn is not None: + x = sageattn(q, k, v, tensor_layout='NHD') + return x + + if flash_attn_func is not None: + x = flash_attn_func(q, k, v) + return x + + if xformers_attn_func is not None: + x = xformers_attn_func(q, k, v) + return x + + x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2) + return x + + B, L, H, C = q.shape + + q = q.flatten(0, 1) + k = k.flatten(0, 1) + v = v.flatten(0, 1) + + if sageattn_varlen is not None: + x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + elif flash_attn_varlen_func is not None: + x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + else: + raise NotImplementedError('No Attn Installed!') + + x = x.unflatten(0, (B, L)) + + return x + + +class HunyuanAttnProcessorFlashAttnDouble: + def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): + cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = apply_rotary_emb_transposed(query, image_rotary_emb) + key = apply_rotary_emb_transposed(key, image_rotary_emb) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=1) + key = torch.cat([key, encoder_key], dim=1) + value = torch.cat([value, encoder_value], dim=1) + + hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + hidden_states = hidden_states.flatten(-2) + + txt_length = encoder_hidden_states.shape[1] + hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanAttnProcessorFlashAttnSingle: + def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): + cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask + + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + txt_length = encoder_hidden_states.shape[1] + + query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1) + key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1) + + hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + hidden_states = hidden_states.flatten(-2) + + hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] + + return hidden_states, encoder_hidden_states + + +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=-1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + self_attn_mask[:, :, :, 0] = True + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, rope_dim, theta): + super().__init__() + self.DT, self.DY, self.DX = rope_dim + self.theta = theta + + @torch.no_grad() + def get_frequency(self, dim, pos): + T, H, W = pos.shape + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim)) + freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0) + return freqs.cos(), freqs.sin() + + @torch.no_grad() + def forward_inner(self, frame_indices, height, width, device): + GT, GY, GX = torch.meshgrid( + frame_indices.to(device=device, dtype=torch.float32), + torch.arange(0, height, device=device, dtype=torch.float32), + torch.arange(0, width, device=device, dtype=torch.float32), + indexing="ij" + ) + + FCT, FST = self.get_frequency(self.DT, GT) + FCY, FSY = self.get_frequency(self.DY, GY) + FCX, FSX = self.get_frequency(self.DX, GX) + + result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) + + return result.to(device) + + @torch.no_grad() + def forward(self, frame_indices, height, width, device): + frame_indices = frame_indices.unbind(0) + results = [self.forward_inner(f, height, width, device) for f in frame_indices] + results = torch.stack(results, dim=0) + return results + + +class AdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = emb.unsqueeze(-2) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormZeroSingle(nn.Module): + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = emb.unsqueeze(-2) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + emb = emb.unsqueeze(-2) + emb = self.linear(self.silu(emb)) + scale, shift = emb.chunk(2, dim=-1) + x = self.norm(x) * (1 + scale) + shift + return x + + +class HunyuanVideoSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanAttnProcessorFlashAttnSingle(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanAttnProcessorFlashAttnDouble(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + + return hidden_states, encoder_hidden_states + + +class ClipVisionProjection(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.up = nn.Linear(in_channels, out_channels * 3) + self.down = nn.Linear(out_channels * 3, out_channels) + + def forward(self, x): + projected_x = self.down(nn.functional.silu(self.up(x))) + return projected_x + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__(self, patch_size, in_chans, embed_dim): + super().__init__() + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + +class HunyuanVideoPatchEmbedForCleanLatents(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + @torch.no_grad() + def initialize_weight_from_another_conv3d(self, another_layer): + weight = another_layer.weight.detach().clone() + bias = another_layer.bias.detach().clone() + + sd = { + 'proj.weight': weight.clone(), + 'proj.bias': bias.clone(), + 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0, + 'proj_2x.bias': bias.clone(), + 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0, + 'proj_4x.bias': bias.clone(), + } + + sd = {k: v.clone() for k, v in sd.items()} + + self.load_state_dict(sd) + return + + +class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), + has_image_proj=False, + image_proj_dim=1152, + has_clean_x_embedder=False, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + + self.clean_x_embedder = None + self.image_projection = None + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.inner_dim = inner_dim + self.use_gradient_checkpointing = False + self.enable_teacache = False + + if has_image_proj: + self.install_image_projection(image_proj_dim) + + if has_clean_x_embedder: + self.install_clean_x_embedder() + + self.high_quality_fp32_output_for_inference = False + + def install_image_projection(self, in_channels): + self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim) + self.config['has_image_proj'] = True + self.config['image_proj_dim'] = in_channels + + def install_clean_x_embedder(self): + self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim) + self.config['has_clean_x_embedder'] = True + + def enable_gradient_checkpointing(self): + self.use_gradient_checkpointing = True + print('self.use_gradient_checkpointing = True') + + def disable_gradient_checkpointing(self): + self.use_gradient_checkpointing = False + print('self.use_gradient_checkpointing = False') + + def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15): + self.enable_teacache = enable_teacache + self.cnt = 0 + self.num_steps = num_steps + self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_residual = None + self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]) + + def gradient_checkpointing_method(self, block, *args): + if self.use_gradient_checkpointing: + result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False) + else: + result = block(*args) + return result + + def process_input_hidden_states( + self, + latents, latent_indices=None, + clean_latents=None, clean_latent_indices=None, + clean_latents_2x=None, clean_latent_2x_indices=None, + clean_latents_4x=None, clean_latent_4x_indices=None + ): + hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents) + B, C, T, H, W = hidden_states.shape + + if latent_indices is None: + latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1) + + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device) + rope_freqs = rope_freqs.flatten(2).transpose(1, 2) + + if clean_latents is not None and clean_latent_indices is not None: + clean_latents = clean_latents.to(hidden_states) + clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents) + clean_latents = clean_latents.flatten(2).transpose(1, 2) + + clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device) + clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([clean_latents, hidden_states], dim=1) + rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1) + + if clean_latents_2x is not None and clean_latent_2x_indices is not None: + clean_latents_2x = clean_latents_2x.to(hidden_states) + clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4)) + clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x) + clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) + + clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device) + clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2)) + clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2)) + clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1) + rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1) + + if clean_latents_4x is not None and clean_latent_4x_indices is not None: + clean_latents_4x = clean_latents_4x.to(hidden_states) + clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8)) + clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x) + clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) + + clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device) + clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4)) + clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4)) + clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1) + rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1) + + return hidden_states, rope_freqs + + def forward( + self, + hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance, + latent_indices=None, + clean_latents=None, clean_latent_indices=None, + clean_latents_2x=None, clean_latent_2x_indices=None, + clean_latents_4x=None, clean_latent_4x_indices=None, + image_embeddings=None, + attention_kwargs=None, return_dict=True + ): + + if attention_kwargs is None: + attention_kwargs = {} + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p, p_t = self.config['patch_size'], self.config['patch_size_t'] + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + original_context_length = post_patch_num_frames * post_patch_height * post_patch_width + + hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices) + + temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections) + encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask) + + if self.image_projection is not None: + assert image_embeddings is not None, 'You must use image embeddings!' + extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings) + extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device) + + # must cat before (not after) encoder_hidden_states, due to attn masking + encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1) + encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1) + + if batch_size == 1: + # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want + # If they are not same, then their impls are wrong. Ours are always the correct one. + text_len = encoder_attention_mask.sum().item() + encoder_hidden_states = encoder_hidden_states[:, :text_len] + attention_mask = None, None, None, None + else: + img_seq_len = hidden_states.shape[1] + txt_seq_len = encoder_hidden_states.shape[1] + + cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len) + cu_seqlens_kv = cu_seqlens_q + max_seqlen_q = img_seq_len + txt_seq_len + max_seqlen_kv = max_seqlen_q + + attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + + if self.enable_teacache: + modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] + + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item() + self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1) + should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh + + if should_calc: + self.accumulated_rel_l1_distance = 0 + + self.previous_modulated_input = modulated_inp + self.cnt += 1 + + if self.cnt == self.num_steps: + self.cnt = 0 + + if not should_calc: + hidden_states = hidden_states + self.previous_residual + else: + ori_hidden_states = hidden_states.clone() + + for block_id, block in enumerate(self.transformer_blocks): + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + rope_freqs + ) + + for block_id, block in enumerate(self.single_transformer_blocks): + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + rope_freqs + ) + + self.previous_residual = hidden_states - ori_hidden_states + else: + for block_id, block in enumerate(self.transformer_blocks): + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + rope_freqs + ) + + for block_id, block in enumerate(self.single_transformer_blocks): + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + rope_freqs + ) + + hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb) + + hidden_states = hidden_states[:, -original_context_length:, :] + + if self.high_quality_fp32_output_for_inference: + hidden_states = hidden_states.to(dtype=torch.float32) + if self.proj_out.weight.dtype != torch.float32: + self.proj_out.to(dtype=torch.float32) + + hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states) + + hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)', + t=post_patch_num_frames, h=post_patch_height, w=post_patch_width, + pt=p_t, ph=p, pw=p) + + if return_dict: + return Transformer2DModelOutput(sample=hidden_states) + + return hidden_states, diff --git a/diffusers_helper/pipelines/k_diffusion_hunyuan.py b/diffusers_helper/pipelines/k_diffusion_hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..d72b44b859c0042af1e227612edd76fa85880548 --- /dev/null +++ b/diffusers_helper/pipelines/k_diffusion_hunyuan.py @@ -0,0 +1,120 @@ +import torch +import math + +from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc +from diffusers_helper.k_diffusion.wrapper import fm_wrapper +from diffusers_helper.utils import repeat_to_batch_size + + +def flux_time_shift(t, mu=1.15, sigma=1.0): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0): + k = (y2 - y1) / (x2 - x1) + b = y1 - k * x1 + mu = k * context_length + b + mu = min(mu, math.log(exp_max)) + return mu + + +def get_flux_sigmas_from_mu(n, mu): + sigmas = torch.linspace(1, 0, steps=n + 1) + sigmas = flux_time_shift(sigmas, mu=mu) + return sigmas + + +@torch.inference_mode() +def sample_hunyuan( + transformer, + sampler='unipc', + initial_latent=None, + concat_latent=None, + strength=1.0, + width=512, + height=512, + frames=16, + real_guidance_scale=1.0, + distilled_guidance_scale=6.0, + guidance_rescale=0.0, + shift=None, + num_inference_steps=25, + batch_size=None, + generator=None, + prompt_embeds=None, + prompt_embeds_mask=None, + prompt_poolers=None, + negative_prompt_embeds=None, + negative_prompt_embeds_mask=None, + negative_prompt_poolers=None, + dtype=torch.bfloat16, + device=None, + negative_kwargs=None, + callback=None, + **kwargs, +): + device = device or transformer.device + + if batch_size is None: + batch_size = int(prompt_embeds.shape[0]) + + latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32) + + B, C, T, H, W = latents.shape + seq_length = T * H * W // 4 + + if shift is None: + mu = calculate_flux_mu(seq_length, exp_max=7.0) + else: + mu = math.log(shift) + + sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device) + + k_model = fm_wrapper(transformer) + + if initial_latent is not None: + sigmas = sigmas * strength + first_sigma = sigmas[0].to(device=device, dtype=torch.float32) + initial_latent = initial_latent.to(device=device, dtype=torch.float32) + latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma + + if concat_latent is not None: + concat_latent = concat_latent.to(latents) + + distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype) + + prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size) + prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size) + prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size) + negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size) + negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size) + negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size) + concat_latent = repeat_to_batch_size(concat_latent, batch_size) + + sampler_kwargs = dict( + dtype=dtype, + cfg_scale=real_guidance_scale, + cfg_rescale=guidance_rescale, + concat_latent=concat_latent, + positive=dict( + pooled_projections=prompt_poolers, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_embeds_mask, + guidance=distilled_guidance, + **kwargs, + ), + negative=dict( + pooled_projections=negative_prompt_poolers, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_embeds_mask, + guidance=distilled_guidance, + **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}), + ) + ) + + if sampler == 'unipc': + results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback) + else: + raise NotImplementedError(f'Sampler {sampler} is not supported.') + + return results diff --git a/diffusers_helper/thread_utils.py b/diffusers_helper/thread_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..144fdad6a218b10e77944e927ea350bb84b559a1 --- /dev/null +++ b/diffusers_helper/thread_utils.py @@ -0,0 +1,76 @@ +import time + +from threading import Thread, Lock + + +class Listener: + task_queue = [] + lock = Lock() + thread = None + + @classmethod + def _process_tasks(cls): + while True: + task = None + with cls.lock: + if cls.task_queue: + task = cls.task_queue.pop(0) + + if task is None: + time.sleep(0.001) + continue + + func, args, kwargs = task + try: + func(*args, **kwargs) + except Exception as e: + print(f"Error in listener thread: {e}") + + @classmethod + def add_task(cls, func, *args, **kwargs): + with cls.lock: + cls.task_queue.append((func, args, kwargs)) + + if cls.thread is None: + cls.thread = Thread(target=cls._process_tasks, daemon=True) + cls.thread.start() + + +def async_run(func, *args, **kwargs): + Listener.add_task(func, *args, **kwargs) + + +class FIFOQueue: + def __init__(self): + self.queue = [] + self.lock = Lock() + + def push(self, item): + with self.lock: + self.queue.append(item) + + def pop(self): + with self.lock: + if self.queue: + return self.queue.pop(0) + return None + + def top(self): + with self.lock: + if self.queue: + return self.queue[0] + return None + + def next(self): + while True: + with self.lock: + if self.queue: + return self.queue.pop(0) + + time.sleep(0.001) + + +class AsyncStream: + def __init__(self): + self.input_queue = FIFOQueue() + self.output_queue = FIFOQueue() diff --git a/diffusers_helper/utils.py b/diffusers_helper/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fea3a9d8f60747be78b915b918e92ae6f85289e4 --- /dev/null +++ b/diffusers_helper/utils.py @@ -0,0 +1,613 @@ +import os +import cv2 +import json +import random +import glob +import torch +import einops +import numpy as np +import datetime +import torchvision + +import safetensors.torch as sf +from PIL import Image + + +def min_resize(x, m): + if x.shape[0] < x.shape[1]: + s0 = m + s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) + else: + s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) + s1 = m + new_max = max(s1, s0) + raw_max = max(x.shape[0], x.shape[1]) + if new_max < raw_max: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + y = cv2.resize(x, (s1, s0), interpolation=interpolation) + return y + + +def d_resize(x, y): + H, W, C = y.shape + new_min = min(H, W) + raw_min = min(x.shape[0], x.shape[1]) + if new_min < raw_min: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + y = cv2.resize(x, (W, H), interpolation=interpolation) + return y + + +def resize_and_center_crop(image, target_width, target_height): + if target_height == image.shape[0] and target_width == image.shape[1]: + return image + + pil_image = Image.fromarray(image) + original_width, original_height = pil_image.size + scale_factor = max(target_width / original_width, target_height / original_height) + resized_width = int(round(original_width * scale_factor)) + resized_height = int(round(original_height * scale_factor)) + resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) + left = (resized_width - target_width) / 2 + top = (resized_height - target_height) / 2 + right = (resized_width + target_width) / 2 + bottom = (resized_height + target_height) / 2 + cropped_image = resized_image.crop((left, top, right, bottom)) + return np.array(cropped_image) + + +def resize_and_center_crop_pytorch(image, target_width, target_height): + B, C, H, W = image.shape + + if H == target_height and W == target_width: + return image + + scale_factor = max(target_width / W, target_height / H) + resized_width = int(round(W * scale_factor)) + resized_height = int(round(H * scale_factor)) + + resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False) + + top = (resized_height - target_height) // 2 + left = (resized_width - target_width) // 2 + cropped = resized[:, :, top:top + target_height, left:left + target_width] + + return cropped + + +def resize_without_crop(image, target_width, target_height): + if target_height == image.shape[0] and target_width == image.shape[1]: + return image + + pil_image = Image.fromarray(image) + resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) + return np.array(resized_image) + + +def just_crop(image, w, h): + if h == image.shape[0] and w == image.shape[1]: + return image + + original_height, original_width = image.shape[:2] + k = min(original_height / h, original_width / w) + new_width = int(round(w * k)) + new_height = int(round(h * k)) + x_start = (original_width - new_width) // 2 + y_start = (original_height - new_height) // 2 + cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width] + return cropped_image + + +def write_to_json(data, file_path): + temp_file_path = file_path + ".tmp" + with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: + json.dump(data, temp_file, indent=4) + os.replace(temp_file_path, file_path) + return + + +def read_from_json(file_path): + with open(file_path, 'rt', encoding='utf-8') as file: + data = json.load(file) + return data + + +def get_active_parameters(m): + return {k: v for k, v in m.named_parameters() if v.requires_grad} + + +def cast_training_params(m, dtype=torch.float32): + result = {} + for n, param in m.named_parameters(): + if param.requires_grad: + param.data = param.to(dtype) + result[n] = param + return result + + +def separate_lora_AB(parameters, B_patterns=None): + parameters_normal = {} + parameters_B = {} + + if B_patterns is None: + B_patterns = ['.lora_B.', '__zero__'] + + for k, v in parameters.items(): + if any(B_pattern in k for B_pattern in B_patterns): + parameters_B[k] = v + else: + parameters_normal[k] = v + + return parameters_normal, parameters_B + + +def set_attr_recursive(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + setattr(obj, attrs[-1], value) + return + + +def print_tensor_list_size(tensors): + total_size = 0 + total_elements = 0 + + if isinstance(tensors, dict): + tensors = tensors.values() + + for tensor in tensors: + total_size += tensor.nelement() * tensor.element_size() + total_elements += tensor.nelement() + + total_size_MB = total_size / (1024 ** 2) + total_elements_B = total_elements / 1e9 + + print(f"Total number of tensors: {len(tensors)}") + print(f"Total size of tensors: {total_size_MB:.2f} MB") + print(f"Total number of parameters: {total_elements_B:.3f} billion") + return + + +@torch.no_grad() +def batch_mixture(a, b=None, probability_a=0.5, mask_a=None): + batch_size = a.size(0) + + if b is None: + b = torch.zeros_like(a) + + if mask_a is None: + mask_a = torch.rand(batch_size) < probability_a + + mask_a = mask_a.to(a.device) + mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) + result = torch.where(mask_a, a, b) + return result + + +@torch.no_grad() +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +@torch.no_grad() +def supress_lower_channels(m, k, alpha=0.01): + data = m.weight.data.clone() + + assert int(data.shape[1]) >= k + + data[:, :k] = data[:, :k] * alpha + m.weight.data = data.contiguous().clone() + return m + + +def freeze_module(m): + if not hasattr(m, '_forward_inside_frozen_module'): + m._forward_inside_frozen_module = m.forward + m.requires_grad_(False) + m.forward = torch.no_grad()(m.forward) + return m + + +def get_latest_safetensors(folder_path): + safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors')) + + if not safetensors_files: + raise ValueError('No file to resume!') + + latest_file = max(safetensors_files, key=os.path.getmtime) + latest_file = os.path.abspath(os.path.realpath(latest_file)) + return latest_file + + +def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): + tags = tags_str.split(', ') + tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) + prompt = ', '.join(tags) + return prompt + + +def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0): + numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma) + if round_to_int: + numbers = np.round(numbers).astype(int) + return numbers.tolist() + + +def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False): + edges = np.linspace(0, 1, n + 1) + points = np.random.uniform(edges[:-1], edges[1:]) + numbers = inclusive + (exclusive - inclusive) * points + if round_to_int: + numbers = np.round(numbers).astype(int) + return numbers.tolist() + + +def soft_append_bcthw(history, current, overlap=0): + if overlap <= 0: + return torch.cat([history, current], dim=2) + + assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})" + assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})" + + weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) + blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] + output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) + + return output.to(history) + + +def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0): + b, c, t, h, w = x.shape + + per_row = b + for p in [6, 5, 4, 3, 2]: + if b % p == 0: + per_row = p + break + + os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) + x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 + x = x.detach().cpu().to(torch.uint8) + x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) + torchvision.io.write_video(output_filename, x, fps=int(fps), video_codec='libx264', options={'crf': str(int(crf))}) + return x + + +def save_bcthw_as_png(x, output_filename): + os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) + x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 + x = x.detach().cpu().to(torch.uint8) + x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') + torchvision.io.write_png(x, output_filename) + return output_filename + + +def save_bchw_as_png(x, output_filename): + os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) + x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 + x = x.detach().cpu().to(torch.uint8) + x = einops.rearrange(x, 'b c h w -> c h (b w)') + torchvision.io.write_png(x, output_filename) + return output_filename + + +def add_tensors_with_padding(tensor1, tensor2): + if tensor1.shape == tensor2.shape: + return tensor1 + tensor2 + + shape1 = tensor1.shape + shape2 = tensor2.shape + + new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) + + padded_tensor1 = torch.zeros(new_shape) + padded_tensor2 = torch.zeros(new_shape) + + padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 + padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 + + result = padded_tensor1 + padded_tensor2 + return result + + +def print_free_mem(): + torch.cuda.empty_cache() + free_mem, total_mem = torch.cuda.mem_get_info(0) + free_mem_mb = free_mem / (1024 ** 2) + total_mem_mb = total_mem / (1024 ** 2) + print(f"Free memory: {free_mem_mb:.2f} MB") + print(f"Total memory: {total_mem_mb:.2f} MB") + return + + +def print_gpu_parameters(device, state_dict, log_count=1): + summary = {"device": device, "keys_count": len(state_dict)} + + logged_params = {} + for i, (key, tensor) in enumerate(state_dict.items()): + if i >= log_count: + break + logged_params[key] = tensor.flatten()[:3].tolist() + + summary["params"] = logged_params + + print(str(summary)) + return + + +def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18): + from PIL import Image, ImageDraw, ImageFont + + txt = Image.new("RGB", (width, height), color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype(font_path, size=size) + + if text == '': + return np.array(txt) + + # Split text into lines that fit within the image width + lines = [] + words = text.split() + current_line = words[0] + + for word in words[1:]: + line_with_word = f"{current_line} {word}" + if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width: + current_line = line_with_word + else: + lines.append(current_line) + current_line = word + + lines.append(current_line) + + # Draw the text line by line + y = 0 + line_height = draw.textbbox((0, 0), "A", font=font)[3] + + for line in lines: + if y + line_height > height: + break # stop drawing if the next line will be outside the image + draw.text((0, y), line, fill="black", font=font) + y += line_height + + return np.array(txt) + + +def blue_mark(x): + x = x.copy() + c = x[:, :, 2] + b = cv2.blur(c, (9, 9)) + x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1) + return x + + +def green_mark(x): + x = x.copy() + x[:, :, 2] = -1 + x[:, :, 0] = -1 + return x + + +def frame_mark(x): + x = x.copy() + x[:64] = -1 + x[-64:] = -1 + x[:, :8] = 1 + x[:, -8:] = 1 + return x + + +@torch.inference_mode() +def pytorch2numpy(imgs): + results = [] + for x in imgs: + y = x.movedim(0, -1) + y = y * 127.5 + 127.5 + y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) + results.append(y) + return results + + +@torch.inference_mode() +def numpy2pytorch(imgs): + h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 + h = h.movedim(-1, 1) + return h + + +@torch.no_grad() +def duplicate_prefix_to_suffix(x, count, zero_out=False): + if zero_out: + return torch.cat([x, torch.zeros_like(x[:count])], dim=0) + else: + return torch.cat([x, x[:count]], dim=0) + + +def weighted_mse(a, b, weight): + return torch.mean(weight.float() * (a.float() - b.float()) ** 2) + + +def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0): + x = (x - x_min) / (x_max - x_min) + x = max(0.0, min(x, 1.0)) + x = x ** sigma + return y_min + x * (y_max - y_min) + + +def expand_to_dims(x, target_dims): + return x.view(*x.shape, *([1] * max(0, target_dims - x.dim()))) + + +def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int): + if tensor is None: + return None + + first_dim = tensor.shape[0] + + if first_dim == batch_size: + return tensor + + if batch_size % first_dim != 0: + raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.") + + repeat_times = batch_size // first_dim + + return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1)) + + +def dim5(x): + return expand_to_dims(x, 5) + + +def dim4(x): + return expand_to_dims(x, 4) + + +def dim3(x): + return expand_to_dims(x, 3) + + +def crop_or_pad_yield_mask(x, length): + B, F, C = x.shape + device = x.device + dtype = x.dtype + + if F < length: + y = torch.zeros((B, length, C), dtype=dtype, device=device) + mask = torch.zeros((B, length), dtype=torch.bool, device=device) + y[:, :F, :] = x + mask[:, :F] = True + return y, mask + + return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device) + + +def extend_dim(x, dim, minimal_length, zero_pad=False): + original_length = int(x.shape[dim]) + + if original_length >= minimal_length: + return x + + if zero_pad: + padding_shape = list(x.shape) + padding_shape[dim] = minimal_length - original_length + padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device) + else: + idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1) + last_element = x[idx] + padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim) + + return torch.cat([x, padding], dim=dim) + + +def lazy_positional_encoding(t, repeats=None): + if not isinstance(t, list): + t = [t] + + from diffusers.models.embeddings import get_timestep_embedding + + te = torch.tensor(t) + te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0) + + if repeats is None: + return te + + te = te[:, None, :].expand(-1, repeats, -1) + + return te + + +def state_dict_offset_merge(A, B, C=None): + result = {} + keys = A.keys() + + for key in keys: + A_value = A[key] + B_value = B[key].to(A_value) + + if C is None: + result[key] = A_value + B_value + else: + C_value = C[key].to(A_value) + result[key] = A_value + B_value - C_value + + return result + + +def state_dict_weighted_merge(state_dicts, weights): + if len(state_dicts) != len(weights): + raise ValueError("Number of state dictionaries must match number of weights") + + if not state_dicts: + return {} + + total_weight = sum(weights) + + if total_weight == 0: + raise ValueError("Sum of weights cannot be zero") + + normalized_weights = [w / total_weight for w in weights] + + keys = state_dicts[0].keys() + result = {} + + for key in keys: + result[key] = state_dicts[0][key] * normalized_weights[0] + + for i in range(1, len(state_dicts)): + state_dict_value = state_dicts[i][key].to(result[key]) + result[key] += state_dict_value * normalized_weights[i] + + return result + + +def group_files_by_folder(all_files): + grouped_files = {} + + for file in all_files: + folder_name = os.path.basename(os.path.dirname(file)) + if folder_name not in grouped_files: + grouped_files[folder_name] = [] + grouped_files[folder_name].append(file) + + list_of_lists = list(grouped_files.values()) + return list_of_lists + + +def generate_timestamp(): + now = datetime.datetime.now() + timestamp = now.strftime('%y%m%d_%H%M%S') + milliseconds = f"{int(now.microsecond / 1000):03d}" + random_number = random.randint(0, 9999) + return f"{timestamp}_{milliseconds}_{random_number}" + + +def write_PIL_image_with_png_info(image, metadata, path): + from PIL.PngImagePlugin import PngInfo + + png_info = PngInfo() + for key, value in metadata.items(): + png_info.add_text(key, value) + + image.save(path, "PNG", pnginfo=png_info) + return image + + +def torch_safe_save(content, path): + torch.save(content, path + '_tmp') + os.replace(path + '_tmp', path) + return path + + +def move_optimizer_to_device(optimizer, device): + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) diff --git a/docs/advanced_config.md b/docs/advanced_config.md new file mode 100644 index 0000000000000000000000000000000000000000..467a75d0b40c061b9fc1b61ca19f0065aa44d4b0 --- /dev/null +++ b/docs/advanced_config.md @@ -0,0 +1,316 @@ +> 📝 Click on the language section to expand / 言語をクリックして展開 + +# Advanced configuration / 高度な設定 + +## Table of contents / 目次 + +- [How to specify `network_args`](#how-to-specify-network_args--network_argsの指定方法) +- [LoRA+](#lora) +- [Select the target modules of LoRA](#select-the-target-modules-of-lora--loraの対象モジュールを選択する) +- [Save and view logs in TensorBoard format](#save-and-view-logs-in-tensorboard-format--tensorboard形式のログの保存と参照) +- [Save and view logs in wandb](#save-and-view-logs-in-wandb--wandbでログの保存と参照) +- [FP8 weight optimization for models](#fp8-weight-optimization-for-models--モデルの重みのfp8への最適化) +- [PyTorch Dynamo optimization for model training](#pytorch-dynamo-optimization-for-model-training--モデルの学習におけるpytorch-dynamoの最適化) + +## How to specify `network_args` / `network_args`の指定方法 + +The `--network_args` option is an option for specifying detailed arguments to LoRA. Specify the arguments in the form of `key=value` in `--network_args`. + +
+日本語 +`--network_args`オプションは、LoRAへの詳細な引数を指定するためのオプションです。`--network_args`には、`key=value`の形式で引数を指定します。 +
+ +### Example / 記述例 + +If you specify it on the command line, write as follows. / コマンドラインで指定する場合は以下のように記述します。 + +```bash +accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ... + --network_module networks.lora --network_dim 32 + --network_args "key1=value1" "key2=value2" ... +``` + +If you specify it in the configuration file, write as follows. / 設定ファイルで指定する場合は以下のように記述します。 + +```toml +network_args = ["key1=value1", "key2=value2", ...] +``` + +If you specify `"verbose=True"`, detailed information of LoRA will be displayed. / `"verbose=True"`を指定するとLoRAの詳細な情報が表示されます。 + +```bash +--network_args "verbose=True" "key1=value1" "key2=value2" ... +``` + +## LoRA+ + +LoRA+ is a method to improve the training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiplier for the learning rate. The original paper recommends 16, but adjust as needed. It seems to be good to start from around 4. For details, please refer to the [related PR of sd-scripts](https://github.com/kohya-ss/sd-scripts/pull/1233). + +Specify `loraplus_lr_ratio` with `--network_args`. + +
+日本語 + +LoRA+は、LoRAのUP側(LoRA-B)の学習率を上げることで学習速度を向上させる手法です。学習率に対する倍率を指定します。元論文では16を推奨していますが、必要に応じて調整してください。4程度から始めるとよいようです。詳細は[sd-scriptsの関連PR]https://github.com/kohya-ss/sd-scripts/pull/1233)を参照してください。 + +`--network_args`で`loraplus_lr_ratio`を指定します。 +
+ +### Example / 記述例 + +```bash +accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ... + --network_module networks.lora --network_dim 32 --network_args "loraplus_lr_ratio=4" ... +``` + +## Select the target modules of LoRA / LoRAの対象モジュールを選択する + +*This feature is highly experimental and the specification may change. / この機能は特に実験的なもので、仕様は変更される可能性があります。* + +By specifying `exclude_patterns` and `include_patterns` with `--network_args`, you can select the target modules of LoRA. + +`exclude_patterns` excludes modules that match the specified pattern. `include_patterns` targets only modules that match the specified pattern. + +Specify the values as a list. For example, `"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`. + +The pattern is a regular expression for the module name. The module name is in the form of `double_blocks.0.img_mod.linear` or `single_blocks.39.modulation.linear`. The regular expression is not a partial match but a complete match. + +The patterns are applied in the order of `exclude_patterns`→`include_patterns`. By default, the Linear layers of `img_mod`, `txt_mod`, and `modulation` of double blocks and single blocks are excluded. + +(`.*(img_mod|txt_mod|modulation).*` is specified.) + +
+日本語 + +`--network_args`で`exclude_patterns`と`include_patterns`を指定することで、LoRAの対象モジュールを選択することができます。 + +`exclude_patterns`は、指定したパターンに一致するモジュールを除外します。`include_patterns`は、指定したパターンに一致するモジュールのみを対象とします。 + +値は、リストで指定します。`"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`のようになります。 + +パターンは、モジュール名に対する正規表現です。モジュール名は、たとえば`double_blocks.0.img_mod.linear`や`single_blocks.39.modulation.linear`のような形式です。正規表現は部分一致ではなく完全一致です。 + +パターンは、`exclude_patterns`→`include_patterns`の順で適用されます。デフォルトは、double blocksとsingle blocksのLinear層のうち、`img_mod`、`txt_mod`、`modulation`が除外されています。 + +(`.*(img_mod|txt_mod|modulation).*`が指定されています。) +
+ +### Example / 記述例 + +Only the modules of double blocks / double blocksのモジュールのみを対象とする場合: + +```bash +--network_args "exclude_patterns=[r'.*single_blocks.*']" +``` + +Only the modules of single blocks from the 10th / single blocksの10番目以降のLinearモジュールのみを対象とする場合: + +```bash +--network_args "exclude_patterns=[r'.*']" "include_patterns=[r'.*single_blocks\.\d{2}\.linear.*']" +``` + +## Save and view logs in TensorBoard format / TensorBoard形式のログの保存と参照 + +Specify the folder to save the logs with the `--logging_dir` option. Logs in TensorBoard format will be saved. + +For example, if you specify `--logging_dir=logs`, a `logs` folder will be created in the working folder, and logs will be saved in the date folder inside it. + +Also, if you specify the `--log_prefix` option, the specified string will be added before the date. For example, use `--logging_dir=logs --log_prefix=lora_setting1_` for identification. + +To view logs in TensorBoard, open another command prompt and activate the virtual environment. Then enter the following in the working folder. + +```powershell +tensorboard --logdir=logs +``` + +(tensorboard installation is required.) + +Then open a browser and access http://localhost:6006/ to display it. + +
+日本語 +`--logging_dir`オプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。 + +たとえば`--logging_dir=logs`と指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。 + +また`--log_prefix`オプションを指定すると、日時の前に指定した文字列が追加されます。`--logging_dir=logs --log_prefix=lora_setting1_`などとして識別用にお使いください。 + +TensorBoardでログを確認するには、別のコマンドプロンプトを開き、仮想環境を有効にしてから、作業フォルダで以下のように入力します。 + +```powershell +tensorboard --logdir=logs +``` + +(tensorboardのインストールが必要です。) + +その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。 +
+ +## Save and view logs in wandb / wandbでログの保存と参照 + +`--log_with wandb` option is available to save logs in wandb format. `tensorboard` or `all` is also available. The default is `tensorboard`. + +Specify the project name with `--log_tracker_name` when using wandb. + +
+日本語 +`--log_with wandb`オプションを指定するとwandb形式でログを保存することができます。`tensorboard`や`all`も指定可能です。デフォルトは`tensorboard`です。 + +wandbを使用する場合は、`--log_tracker_name`でプロジェクト名を指定してください。 +
+ +## FP8 weight optimization for models / モデルの重みのFP8への最適化 + +The `--fp8_scaled` option is available to quantize the weights of the model to FP8 (E4M3) format with appropriate scaling. This reduces the VRAM usage while maintaining precision. Important weights are kept in FP16/BF16/FP32 format. + +The model weights must be in fp16 or bf16. Weights that have been pre-converted to float8_e4m3 cannot be used. + +Wan2.1 inference and training are supported. + +Specify the `--fp8_scaled` option in addition to the `--fp8` option during inference. + +Specify the `--fp8_scaled` option in addition to the `--fp8_base` option during training. + +Acknowledgments: This feature is based on the [implementation](https://github.com/Tencent/HunyuanVideo/blob/7df4a45c7e424a3f6cd7d653a7ff1f60cddc1eb1/hyvideo/modules/fp8_optimization.py) of [HunyuanVideo](https://github.com/Tencent/HunyuanVideo). The selection of high-precision modules is based on the [implementation](https://github.com/tdrussell/diffusion-pipe/blob/407c04fdae1c9ab5e67b54d33bef62c3e0a8dbc7/models/wan.py) of [diffusion-pipe](https://github.com/tdrussell/diffusion-pipe). I would like to thank these repositories. + +
+日本語 +重みを単純にFP8へcastするのではなく、適切なスケーリングでFP8形式に量子化することで、精度を維持しつつVRAM使用量を削減します。また、重要な重みはFP16/BF16/FP32形式で保持します。 + +モデルの重みは、fp16またはbf16が必要です。あらかじめfloat8_e4m3に変換された重みは使用できません。 + +Wan2.1の推論、学習のみ対応しています。 + +推論時は`--fp8`オプションに加えて `--fp8_scaled`オプションを指定してください。 + +学習時は`--fp8_base`オプションに加えて `--fp8_scaled`オプションを指定してください。 + +謝辞:この機能は、[HunyuanVideo](https://github.com/Tencent/HunyuanVideo)の[実装](https://github.com/Tencent/HunyuanVideo/blob/7df4a45c7e424a3f6cd7d653a7ff1f60cddc1eb1/hyvideo/modules/fp8_optimization.py)を参考にしました。また、高精度モジュールの選択においては[diffusion-pipe](https://github.com/tdrussell/diffusion-pipe)の[実装](https://github.com/tdrussell/diffusion-pipe/blob/407c04fdae1c9ab5e67b54d33bef62c3e0a8dbc7/models/wan.py)を参考にしました。これらのリポジトリに感謝します。 + +
+ +### Key features and implementation details / 主な特徴と実装の詳細 + +- Implements FP8 (E4M3) weight quantization for Linear layers +- Reduces VRAM requirements by using 8-bit weights for storage (slightly increased compared to existing `--fp8` `--fp8_base` options) +- Quantizes weights to FP8 format with appropriate scaling instead of simple cast to FP8 +- Maintains computational precision by dequantizing to original precision (FP16/BF16/FP32) during forward pass +- Preserves important weights in FP16/BF16/FP32 format + +The implementation: + +1. Quantizes weights to FP8 format with appropriate scaling +2. Replaces weights by FP8 quantized weights and stores scale factors in model state dict +3. Applies monkey patching to Linear layers for transparent dequantization during computation + +
+日本語 + +- Linear層のFP8(E4M3)重み量子化を実装 +- 8ビットの重みを使用することでVRAM使用量を削減(既存の`--fp8` `--fp8_base` オプションに比べて微増) +- 単純なFP8へのcastではなく、適切な値でスケールして重みをFP8形式に量子化 +- forward時に元の精度(FP16/BF16/FP32)に逆量子化して計算精度を維持 +- 精度が重要な重みはFP16/BF16/FP32のまま保持 + +実装: + +1. 精度を維持できる適切な倍率で重みをFP8形式に量子化 +2. 重みをFP8量子化重みに置き換え、倍率をモデルのstate dictに保存 +3. Linear層にmonkey patchingすることでモデルを変更せずに逆量子化 +
+ + ## PyTorch Dynamo optimization for model training / モデルの学習におけるPyTorch Dynamoの最適化 + +The PyTorch Dynamo options are now available to optimize the training process. PyTorch Dynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster by using TorchInductor, a deep learning compiler. This integration allows for potential speedups in training while maintaining model accuracy. + +[PR #215](https://github.com/kohya-ss/musubi-tuner/pull/215) added this feature. + +Specify the `--dynamo_backend` option to enable Dynamo optimization with one of the available backends from the `DynamoBackend` enum. + +Additional options allow for fine-tuning the Dynamo behavior: +- `--dynamo_mode`: Controls the optimization strategy +- `--dynamo_fullgraph`: Enables fullgraph mode for potentially better optimization +- `--dynamo_dynamic`: Enables dynamic shape handling + +The `--dynamo_dynamic` option has been reported to have many problems based on the validation in PR #215. + +### Available options: + +``` +--dynamo_backend {NO, INDUCTOR, NVFUSER, CUDAGRAPHS, CUDAGRAPHS_FALLBACK, etc.} + Specifies the Dynamo backend to use (default is NO, which disables Dynamo) + +--dynamo_mode {default, reduce-overhead, max-autotune} + Specifies the optimization mode (default is 'default') + - 'default': Standard optimization + - 'reduce-overhead': Focuses on reducing compilation overhead + - 'max-autotune': Performs extensive autotuning for potentially better performance + +--dynamo_fullgraph + Flag to enable fullgraph mode, which attempts to capture and optimize the entire model graph + +--dynamo_dynamic + Flag to enable dynamic shape handling for models with variable input shapes +``` + +### Usage example: + +```bash +python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode default +``` + +For more aggressive optimization: +```bash +python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode max-autotune --dynamo_fullgraph +``` + +Note: The best combination of options may depend on your specific model and hardware. Experimentation may be necessary to find the optimal configuration. + +
+日本語 +PyTorch Dynamoオプションが学習プロセスを最適化するために追加されました。PyTorch Dynamoは、TorchInductor(ディープラーニングコンパイラ)を使用して、変更を加えることなくPyTorchプログラムを高速化するためのPythonレベルのJITコンパイラです。この統合により、モデルの精度を維持しながら学習の高速化が期待できます。 + +[PR #215](https://github.com/kohya-ss/musubi-tuner/pull/215) で追加されました。 + +`--dynamo_backend`オプションを指定して、`DynamoBackend`列挙型から利用可能なバックエンドの一つを選択することで、Dynamo最適化を有効にします。 + +追加のオプションにより、Dynamoの動作を微調整できます: +- `--dynamo_mode`:最適化戦略を制御します +- `--dynamo_fullgraph`:より良い最適化の可能性のためにフルグラフモードを有効にします +- `--dynamo_dynamic`:動的形状処理を有効にします + +PR #215での検証によると、`--dynamo_dynamic`には問題が多いことが報告されています。 + +__利用可能なオプション:__ + +``` +--dynamo_backend {NO, INDUCTOR, NVFUSER, CUDAGRAPHS, CUDAGRAPHS_FALLBACK, など} + 使用するDynamoバックエンドを指定します(デフォルトはNOで、Dynamoを無効にします) + +--dynamo_mode {default, reduce-overhead, max-autotune} + 最適化モードを指定します(デフォルトは 'default') + - 'default':標準的な最適化 + - 'reduce-overhead':コンパイルのオーバーヘッド削減に焦点を当てる + - 'max-autotune':より良いパフォーマンスのために広範な自動調整を実行 + +--dynamo_fullgraph + フルグラフモードを有効にするフラグ。モデルグラフ全体をキャプチャして最適化しようとします + +--dynamo_dynamic + 可変入力形状を持つモデルのための動的形状処理を有効にするフラグ +``` + +__使用例:__ + +```bash +python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode default +``` + +より積極的な最適化の場合: +```bash +python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode max-autotune --dynamo_fullgraph +``` + +注意:最適なオプションの組み合わせは、特定のモデルとハードウェアに依存する場合があります。最適な構成を見つけるために実験が必要かもしれません。 +
diff --git a/docs/framepack.md b/docs/framepack.md new file mode 100644 index 0000000000000000000000000000000000000000..c55f63eb721cdf280e4ee79d2df3a13ccdf9590c --- /dev/null +++ b/docs/framepack.md @@ -0,0 +1,282 @@ +# FramePack + +## Overview / 概要 + +This document describes the usage of the [FramePack](https://github.com/lllyasviel/FramePack) architecture within the Musubi Tuner framework. FramePack is a novel video generation architecture developed by lllyasviel. + +Key differences from HunyuanVideo: +- FramePack only supports Image-to-Video (I2V) generation. Text-to-Video (T2V) is not supported. +- It utilizes a different DiT model architecture and requires an additional Image Encoder. VAE is same as HunyuanVideo. Text Encoders seem to be the same as HunyuanVideo but we employ the original FramePack method to utilize them. +- Caching and training scripts are specific to FramePack (`fpack_*.py`). +- Due to its progressive generation nature, VRAM usage can be significantly lower, especially for longer videos, compared to other architectures. + +This feature is experimental. + +
+日本語 +このドキュメントは、Musubi Tunerフレームワーク内での[FramePack](https://github.com/lllyasviel/FramePack) アーキテクチャの使用法について説明しています。FramePackは、lllyasviel氏にによって開発された新しいビデオ生成アーキテクチャです。 + +HunyuanVideoとの主な違いは次のとおりです。 +- FramePackは、画像からビデオ(I2V)生成のみをサポートしています。テキストからビデオ(T2V)はサポートされていません。 +- 異なるDiTモデルアーキテクチャを使用し、追加の画像エンコーダーが必要です。VAEはHunyuanVideoと同じです。テキストエンコーダーはHunyuanVideoと同じと思われますが、FramePack公式と同じ方法で推論を行っています。 +- キャッシングと学習スクリプトはFramePack専用(`fpack_*.py`)です。 +- セクションずつ生成するため、他のアーキテクチャと比較して、特に長いビデオの場合、VRAM使用量が大幅に少なくなる可能性があります。 + +この機能は実験的なものですです。 +
+ +## Download the model / モデルのダウンロード + +You need to download the DiT, VAE, Text Encoder 1 (LLaMA), Text Encoder 2 (CLIP), and Image Encoder (SigLIP) models specifically for FramePack. Several download options are available for each component. + +### DiT Model + +Choose one of the following methods: + +1. **From lllyasviel's Hugging Face repo:** Download the three `.safetensors` files (starting with `diffusion_pytorch_model-00001-of-00003.safetensors`) from [lllyasviel/FramePackI2V_HY](https://huggingface.co/lllyasviel/FramePackI2V_HY). Specify the path to the first file (`...-00001-of-00003.safetensors`) as the `--dit` argument. +2. **From local FramePack installation:** If you have cloned and run the official FramePack repository, the model might be downloaded locally. Specify the path to the snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--lllyasviel--FramePackI2V_HY/snapshots/`. +3. **From Kijai's Hugging Face repo:** Download the single file `FramePackI2V_HY_bf16.safetensors` from [Kijai/HunyuanVideo_comfy](https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_bf16.safetensors). Specify the path to this file as the `--dit` argument. + +### VAE Model + +Choose one of the following methods: + +1. **Use official HunyuanVideo VAE:** Follow the instructions in the main [README.md](../README.md#model-download). +2. **From hunyuanvideo-community Hugging Face repo:** Download `vae/diffusion_pytorch_model.safetensors` from [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo). +3. **From local FramePack installation:** If you have cloned and run the official FramePack repository, the VAE might be downloaded locally within the HunyuanVideo community model snapshot. Specify the path to the snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--hunyuanvideo-community--HunyuanVideo/snapshots/`. + +### Text Encoder 1 (LLaMA) Model + +Choose one of the following methods: + +1. **From Comfy-Org Hugging Face repo:** Download `split_files/text_encoders/llava_llama3_fp16.safetensors` from [Comfy-Org/HunyuanVideo_repackaged](https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged). +2. **From hunyuanvideo-community Hugging Face repo:** Download the four `.safetensors` files (starting with `text_encoder/model-00001-of-00004.safetensors`) from [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo). Specify the path to the first file (`...-00001-of-00004.safetensors`) as the `--text_encoder1` argument. +3. **From local FramePack installation:** (Same as VAE) Specify the path to the HunyuanVideo community model snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--hunyuanvideo-community--HunyuanVideo/snapshots/`. + +### Text Encoder 2 (CLIP) Model + +Choose one of the following methods: + +1. **From Comfy-Org Hugging Face repo:** Download `split_files/text_encoders/clip_l.safetensors` from [Comfy-Org/HunyuanVideo_repackaged](https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged). +2. **From hunyuanvideo-community Hugging Face repo:** Download `text_encoder_2/model.safetensors` from [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo). +3. **From local FramePack installation:** (Same as VAE) Specify the path to the HunyuanVideo community model snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--hunyuanvideo-community--HunyuanVideo/snapshots/`. + +### Image Encoder (SigLIP) Model + +Choose one of the following methods: + +1. **From Comfy-Org Hugging Face repo:** Download `sigclip_vision_patch14_384.safetensors` from [Comfy-Org/sigclip_vision_384](https://huggingface.co/Comfy-Org/sigclip_vision_384). +2. **From lllyasviel's Hugging Face repo:** Download `image_encoder/model.safetensors` from [lllyasviel/flux_redux_bfl](https://huggingface.co/lllyasviel/flux_redux_bfl). +3. **From local FramePack installation:** If you have cloned and run the official FramePack repository, the model might be downloaded locally. Specify the path to the snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--lllyasviel--flux_redux_bfl/snapshots/`. + +
+日本語 + +DiT、VAE、テキストエンコーダー1(LLaMA)、テキストエンコーダー2(CLIP)、および画像エンコーダー(SigLIP)モデルは複数の方法でダウンロードできます。英語の説明を参考にして、ダウンロードしてください。 + +FramePack公式のリポジトリをクローンして実行した場合、モデルはローカルにダウンロードされている可能性があります。スナップショットディレクトリへのパスを指定してください。例:`path/to/FramePack/hf_download/hub/models--lllyasviel--flux_redux_bfl/snapshots/` + +HunyuanVideoの推論をComfyUIですでに行っている場合、いくつかのモデルはすでにダウンロードされている可能性があります。 +
+ +## Pre-caching / 事前キャッシング + +The default resolution for FramePack is 640x640. See [the source code](../frame_pack/bucket_tools.py) for the default resolution of each bucket. + +The dataset for training must be a video dataset. Image datasets are not supported. You can train on videos of any length. Specify `frame_extraction` as `full` and set `max_frames` to a sufficiently large value. However, if the video is too long, you may run out of VRAM during VAE encoding. + +### Latent Pre-caching / latentの事前キャッシング + +Latent pre-caching uses a dedicated script for FramePack. You **must** provide the Image Encoder model. + +```bash +python fpack_cache_latents.py \ + --dataset_config path/to/toml \ + --vae path/to/vae_model.safetensors \ + --image_encoder path/to/image_encoder_model.safetensors \ + --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 +``` + +Key differences from HunyuanVideo caching: +- Uses `fpack_cache_latents.py`. +- Requires the `--image_encoder` argument pointing to the downloaded SigLIP model. +- You can use the `--latent_window_size` argument (default 9) which defines the size of the latent sections FramePack processes (omitted in the example). This value should typically not be changed unless you understand the implications. +- The script generates multiple cache files per video, each corresponding to a different section, with the section index appended to the filename (e.g., `..._frame_pos-0000-count_...` becomes `..._frame_pos-0000-0000-count_...`, `..._frame_pos-0000-0001-count_...`, etc.). +- Image embeddings are calculated using the Image Encoder and stored in the cache files alongside the latents. + +For VRAM savings during VAE decoding, consider using `--vae_chunk_size` and `--vae_spatial_tile_sample_min_size`. If VRAM is overflowing and using shared memory, it is recommended to set `--vae_chunk_size` to 16 or 8, and `--vae_spatial_tile_sample_min_size` to 64 or 32. + +
+日本語 +FramePackのデフォルト解像度は640x640です。各バケットのデフォルト解像度については、[ソースコード](../frame_pack/bucket_tools.py)を参照してください。 + +画像データセットでの学習は行えません。また動画の長さによらず学習可能です。 `frame_extraction` に `full` を指定して、`max_frames` に十分に大きな値を指定してください。ただし、あまりにも長いとVAEのencodeでVRAMが不足する可能性があります。 + +latentの事前キャッシングはFramePack専用のスクリプトを使用します。画像エンコーダーモデルを指定する必要があります。 + +HunyuanVideoのキャッシングとの主な違いは次のとおりです。 +- `fpack_cache_latents.py`を使用します。 +- ダウンロードしたSigLIPモデルを指す`--image_encoder`引数が必要です。 +- `--latent_window_size`引数(デフォルト9)を指定できます(例では省略)。これは、FramePackが処理するlatentセクションのサイズを定義します。この値は、影響を理解していない限り、通常変更しないでください。 +- スクリプトは、各ビデオに対して複数のキャッシュファイルを生成します。各ファイルは異なるセクションに対応し、セクションインデックスがファイル名に追加されます(例:`..._frame_pos-0000-count_...`は`..._frame_pos-0000-0000-count_...`、`..._frame_pos-0000-0001-count_...`などになります)。 +- 画像埋め込みは画像エンコーダーを使用して計算され、latentとともにキャッシュファイルに保存されます。 + +VAEのdecode時のVRAM節約のために、`--vae_chunk_size`と`--vae_spatial_tile_sample_min_size`を使用することを検討してください。VRAMがあふれて共有メモリを使用している場合には、`--vae_chunk_size`を16、8などに、`--vae_spatial_tile_sample_min_size`を64、32などに変更することをお勧めします。 +
+ +### Text Encoder Output Pre-caching / テキストエンコーダー出力の事前キャッシング + +Text encoder output pre-caching also uses a dedicated script. + +```bash +python fpack_cache_text_encoder_outputs.py \ + --dataset_config path/to/toml \ + --text_encoder1 path/to/text_encoder1 \ + --text_encoder2 path/to/text_encoder2 \ + --batch_size 16 +``` + +Key differences from HunyuanVideo caching: +- Uses `fpack_cache_text_encoder_outputs.py`. +- Requires both `--text_encoder1` (LLaMA) and `--text_encoder2` (CLIP) arguments. +- Uses `--fp8_llm` option to run the LLaMA Text Encoder 1 in fp8 mode for VRAM savings (similar to `--fp8_t5` in Wan2.1). +- Saves LLaMA embeddings, attention mask, and CLIP pooler output to the cache file. + +
+日本語 +テキストエンコーダー出力の事前キャッシングも専用のスクリプトを使用します。 + +HunyuanVideoのキャッシングとの主な違いは次のとおりです。 +- `fpack_cache_text_encoder_outputs.py`を使用します。 +- LLaMAとCLIPの両方の引数が必要です。 +- LLaMAテキストエンコーダー1をfp8モードで実行するための`--fp8_llm`オプションを使用します(Wan2.1の`--fp8_t5`に似ています)。 +- LLaMAの埋め込み、アテンションマスク、CLIPのプーラー出力をキャッシュファイルに保存します。 + +
+ + +## Training / 学習 + +### Training + +Training uses a dedicated script `fpack_train_network.py`. Remember FramePack only supports I2V training. + +```bash +accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 fpack_train_network.py \ + --dit path/to/dit_model \ + --vae path/to/vae_model.safetensors \ + --text_encoder1 path/to/text_encoder1 \ + --text_encoder2 path/to/text_encoder2 \ + --image_encoder path/to/image_encoder_model.safetensors \ + --dataset_config path/to/toml \ + --sdpa --mixed_precision bf16 \ + --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \ + --timestep_sampling shift --weighting_scheme none --discrete_flow_shift 3.0 \ + --max_data_loader_n_workers 2 --persistent_data_loader_workers \ + --network_module networks.lora_framepack --network_dim 32 \ + --max_train_epochs 16 --save_every_n_epochs 1 --seed 42 \ + --output_dir path/to/output_dir --output_name name-of-lora +``` + +If you use the command prompt (Windows, not PowerShell), you may need to write them in a single line, or use `^` at the end of each line to continue the command. + +The maximum value for `--blocks_to_swap` is 36. The default resolution for FramePack is 640x640, which requires around 17GB of VRAM. If you run out of VRAM, consider lowering the dataset resolution. + +Key differences from HunyuanVideo training: +- Uses `fpack_train_network.py`. +- **Requires** specifying `--vae`, `--text_encoder1`, `--text_encoder2`, and `--image_encoder`. +- **Requires** specifying `--network_module networks.lora_framepack`. +- Optional `--latent_window_size` argument (default 9, should match caching). +- Memory saving options like `--fp8_base` (for DiT) and `--fp8_llm` (for Text Encoder 1) are available. `--fp8_scaled` is recommended when using `--fp8_base` for DiT. +- `--vae_chunk_size` and `--vae_spatial_tile_sample_min_size` options are available for the VAE to prevent out-of-memory during sampling (similar to caching). +- `--gradient_checkpointing` is available for memory savings. + + +Training settings (learning rate, optimizers, etc.) are experimental. Feedback is welcome. + +
+日本語 +FramePackの学習は専用のスクリプト`fpack_train_network.py`を使用します。FramePackはI2V学習のみをサポートしています。 + +コマンド記述例は英語版を参考にしてください。WindowsでPowerShellではなくコマンドプロンプトを使用している場合、コマンドを1行で記述するか、各行の末尾に`^`を付けてコマンドを続ける必要があります。 + +`--blocks_to_swap`の最大値は36です。FramePackのデフォルト解像度(640x640)では、17GB程度のVRAMが必要です。VRAM容量が不足する場合は、データセットの解像度を下げてください。 + +HunyuanVideoの学習との主な違いは次のとおりです。 +- `fpack_train_network.py`を使用します。 +- `--vae`、`--text_encoder1`、`--text_encoder2`、`--image_encoder`を指定する必要があります。 +- `--network_module networks.lora_framepack`を指定する必要があります。 +- 必要に応じて`--latent_window_size`引数(デフォルト9)を指定できます(キャッシング時と一致させる必要があります)。 +- `--fp8_base`(DiT用)や`--fp8_llm`(テキストエンコーダー1用)などのメモリ節約オプションが利用可能です。`--fp8_base`指定時は、`--fp8_scaled`を使用することをお勧めします。 +- サンプル生成時にメモリ不足を防ぐため、VAE用の`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`オプションが利用可能です(キャッシング時と同様)。 +- メモリ節約のために`--gradient_checkpointing`が利用可能です。 + +
+ +## Inference + +Inference uses a dedicated script `fpack_generate_video.py`. + +```bash +python fpack_generate_video.py \ + --dit path/to/dit_model \ + --vae path/to/vae_model.safetensors \ + --text_encoder1 path/to/text_encoder1 \ + --text_encoder2 path/to/text_encoder2 \ + --image_encoder path/to/image_encoder_model.safetensors \ + --image_path path/to/start_image.jpg \ + --prompt "A cat walks on the grass, realistic style." \ + --video_size 512 768 --video_seconds 5 --fps 30 --infer_steps 25 \ + --attn_mode sdpa --fp8_scaled \ + --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \ + --save_path path/to/save/dir --output_type both \ + --seed 1234 --lora_multiplier 1.0 --lora_weight path/to/lora.safetensors +``` + + +Key differences from HunyuanVideo inference: +- Uses `fpack_generate_video.py`. +- **Requires** specifying `--vae`, `--text_encoder1`, `--text_encoder2`, and `--image_encoder`. +- **Requires** specifying `--image_path` for the starting frame. +- **Requires** specifying `--video_seconds` (length of the video in seconds). +- `--video_size` is the size of the generated video, height and width are specified in that order. +- Optional `--latent_window_size` argument (default 9, should match caching and training). +- `--fp8_scaled` option is available for DiT to reduce memory usage. Quality may be slightly lower. `--fp8_llm` option is available to reduce memory usage of Text Encoder 1. `--fp8` alone is also an option for DiT but `--fp8_scaled` potentially offers better quality. +- LoRA loading options (`--lora_weight`, `--lora_multiplier`, `--include_patterns`, `--exclude_patterns`) are available. `--lycoris` is also supported. +- `--embedded_cfg_scale` (default 10.0) controls the distilled guidance scale. +- `--guidance_scale` (default 1.0) controls the standard classifier-free guidance scale. **Changing this from 1.0 is generally not recommended for the base FramePack model.** +- `--guidance_rescale` (default 0.0) is available but typically not needed. +- `--bulk_decode` option can decode all frames at once, potentially faster but uses more VRAM during decoding. `--vae_chunk_size` and `--vae_spatial_tile_sample_min_size` options are recommended to prevent out-of-memory errors. +- `--sample_solver` (default `unipc`) is available but only `unipc` is implemented. +- `--save_merged_model` option is available to save the DiT model after merging LoRA weights. Inference is skipped if this is specified. +- Batch and interactive modes (`--from_file`, `--interactive`) are **not yet implemented** for FramePack generation. + +Other options like `--video_size`, `--fps`, `--infer_steps`, `--save_path`, `--output_type`, `--seed`, `--attn_mode`, `--blocks_to_swap`, `--vae_chunk_size`, `--vae_spatial_tile_sample_min_size` function similarly to HunyuanVideo/Wan2.1 where applicable. + +The maximum value for `--blocks_to_swap` is 38. +
+日本語 + +FramePackの推論は専用のスクリプト`fpack_generate_video.py`を使用します。コマンド記述例は英語版を参考にしてください。 + +HunyuanVideoの推論との主な違いは次のとおりです。 +- `fpack_generate_video.py`を使用します。 +- `--vae`、`--text_encoder1`、`--text_encoder2`、`--image_encoder`を指定する必要があります。 +- `--image_path`を指定する必要があります。 +- `--video_seconds`を指定する必要があります(秒単位でのビデオの長さを指定)。 +- `--video_size`は生成するビデオのサイズで、高さと幅をその順番で指定します。 +- 必要に応じて`--latent_window_size`引数(デフォルト9)を指定できます(キャッシング時、学習時と一致させる必要があります)。 +- DiTのメモリ使用量を削減するために、`--fp8_scaled`オプションを指定可能です。品質はやや低下する可能性があります。またText Encoder 1のメモリ使用量を削減するために、`--fp8_llm`オプションを指定可能です。DiT用に`--fp8`単独のオプションも用意されていますが、`--fp8_scaled`の方が品質が良い可能性があります。 +- LoRAの読み込みオプション(`--lora_weight`、`--lora_multiplier`、`--include_patterns`、`--exclude_patterns`)が利用可能です。LyCORISもサポートされています。 +- `--embedded_cfg_scale`(デフォルト10.0)は、蒸留されたガイダンススケールを制御します。通常は変更しないでください。 +- `--guidance_scale`(デフォルト1.0)は、標準の分類器フリーガイダンススケールを制御します。**FramePackモデルのベースモデルでは、通常1.0から変更しないことをお勧めします。** +- `--guidance_rescale`(デフォルト0.0)も利用可能ですが、通常は必要ありません。 +- `--bulk_decode`オプションは、すべてのフレームを一度にデコードできるオプションです。高速ですが、デコード中にVRAMを多く使用します。VRAM不足エラーを防ぐために、`--vae_chunk_size`と`--vae_spatial_tile_sample_min_size`オプションを指定することをお勧めします。 +- `--sample_solver`(デフォルト`unipc`)は利用可能ですが、`unipc`のみが実装されています。 +- `--save_merged_model`オプションは、LoRAの重みをマージした後にDiTモデルを保存するためのオプションです。これを指定すると推論はスキップされます。 +- バッチモードとインタラクティブモード(`--from_file`、`--interactive`)はFramePack生成には**まだ実装されていません**。 + +`--video_size`、`--fps`、`--infer_steps`、`--save_path`、`--output_type`、`--seed`、`--attn_mode`、`--blocks_to_swap`、`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`などの他のオプションは、HunyuanVideo/Wan2.1と同様に機能します。 + +`--blocks_to_swap`の最大値は38です。 +
\ No newline at end of file diff --git a/docs/sampling_during_training.md b/docs/sampling_during_training.md new file mode 100644 index 0000000000000000000000000000000000000000..e937b36662b979c816490f82ca5759ef7eba13de --- /dev/null +++ b/docs/sampling_during_training.md @@ -0,0 +1,90 @@ +> 📝 Click on the language section to expand / 言語をクリックして展開 + +# Sampling during training / 学習中のサンプル画像生成 + +By preparing a prompt file, you can generate sample images during training. + +Please be aware that it consumes a considerable amount of VRAM, so be careful when generating sample images for videos with a large number of frames. Also, since it takes time to generate, adjust the frequency of sample image generation as needed. + +
+日本語 + +プロンプトファイルを用意することで、学習中にサンプル画像を生成することができます。 + +VRAMをそれなりに消費しますので、特にフレーム数が多い動画を生成する場合は注意してください。また生成には時間がかかりますので、サンプル画像生成の頻度は適宜調整してください。 +
+ +## How to use / 使い方 + +### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション + +Example of command line options for training with sampling / 記述例: + +```bash +--vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt +--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 +--text_encoder1 path/to/ckpts/text_encoder +--text_encoder2 path/to/ckpts/text_encoder_2 +--sample_prompts /path/to/prompt_file.txt +--sample_every_n_epochs 1 --sample_every_n_steps 1000 -- sample_at_first +``` + +`--vae`, `--vae_chunk_size`, `--vae_spatial_tile_sample_min_size`, `--text_encoder1`, `--text_encoder2` are the same as when generating images, so please refer to [here](/README.md#inference) for details. `--fp8_llm` can also be specified. + +`--sample_prompts` specifies the path to the prompt file used for sample image generation. Details are described below. + +`--sample_every_n_epochs` specifies how often to generate sample images in epochs, and `--sample_every_n_steps` specifies how often to generate sample images in steps. + +`--sample_at_first` is specified when generating sample images at the beginning of training. + +Sample images and videos are saved in the `sample` directory in the directory specified by `--output_dir`. They are saved as `.png` for still images and `.mp4` for videos. + +
+日本語 + +`--vae`、`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`、`--text_encoder1`、`--text_encoder2`は、画像生成時と同様ですので、詳細は[こちら](/README.ja.md#推論)を参照してください。`--fp8_llm`も指定可能です。 + +`--sample_prompts`は、サンプル画像生成に使用するプロンプトファイルのパスを指定します。詳細は後述します。 + +`--sample_every_n_epochs`は、何エポックごとにサンプル画像を生成するかを、`--sample_every_n_steps`は、何ステップごとにサンプル画像を生成するかを指定します。 + +`--sample_at_first`は、学習開始時にサンプル画像を生成する場合に指定します。 + +サンプル画像、動画は、`--output_dir`で指定したディレクトリ内の、`sample`ディレクトリに保存されます。静止画の場合は`.png`、動画の場合は`.mp4`で保存されます。 +
+ +### Prompt file / プロンプトファイル + +The prompt file is a text file that contains the prompts for generating sample images. The example is as follows. / プロンプトファイルは、サンプル画像生成のためのプロンプトを記述したテキストファイルです。例は以下の通りです。 + +``` +# prompt 1: for generating a cat video +A cat walks on the grass, realistic style. --w 640 --h 480 --f 25 --d 1 --s 20 + +# prompt 2: for generating a dog image +A dog runs on the beach, realistic style. --w 960 --h 544 --f 1 --d 2 --s 20 +``` + +A line starting with `#` is a comment. + +* `--w` specifies the width of the generated image or video. The default is 256. +* `--h` specifies the height. The default is 256. +* `--f` specifies the number of frames. The default is 1, which generates a still image. +* `--d` specifies the seed. The default is random. +* `--s` specifies the number of steps in generation. The default is 20. +* `--g` specifies the guidance scale. The default is 6.0, which is the default value during inference of HunyuanVideo. +* `--fs` specifies the discrete flow shift. The default is 14.5, which corresponds to the number of steps 20. In the HunyuanVideo paper, 7.0 is recommended for 50 steps, and 17.0 is recommended for less than 20 steps (e.g. 10). + +
+日本語 + +`#` で始まる行はコメントです。 + +* `--w` 生成画像、動画の幅を指定します。省略時は256です。 +* `--h` 高さを指定します。省略時は256です。 +* `--f` フレーム数を指定します。省略時は1で、静止画を生成します。 +* `--d` シードを指定します。省略時はランダムです。 +* `--s` 生成におけるステップ数を指定します。省略時は20です。 +* `--g` guidance scaleを指定します。省略時は6.0で、HunyuanVideoの推論時のデフォルト値です。 +* `--fs` discrete flow shiftを指定します。省略時は14.5で、ステップ数20の場合に対応した値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。 +
\ No newline at end of file diff --git a/docs/wan.md b/docs/wan.md new file mode 100644 index 0000000000000000000000000000000000000000..27a457a3977cbfe90641000c5b41f39d011979d8 --- /dev/null +++ b/docs/wan.md @@ -0,0 +1,531 @@ +> 📝 Click on the language section to expand / 言語をクリックして展開 + +# Wan 2.1 + +## Overview / 概要 + +This is an unofficial training and inference script for [Wan2.1](https://github.com/Wan-Video/Wan2.1). The features are as follows. + +- fp8 support and memory reduction by block swap: Inference of a 720x1280x81frames videos with 24GB VRAM, training with 720x1280 images with 24GB VRAM +- Inference without installing Flash attention (using PyTorch's scaled dot product attention) +- Supports xformers and Sage attention + +This feature is experimental. + +
+日本語 +[Wan2.1](https://github.com/Wan-Video/Wan2.1) の非公式の学習および推論スクリプトです。 + +以下の特徴があります。 + +- fp8対応およびblock swapによる省メモリ化:720x1280x81framesの動画を24GB VRAMで推論可能、720x1280の画像での学習が24GB VRAMで可能 +- Flash attentionのインストールなしでの実行(PyTorchのscaled dot product attentionを使用) +- xformersおよびSage attention対応 + +この機能は実験的なものです。 +
+ +## Download the model / モデルのダウンロード + +Download the T5 `models_t5_umt5-xxl-enc-bf16.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P/tree/main + +Download the VAE from the above page `Wan2.1_VAE.pth` or download `split_files/vae/wan_2.1_vae.safetensors` from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae + +Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models + +Wan2.1 Fun Control model weights can be downloaded from [here](https://huggingface.co/alibaba-pai/Wan2.1-Fun-14B-Control). Navigate to each weight page and download. The Fun Control model seems to support not only T2V but also I2V tasks. + +Please select the appropriate weights according to T2V, I2V, resolution, model size, etc. + +`fp16` and `bf16` models can be used, and `fp8_e4m3fn` models can be used if `--fp8` (or `--fp8_base`) is specified without specifying `--fp8_scaled`. **Please note that `fp8_scaled` models are not supported even with `--fp8_scaled`.** + +(Thanks to Comfy-Org for providing the repackaged weights.) + +### Model support matrix / モデルサポートマトリックス + +* columns: training dtype (行:学習時のデータ型) +* rows: model dtype (列:モデルのデータ型) + +| model \ training |bf16|fp16|--fp8_base|--fp8base & --fp8_scaled| +|--|--|--|--|--| +|bf16|✓|--|✓|✓| +|fp16|--|✓|✓|✓| +|fp8_e4m3fn|--|--|✓|--| +|fp8_scaled|--|--|--|--| + +
+日本語 +T5 `models_t5_umt5-xxl-enc-bf16.pth` およびCLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を、次のページからダウンロードしてください:https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P/tree/main + +VAEは上のページから `Wan2.1_VAE.pth` をダウンロードするか、次のページから `split_files/vae/wan_2.1_vae.safetensors` をダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae + +DiTの重みを次のページからダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models + +Wan2.1 Fun Controlモデルの重みは、[こちら](https://huggingface.co/alibaba-pai/Wan2.1-Fun-14B-Control)から、それぞれの重みのページに遷移し、ダウンロードしてください。Fun ControlモデルはT2VだけでなくI2Vタスクにも対応しているようです。 + +T2VやI2V、解像度、モデルサイズなどにより適切な重みを選択してください。 + +`fp16` および `bf16` モデルを使用できます。また、`--fp8` (または`--fp8_base`)を指定し`--fp8_scaled`を指定をしないときには `fp8_e4m3fn` モデルを使用できます。**`fp8_scaled` モデルはいずれの場合もサポートされていませんのでご注意ください。** + +(repackaged版の重みを提供してくださっているComfy-Orgに感謝いたします。) +
+ +## Pre-caching / 事前キャッシュ + +### Latent Pre-caching + +Latent pre-caching is almost the same as in HunyuanVideo. Create the cache using the following command: + +```bash +python wan_cache_latents.py --dataset_config path/to/toml --vae path/to/wan_2.1_vae.safetensors +``` + +If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model. If not specified, the training will raise an error. + +If you're running low on VRAM, specify `--vae_cache_cpu` to use the CPU for the VAE internal cache, which will reduce VRAM usage somewhat. + +The control video settings are required for training the Fun-Control model. Please refer to [Dataset Settings](/dataset/dataset_config.md#sample-for-video-dataset-with-control-images) for details. + +
+日本語 +latentの事前キャッシングはHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。 + +I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。指定しないと学習時にエラーが発生します。 + +VRAMが不足している場合は、`--vae_cache_cpu` を指定するとVAEの内部キャッシュにCPUを使うことで、使用VRAMを多少削減できます。 + +Fun-Controlモデルを学習する場合は、制御用動画の設定が必要です。[データセット設定](/dataset/dataset_config.md#sample-for-video-dataset-with-control-images)を参照してください。 +
+ +### Text Encoder Output Pre-caching + +Text encoder output pre-caching is also almost the same as in HunyuanVideo. Create the cache using the following command: + +```bash +python wan_cache_text_encoder_outputs.py --dataset_config path/to/toml --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --batch_size 16 +``` + +Adjust `--batch_size` according to your available VRAM. + +For systems with limited VRAM (less than ~16GB), use `--fp8_t5` to run the T5 in fp8 mode. + +
+日本語 +テキストエンコーダ出力の事前キャッシングもHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。 + +使用可能なVRAMに合わせて `--batch_size` を調整してください。 + +VRAMが限られているシステム(約16GB未満)の場合は、T5をfp8モードで実行するために `--fp8_t5` を使用してください。 +
+ +## Training / 学習 + +### Training + +Start training using the following command (input as a single line): + +```bash +accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 wan_train_network.py + --task t2v-1.3B + --dit path/to/wan2.1_xxx_bf16.safetensors + --dataset_config path/to/toml --sdpa --mixed_precision bf16 --fp8_base + --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing + --max_data_loader_n_workers 2 --persistent_data_loader_workers + --network_module networks.lora_wan --network_dim 32 + --timestep_sampling shift --discrete_flow_shift 3.0 + --max_train_epochs 16 --save_every_n_epochs 1 --seed 42 + --output_dir path/to/output_dir --output_name name-of-lora +``` +The above is an example. The appropriate values for `timestep_sampling` and `discrete_flow_shift` need to be determined by experimentation. + +For additional options, use `python wan_train_network.py --help` (note that many options are unverified). + +`--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (for Wan2.1 official models), `t2v-1.3B-FC`, `t2v-14B-FC`, and `i2v-14B-FC` (for Wan2.1 Fun Control model). Specify the DiT weights for the task with `--dit`. + +Don't forget to specify `--network_module networks.lora_wan`. + +Other options are mostly the same as `hv_train_network.py`. + +Use `convert_lora.py` for converting the LoRA weights after training, as in HunyuanVideo. + +
+日本語 +`timestep_sampling`や`discrete_flow_shift`は一例です。どのような値が適切かは実験が必要です。 + +その他のオプションについては `python wan_train_network.py --help` を使用してください(多くのオプションは未検証です)。 + +`--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (これらはWan2.1公式モデル)、`t2v-1.3B-FC`, `t2v-14B-FC`, `i2v-14B-FC`(Wan2.1-Fun Controlモデル)を指定します。`--dit`に、taskに応じたDiTの重みを指定してください。 + + `--network_module` に `networks.lora_wan` を指定することを忘れないでください。 + +その他のオプションは、ほぼ`hv_train_network.py`と同様です。 + +学習後のLoRAの重みの変換は、HunyuanVideoと同様に`convert_lora.py`を使用してください。 +
+ +### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション + +Example of command line options for training with sampling / 記述例: + +```bash +--vae path/to/wan_2.1_vae.safetensors +--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth +--sample_prompts /path/to/prompt_file.txt +--sample_every_n_epochs 1 --sample_every_n_steps 1000 -- sample_at_first +``` +Each option is the same as when generating images or as HunyuanVideo. Please refer to [here](/docs/sampling_during_training.md) for details. + +If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model. + +You can specify the initial image, the negative prompt and the control video (for Wan2.1-Fun-Control) in the prompt file. Please refer to [here](/docs/sampling_during_training.md#prompt-file--プロンプトファイル). + +
+日本語 +各オプションは推論時、およびHunyuanVideoの場合と同様です。[こちら](/docs/sampling_during_training.md)を参照してください。 + +I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。 + +プロンプトファイルで、初期画像やネガティブプロンプト、制御動画(Wan2.1-Fun-Control用)等を指定できます。[こちら](/docs/sampling_during_training.md#prompt-file--プロンプトファイル)を参照してください。 +
+ + +## Inference / 推論 + +### Inference Options Comparison / 推論オプション比較 + +#### Speed Comparison (Faster → Slower) / 速度比較(速い→遅い) +*Note: Results may vary depending on GPU type* + +fp8_fast > bf16/fp16 (no block swap) > fp8 > fp8_scaled > bf16/fp16 (block swap) + +#### Quality Comparison (Higher → Lower) / 品質比較(高→低) + +bf16/fp16 > fp8_scaled > fp8 >> fp8_fast + +### T2V Inference / T2V推論 + +The following is an example of T2V inference (input as a single line): + +```bash +python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 832 480 --video_length 81 --infer_steps 20 +--prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both +--dit path/to/wan2.1_t2v_1.3B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors +--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth +--attn_mode torch +``` + +`--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (these are Wan2.1 official models), `t2v-1.3B-FC`, `t2v-14B-FC` and `i2v-14B-FC` (for Wan2.1-Fun Control model). + +`--attn_mode` is `torch`, `sdpa` (same as `torch`), `xformers`, `sageattn`,`flash2`, `flash` (same as `flash2`) or `flash3`. `torch` is the default. Other options require the corresponding library to be installed. `flash3` (Flash attention 3) is not tested. + +Specifying `--fp8` runs DiT in fp8 mode. fp8 can significantly reduce memory consumption but may impact output quality. + +`--fp8_scaled` can be specified in addition to `--fp8` to run the model in fp8 weights optimization. This increases memory consumption and speed slightly but improves output quality. See [here](advanced_config.md#fp8-weight-optimization-for-models--モデルの重みのfp8への最適化) for details. + +`--fp8_fast` option is also available for faster inference on RTX 40x0 GPUs. This option requires `--fp8_scaled` option. **This option seems to degrade the output quality.** + +`--fp8_t5` can be used to specify the T5 model in fp8 format. This option reduces memory usage for the T5 model. + +`--negative_prompt` can be used to specify a negative prompt. If omitted, the default negative prompt is used. + +`--flow_shift` can be used to specify the flow shift (default 3.0 for I2V with 480p, 5.0 for others). + +`--guidance_scale` can be used to specify the guidance scale for classifier free guidance (default 5.0). + +`--blocks_to_swap` is the number of blocks to swap during inference. The default value is None (no block swap). The maximum value is 39 for 14B model and 29 for 1.3B model. + +`--vae_cache_cpu` enables VAE cache in main memory. This reduces VRAM usage slightly but processing is slower. + +`--compile` enables torch.compile. See [here](/README.md#inference) for details. + +`--trim_tail_frames` can be used to trim the tail frames when saving. The default is 0. + +`--cfg_skip_mode` specifies the mode for skipping CFG in different steps. The default is `none` (all steps).`--cfg_apply_ratio` specifies the ratio of steps where CFG is applied. See below for details. + +`--include_patterns` and `--exclude_patterns` can be used to specify which LoRA modules to apply or exclude during training. If not specified, all modules are applied by default. These options accept regular expressions. + +`--include_patterns` specifies the modules to be applied, and `--exclude_patterns` specifies the modules to be excluded. The regular expression is matched against the LoRA key name, and include takes precedence. + +The key name to be searched is in sd-scripts format (`lora_unet_`). For example, `lora_unet_blocks_9_cross_attn_k`. + +For example, if you specify `--exclude_patterns "blocks_[23]\d_"`, it will exclude modules containing `blocks_20` to `blocks_39`. If you specify `--include_patterns "cross_attn" --exclude_patterns "blocks_(0|1|2|3|4)_"`, it will apply LoRA to modules containing `cross_attn` and not containing `blocks_0` to `blocks_4`. + +If you specify multiple LoRA weights, please specify them with multiple arguments. For example: `--include_patterns "cross_attn" ".*" --exclude_patterns "dummy_do_not_exclude" "blocks_(0|1|2|3|4)"`. `".*"` is a regex that matches everything. `dummy_do_not_exclude` is a dummy regex that does not match anything. + +`--cpu_noise` generates initial noise on the CPU. This may result in the same results as ComfyUI with the same seed (depending on other settings). + +If you are using the Fun Control model, specify the control video with `--control_path`. You can specify a video file or a folder containing multiple image files. The number of frames in the video file (or the number of images) should be at least the number specified in `--video_length` (plus 1 frame if you specify `--end_image_path`). + +Please try to match the aspect ratio of the control video with the aspect ratio specified in `--video_size` (there may be some deviation from the initial image of I2V due to the use of bucketing processing). + +Other options are same as `hv_generate_video.py` (some options are not supported, please check the help). + +
+日本語 +`--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (これらはWan2.1公式モデル)、`t2v-1.3B-FC`, `t2v-14B-FC`, `i2v-14B-FC`(Wan2.1-Fun Controlモデル)を指定します。 + +`--attn_mode` には `torch`, `sdpa`(`torch`と同じ)、`xformers`, `sageattn`, `flash2`, `flash`(`flash2`と同じ), `flash3` のいずれかを指定します。デフォルトは `torch` です。その他のオプションを使用する場合は、対応するライブラリをインストールする必要があります。`flash3`(Flash attention 3)は未テストです。 + +`--fp8` を指定するとDiTモデルをfp8形式で実行します。fp8はメモリ消費を大幅に削減できますが、出力品質に影響を与える可能性があります。 + +`--fp8_scaled` を `--fp8` と併用すると、fp8への重み量子化を行います。メモリ消費と速度はわずかに悪化しますが、出力品質が向上します。詳しくは[こちら](advanced_config.md#fp8-weight-optimization-for-models--モデルの重みのfp8への最適化)を参照してください。 + +`--fp8_fast` オプションはRTX 40x0 GPUでの高速推論に使用されるオプションです。このオプションは `--fp8_scaled` オプションが必要です。**出力品質が劣化するようです。** + +`--fp8_t5` を指定するとT5モデルをfp8形式で実行します。T5モデル呼び出し時のメモリ使用量を削減します。 + +`--negative_prompt` でネガティブプロンプトを指定できます。省略した場合はデフォルトのネガティブプロンプトが使用されます。 + +`--flow_shift` でflow shiftを指定できます(480pのI2Vの場合はデフォルト3.0、それ以外は5.0)。 + +`--guidance_scale` でclassifier free guianceのガイダンススケールを指定できます(デフォルト5.0)。 + +`--blocks_to_swap` は推論時のblock swapの数です。デフォルト値はNone(block swapなし)です。最大値は14Bモデルの場合39、1.3Bモデルの場合29です。 + +`--vae_cache_cpu` を有効にすると、VAEのキャッシュをメインメモリに保持します。VRAM使用量が多少減りますが、処理は遅くなります。 + +`--compile`でtorch.compileを有効にします。詳細については[こちら](/README.md#inference)を参照してください。 + +`--trim_tail_frames` で保存時に末尾のフレームをトリミングできます。デフォルトは0です。 + +`--cfg_skip_mode` は異なるステップでCFGをスキップするモードを指定します。デフォルトは `none`(全ステップ)。`--cfg_apply_ratio` はCFGが適用されるステップの割合を指定します。詳細は後述します。 + +LoRAのどのモジュールを適用するかを、`--include_patterns`と`--exclude_patterns`で指定できます(未指定時・デフォルトは全モジュール適用されます +)。これらのオプションには、正規表現を指定します。`--include_patterns`は適用するモジュール、`--exclude_patterns`は適用しないモジュールを指定します。正規表現がLoRAのキー名に含まれるかどうかで判断され、includeが優先されます。 + +検索対象となるキー名は sd-scripts 形式(`lora_unet_<モジュール名のドットを_に置換したもの>`)です。例:`lora_unet_blocks_9_cross_attn_k` + +たとえば `--exclude_patterns "blocks_[23]\d_"`のみを指定すると、`blocks_20`から`blocks_39`を含むモジュールが除外されます。`--include_patterns "cross_attn" --exclude_patterns "blocks_(0|1|2|3|4)_"`のようにincludeとexcludeを指定すると、`cross_attn`を含むモジュールで、かつ`blocks_0`から`blocks_4`を含まないモジュールにLoRAが適用されます。 + +複数のLoRAの重みを指定する場合は、複数個の引数で指定してください。例:`--include_patterns "cross_attn" ".*" --exclude_patterns "dummy_do_not_exclude" "blocks_(0|1|2|3|4)"` `".*"`は全てにマッチする正規表現です。`dummy_do_not_exclude`は何にもマッチしないダミーの正規表現です。 + +`--cpu_noise`を指定すると初期ノイズをCPUで生成します。これにより同一seed時の結果がComfyUIと同じになる可能性があります(他の設定にもよります)。 + +Fun Controlモデルを使用する場合は、`--control_path`で制御用の映像を指定します。動画ファイル、または複数枚の画像ファイルを含んだフォルダを指定できます。動画ファイルのフレーム数(または画像の枚数)は、`--video_length`で指定したフレーム数以上にしてください(後述の`--end_image_path`を指定した場合は、さらに+1フレーム)。 + +制御用の映像のアスペクト比は、`--video_size`で指定したアスペクト比とできるかぎり合わせてください(bucketingの処理を流用しているためI2Vの初期画像とズレる場合があります)。 + +その他のオプションは `hv_generate_video.py` と同じです(一部のオプションはサポートされていないため、ヘルプを確認してください)。 +
+ +#### CFG Skip Mode / CFGスキップモード + + These options allow you to balance generation speed against prompt accuracy. More skipped steps results in faster generation with potential quality degradation. + +Setting `--cfg_apply_ratio` to 0.5 speeds up the denoising loop by up to 25%. + +`--cfg_skip_mode` specified one of the following modes: + +- `early`: Skips CFG in early steps for faster generation, applying guidance mainly in later refinement steps +- `late`: Skips CFG in later steps, applying guidance during initial structure formation +- `middle`: Skips CFG in middle steps, applying guidance in both early and later steps +- `early_late`: Skips CFG in both early and late steps, applying only in middle steps +- `alternate`: Applies CFG in alternate steps based on the specified ratio +- `none`: Applies CFG at all steps (default) + +`--cfg_apply_ratio` specifies a value from 0.0 to 1.0 controlling the proportion of steps where CFG is applied. For example, setting 0.5 means CFG will be applied in only 50% of the steps. + +If num_steps is 10, the following table shows the steps where CFG is applied based on the `--cfg_skip_mode` option (A means CFG is applied, S means it is skipped, `--cfg_apply_ratio` is 0.6): + +| skip mode | CFG apply pattern | +|---|---| +| early | SSSSAAAAAA | +| late | AAAAAASSSS | +| middle | AAASSSSAAA | +| early_late | SSAAAAAASS | +| alternate | SASASAASAS | + +The appropriate settings are unknown, but you may want to try `late` or `early_late` mode with a ratio of around 0.3 to 0.5. +
+日本語 +これらのオプションは、生成速度とプロンプトの精度のバランスを取ることができます。スキップされるステップが多いほど、生成速度が速くなりますが、品質が低下する可能性があります。 + +ratioに0.5を指定することで、デノイジングのループが最大25%程度、高速化されます。 + +`--cfg_skip_mode` は次のモードのいずれかを指定します: + +- `early`:初期のステップでCFGをスキップして、主に終盤の精細化のステップで適用します +- `late`:終盤のステップでCFGをスキップし、初期の構造が決まる段階で適用します +- `middle`:中間のステップでCFGをスキップし、初期と終盤のステップの両方で適用します +- `early_late`:初期と終盤のステップの両方でCFGをスキップし、中間のステップのみ適用します +- `alternate`:指定された割合に基づいてCFGを適用します + +`--cfg_apply_ratio` は、CFGが適用されるステップの割合を0.0から1.0の値で指定します。たとえば、0.5に設定すると、CFGはステップの50%のみで適用されます。 + +具体的なパターンは上のテーブルを参照してください。 + +適切な設定は不明ですが、モードは`late`または`early_late`、ratioは0.3~0.5程度から試してみると良いかもしれません。 +
+ +#### Skip Layer Guidance + +Skip Layer Guidance is a feature that uses the output of a model with some blocks skipped as the unconditional output of classifier free guidance. It was originally proposed in [SD 3.5](https://github.com/comfyanonymous/ComfyUI/pull/5404) and first applied in Wan2GP in [this PR](https://github.com/deepbeepmeep/Wan2GP/pull/61). It may improve the quality of generated videos. + +The implementation of SD 3.5 is [here](https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py), and the implementation of Wan2GP (the PR mentioned above) has some different specifications. This inference script allows you to choose between the two methods. + +*The SD3.5 method applies slg output in addition to cond and uncond (slows down the speed). The Wan2GP method uses only cond and slg output.* + +The following arguments are available: + +- `--slg_mode`: Specifies the SLG mode. `original` for SD 3.5 method, `uncond` for Wan2GP method. Default is None (no SLG). +- `--slg_layers`: Specifies the indices of the blocks (layers) to skip in SLG, separated by commas. Example: `--slg_layers 4,5,6`. Default is empty (no skip). If this option is not specified, `--slg_mode` is ignored. +- `--slg_scale`: Specifies the scale of SLG when `original`. Default is 3.0. +- `--slg_start`: Specifies the start step of SLG application in inference steps from 0.0 to 1.0. Default is 0.0 (applied from the beginning). +- `--slg_end`: Specifies the end step of SLG application in inference steps from 0.0 to 1.0. Default is 0.3 (applied up to 30% from the beginning). + +Appropriate settings are unknown, but you may want to try `original` mode with a scale of around 3.0 and a start ratio of 0.0 and an end ratio of 0.5, with layers 4, 5, and 6 skipped. + +
+日本語 +Skip Layer Guidanceは、一部のblockをスキップしたモデル出力をclassifier free guidanceのunconditional出力に使用する機能です。元々は[SD 3.5](https://github.com/comfyanonymous/ComfyUI/pull/5404)で提案されたもので、Wan2.1には[Wan2GPのこちらのPR](https://github.com/deepbeepmeep/Wan2GP/pull/61)で初めて適用されました。生成動画の品質が向上する可能性があります。 + +SD 3.5の実装は[こちら](https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py)で、Wan2GPの実装(前述のPR)は一部仕様が異なります。この推論スクリプトでは両者の方式を選択できるようになっています。 + +※SD3.5方式はcondとuncondに加えてslg outputを適用します(速度が低下します)。Wan2GP方式はcondとslg outputのみを使用します。 + +以下の引数があります。 + +- `--slg_mode`:SLGのモードを指定します。`original`でSD 3.5の方式、`uncond`でWan2GPの方式です。デフォルトはNoneで、SLGを使用しません。 +- `--slg_layers`:SLGでスキップするblock (layer)のインデクスをカンマ区切りで指定します。例:`--slg_layers 4,5,6`。デフォルトは空(スキップしない)です。このオプションを指定しないと`--slg_mode`は無視されます。 +- `--slg_scale`:`original`のときのSLGのスケールを指定します。デフォルトは3.0です。 +- `--slg_start`:推論ステップのSLG適用開始ステップを0.0から1.0の割合で指定します。デフォルトは0.0です(最初から適用)。 +- `--slg_end`:推論ステップのSLG適用終了ステップを0.0から1.0の割合で指定します。デフォルトは0.3です(最初から30%まで適用)。 + +適切な設定は不明ですが、`original`モードでスケールを3.0程度、開始割合を0.0、終了割合を0.5程度に設定し、4, 5, 6のlayerをスキップする設定から始めると良いかもしれません。 +
+ +### I2V Inference / I2V推論 + +The following is an example of I2V inference (input as a single line): + +```bash +python wan_generate_video.py --fp8 --task i2v-14B --video_size 832 480 --video_length 81 --infer_steps 20 +--prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both +--dit path/to/wan2.1_i2v_480p_14B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors +--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth +--attn_mode torch --image_path path/to/image.jpg +``` + +Add `--clip` to specify the CLIP model. `--image_path` is the path to the image to be used as the initial frame. + +`--end_image_path` can be used to specify the end image. This option is experimental. When this option is specified, the saved video will be slightly longer than the specified number of frames and will have noise, so it is recommended to specify `--trim_tail_frames 3` to trim the tail frames. + +You can also use the Fun Control model for I2V inference. Specify the control video with `--control_path`. + +Other options are same as T2V inference. + +
+日本語 +`--clip` を追加してCLIPモデルを指定します。`--image_path` は初期フレームとして使用する画像のパスです。 + +`--end_image_path` で終了画像を指定できます。このオプションは実験的なものです。このオプションを指定すると、保存される動画が指定フレーム数よりもやや多くなり、かつノイズが乗るため、`--trim_tail_frames 3` などを指定して末尾のフレームをトリミングすることをお勧めします。 + +I2V推論でもFun Controlモデルが使用できます。`--control_path` で制御用の映像を指定します。 + +その他のオプションはT2V推論と同じです。 +
+ +### New Batch and Interactive Modes / 新しいバッチモードとインタラクティブモード + +In addition to single video generation, Wan 2.1 now supports batch generation from file and interactive prompt input: + +#### Batch Mode from File / ファイルからのバッチモード + +Generate multiple videos from prompts stored in a text file: + +```bash +python wan_generate_video.py --from_file prompts.txt --task t2v-14B +--dit path/to/model.safetensors --vae path/to/vae.safetensors +--t5 path/to/t5_model.pth --save_path output_directory +``` + +The prompts file format: +- One prompt per line +- Empty lines and lines starting with # are ignored (comments) +- Each line can include prompt-specific parameters using command-line style format: + +``` +A beautiful sunset over mountains --w 832 --h 480 --f 81 --d 42 --s 20 +A busy city street at night --w 480 --h 832 --g 7.5 --n low quality, blurry +``` + +Supported inline parameters (if ommitted, default values from the command line are used): +- `--w`: Width +- `--h`: Height +- `--f`: Frame count +- `--d`: Seed +- `--s`: Inference steps +- `--g` or `--l`: Guidance scale +- `--fs`: Flow shift +- `--i`: Image path (for I2V) +- `--cn`: Control path (for Fun Control) +- `--n`: Negative prompt + +In batch mode, models are loaded once and reused for all prompts, significantly improving overall generation time compared to multiple single runs. + +#### Interactive Mode / インタラクティブモード + +Interactive command-line interface for entering prompts: + +```bash +python wan_generate_video.py --interactive --task t2v-14B +--dit path/to/model.safetensors --vae path/to/vae.safetensors +--t5 path/to/t5_model.pth --save_path output_directory +``` + +In interactive mode: +- Enter prompts directly at the command line +- Use the same inline parameter format as batch mode +- Use Ctrl+D (or Ctrl+Z on Windows) to exit +- Models remain loaded between generations for efficiency + +
+日本語 +単一動画の生成に加えて、Wan 2.1は現在、ファイルからのバッチ生成とインタラクティブなプロンプト入力をサポートしています。 + +#### ファイルからのバッチモード + +テキストファイルに保存されたプロンプトから複数の動画を生成します: + +```bash +python wan_generate_video.py --from_file prompts.txt --task t2v-14B +--dit path/to/model.safetensors --vae path/to/vae.safetensors +--t5 path/to/t5_model.pth --save_path output_directory +``` + +プロンプトファイルの形式: +- 1行に1つのプロンプト +- 空行や#で始まる行は無視されます(コメント) +- 各行にはコマンドライン形式でプロンプト固有のパラメータを含めることができます: + +サポートされているインラインパラメータ(省略した場合、コマンドラインのデフォルト値が使用されます) +- `--w`: 幅 +- `--h`: 高さ +- `--f`: フレーム数 +- `--d`: シード +- `--s`: 推論ステップ +- `--g` または `--l`: ガイダンススケール +- `--fs`: フローシフト +- `--i`: 画像パス(I2V用) +- `--cn`: コントロールパス(Fun Control用) +- `--n`: ネガティブプロンプト + +バッチモードでは、モデルは一度だけロードされ、すべてのプロンプトで再利用されるため、複数回の単一実行と比較して全体的な生成時間が大幅に改善されます。 + +#### インタラクティブモード + +プロンプトを入力するためのインタラクティブなコマンドラインインターフェース: + +```bash +python wan_generate_video.py --interactive --task t2v-14B +--dit path/to/model.safetensors --vae path/to/vae.safetensors +--t5 path/to/t5_model.pth --save_path output_directory +``` + +インタラクティブモードでは: +- コマンドラインで直接プロンプトを入力 +- バッチモードと同じインラインパラメータ形式を使用 +- 終了するには Ctrl+D (Windowsでは Ctrl+Z) を使用 +- 効率のため、モデルは生成間で読み込まれたままになります +
+ diff --git a/endframe.py b/endframe.py new file mode 100644 index 0000000000000000000000000000000000000000..76228cc2c25e677dec83f4f887a1cd4a701894e9 --- /dev/null +++ b/endframe.py @@ -0,0 +1,544 @@ +from diffusers_helper.hf_login import login + +import os +import random + +os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))) + +import gradio as gr +import torch +import traceback +import einops +import safetensors.torch as sf +import numpy as np +import argparse +import math + +from PIL import Image +from diffusers import AutoencoderKLHunyuanVideo +from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer +from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake +from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp +from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked +from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan +from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete +from diffusers_helper.thread_utils import AsyncStream, async_run +from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html +from transformers import SiglipImageProcessor, SiglipVisionModel +from diffusers_helper.clip_vision import hf_clip_vision_encode +from diffusers_helper.bucket_tools import find_nearest_bucket + + +parser = argparse.ArgumentParser() +parser.add_argument('--share', action='store_true') +parser.add_argument("--server", type=str, default='127.0.0.1') +parser.add_argument("--port", type=int, default=8001) +args = parser.parse_args() + +print(args) + +free_mem_gb = get_cuda_free_memory_gb(gpu) +high_vram = free_mem_gb > 60 + +print(f'Free VRAM {free_mem_gb} GB') +print(f'High-VRAM Mode: {high_vram}') + +text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu() +text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu() +tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer') +tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2') +vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu() + +feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor') +image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu() + +transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=torch.bfloat16).cpu() + +vae.eval() +text_encoder.eval() +text_encoder_2.eval() +image_encoder.eval() +transformer.eval() + +if not high_vram: + vae.enable_slicing() + vae.enable_tiling() + +transformer.high_quality_fp32_output_for_inference = True +print('transformer.high_quality_fp32_output_for_inference = True') + +transformer.to(dtype=torch.bfloat16) +vae.to(dtype=torch.float16) +image_encoder.to(dtype=torch.float16) +text_encoder.to(dtype=torch.float16) +text_encoder_2.to(dtype=torch.float16) + +vae.requires_grad_(False) +text_encoder.requires_grad_(False) +text_encoder_2.requires_grad_(False) +image_encoder.requires_grad_(False) +transformer.requires_grad_(False) + +if not high_vram: + # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster + DynamicSwapInstaller.install_model(transformer, device=gpu) + DynamicSwapInstaller.install_model(text_encoder, device=gpu) +else: + text_encoder.to(gpu) + text_encoder_2.to(gpu) + image_encoder.to(gpu) + vae.to(gpu) + transformer.to(gpu) + +stream = AsyncStream() + +outputs_folder = './outputs/' +os.makedirs(outputs_folder, exist_ok=True) + + +@torch.no_grad() +def worker(input_image, end_frame, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, save_section_frames, section_settings=None): + total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) + total_latent_sections = int(max(round(total_latent_sections), 1)) + + job_id = generate_timestamp() + + stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...')))) + + try: + # セクション設定の前処理 + def get_section_settings_map(section_settings): + """ + section_settings: DataFrame List of formats [[number, image, prompt], ...] → {section number: (image, prompt)}dict + """ + result = {} + if section_settings is not None: + for row in section_settings: + if row and row[0] is not None: + sec_num = int(row[0]) + img = row[1] + prm = row[2] if len(row) > 2 else "" + result[sec_num] = (img, prm) + return result + + section_map = get_section_settings_map(section_settings) + section_numbers_sorted = sorted(section_map.keys()) if section_map else [] + + def get_section_info(i_section): + """ + i_section: int + section_map: {Section number: (Image, prompt)} + If there is no specification, the next section, if not None + """ + if not section_map: + return None, None, None + # i_section以降で最初に見つかる設定 + for sec in range(i_section, max(section_numbers_sorted)+1): + if sec in section_map: + img, prm = section_map[sec] + return sec, img, prm + return None, None, None + + # Clean GPU + if not high_vram: + unload_complete_models( + text_encoder, text_encoder_2, image_encoder, vae, transformer + ) + + # Text encoding + + stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...')))) + + if not high_vram: + fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode. + load_model_as_complete(text_encoder_2, target_device=gpu) + + llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) + + if cfg == 1: + llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler) + else: + llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) + + llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512) + llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512) + + # Processing input image + + stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...')))) + + def preprocess_image(img): + H, W, C = img.shape + height, width = find_nearest_bucket(H, W, resolution=640) + img_np = resize_and_center_crop(img, target_width=width, target_height=height) + img_pt = torch.from_numpy(img_np).float() / 127.5 - 1 + img_pt = img_pt.permute(2, 0, 1)[None, :, None] + return img_np, img_pt, height, width + + input_image_np, input_image_pt, height, width = preprocess_image(input_image) + Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png')) + + # VAE encoding + + stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...')))) + + if not high_vram: + load_model_as_complete(vae, target_device=gpu) + + start_latent = vae_encode(input_image_pt, vae) + # end_frameも同じタイミングでencode + if end_frame is not None: + end_frame_np, end_frame_pt, _, _ = preprocess_image(end_frame) + end_frame_latent = vae_encode(end_frame_pt, vae) + else: + end_frame_latent = None + + # create section_latents here + section_latents = None + if section_map: + section_latents = {} + for sec_num, (img, prm) in section_map.items(): + if img is not None: + # 画像をVAE encode + img_np, img_pt, _, _ = preprocess_image(img) + section_latents[sec_num] = vae_encode(img_pt, vae) + + # CLIP Vision + + stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...')))) + + if not high_vram: + load_model_as_complete(image_encoder, target_device=gpu) + + image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder) + image_encoder_last_hidden_state = image_encoder_output.last_hidden_state + + # Dtype + + llama_vec = llama_vec.to(transformer.dtype) + llama_vec_n = llama_vec_n.to(transformer.dtype) + clip_l_pooler = clip_l_pooler.to(transformer.dtype) + clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype) + image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype) + + # Sampling + + stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...')))) + + rnd = torch.Generator("cpu").manual_seed(seed) + num_frames = latent_window_size * 4 - 3 + + history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32).cpu() + history_pixels = None + total_generated_latent_frames = 0 + + latent_paddings = reversed(range(total_latent_sections)) + + if total_latent_sections > 4: + # In theory the latent_paddings should follow the above sequence, but it seems that duplicating some + # items looks better than expanding it when total_latent_sections > 4 + # One can try to remove below trick and just + # use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare + latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] + + for i_section, latent_padding in enumerate(latent_paddings): + is_first_section = i_section == 0 + is_last_section = latent_padding == 0 + use_end_latent = is_last_section and end_frame is not None + latent_padding_size = latent_padding * latent_window_size + # set current_latent here + # セクションごとのlatentを使う場合 + if section_map and section_latents is not None and len(section_latents) > 0: + # i_section以上で最小のsection_latentsキーを探す + valid_keys = [k for k in section_latents.keys() if k >= i_section] + if valid_keys: + use_key = min(valid_keys) + current_latent = section_latents[use_key] + print(f"[section_latent] section {i_section}: use section {use_key} latent (section_map keys: {list(section_latents.keys())})") + print(f"[section_latent] current_latent id: {id(current_latent)}, min: {current_latent.min().item():.4f}, max: {current_latent.max().item():.4f}, mean: {current_latent.mean().item():.4f}") + else: + current_latent = start_latent + print(f"[section_latent] section {i_section}: use start_latent (no section_latent >= {i_section})") + print(f"[section_latent] current_latent id: {id(current_latent)}, min: {current_latent.min().item():.4f}, max: {current_latent.max().item():.4f}, mean: {current_latent.mean().item():.4f}") + else: + current_latent = start_latent + print(f"[section_latent] section {i_section}: use start_latent (no section_latents)") + print(f"[section_latent] current_latent id: {id(current_latent)}, min: {current_latent.min().item():.4f}, max: {current_latent.max().item():.4f}, mean: {current_latent.mean().item():.4f}") + + if is_first_section and end_frame_latent is not None: + history_latents[:, :, 0:1, :, :] = end_frame_latent + + if stream.input_queue.top() == 'end': + stream.output_queue.push(('end', None)) + return + + print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}') + + indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) + clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) + clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) + + clean_latents_pre = current_latent.to(history_latents) + clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2) + clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) + + if not high_vram: + unload_complete_models() + move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation) + + if use_teacache: + transformer.initialize_teacache(enable_teacache=True, num_steps=steps) + else: + transformer.initialize_teacache(enable_teacache=False) + + def callback(d): + preview = d['denoised'] + preview = vae_decode_fake(preview) + + preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8) + preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c') + + if stream.input_queue.top() == 'end': + stream.output_queue.push(('end', None)) + raise KeyboardInterrupt('User ends the task.') + + current_step = d['i'] + 1 + percentage = int(100.0 * current_step / steps) + hint = f'Sampling {current_step}/{steps}' + desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). The video is being extended now ...' + stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint)))) + return + + generated_latents = sample_hunyuan( + transformer=transformer, + sampler='unipc', + width=width, + height=height, + frames=num_frames, + real_guidance_scale=cfg, + distilled_guidance_scale=gs, + guidance_rescale=rs, + # shift=3.0, + num_inference_steps=steps, + generator=rnd, + prompt_embeds=llama_vec, + prompt_embeds_mask=llama_attention_mask, + prompt_poolers=clip_l_pooler, + negative_prompt_embeds=llama_vec_n, + negative_prompt_embeds_mask=llama_attention_mask_n, + negative_prompt_poolers=clip_l_pooler_n, + device=gpu, + dtype=torch.bfloat16, + image_embeddings=image_encoder_last_hidden_state, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + callback=callback, + ) + + if is_last_section: + generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2) + + total_generated_latent_frames += int(generated_latents.shape[2]) + history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2) + + if not high_vram: + offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8) + load_model_as_complete(vae, target_device=gpu) + + real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] + + if history_pixels is None: + history_pixels = vae_decode(real_history_latents, vae).cpu() + else: + section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) + overlapped_frames = latent_window_size * 4 - 3 + + current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu() + history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) + + # Save the final frame of each section as a still image (with section numbers). + if save_section_frames and history_pixels is not None: + try: + if i_section == 0 or current_pixels is None: + # The first section is history_pixels the end of + last_frame = history_pixels[0, :, -1, :, :] + else: + # From the second section onward, current_pixels the end of + last_frame = current_pixels[0, :, -1, :, :] + last_frame = einops.rearrange(last_frame, 'c h w -> h w c') + last_frame = last_frame.cpu().numpy() + last_frame = np.clip((last_frame * 127.5 + 127.5), 0, 255).astype(np.uint8) + last_frame = resize_and_center_crop(last_frame, target_width=width, target_height=height) + if is_first_section and end_frame is None: + Image.fromarray(last_frame).save(os.path.join(outputs_folder, f'{job_id}_{i_section}_end.png')) + else: + Image.fromarray(last_frame).save(os.path.join(outputs_folder, f'{job_id}_{i_section}.png')) + except Exception as e: + print(f"[WARN] セクション{ i_section }最終フレーム画像保存時にエラー: {e}") + + if not high_vram: + unload_complete_models() + + output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4') + + save_bcthw_as_mp4(history_pixels, output_filename, fps=30) + + print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}') + + stream.output_queue.push(('file', output_filename)) + + if is_last_section: + break + except: + traceback.print_exc() + + if not high_vram: + unload_complete_models( + text_encoder, text_encoder_2, image_encoder, vae, transformer + ) + + stream.output_queue.push(('end', None)) + return + + +def process(input_image, end_frame, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, use_random_seed, save_section_frames, section_settings): + global stream + assert input_image is not None, 'No input image!' + + if use_random_seed: + seed = random.randint(0, 2**32 - 1) + # Update the seed field of the UI with random values. + yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True), gr.update(value=seed) + else: + yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True), gr.update() + + stream = AsyncStream() + + async_run(worker, input_image, end_frame, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, save_section_frames, section_settings) + + output_filename = None + + while True: + flag, data = stream.output_queue.next() + + if flag == 'file': + output_filename = data + yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True), gr.update() + + if flag == 'progress': + preview, desc, html = data + yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True), gr.update() + + if flag == 'end': + yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False), gr.update() + break + + +def end_process(): + stream.input_queue.push('end') + + +quick_prompts = [ + 'The girl dances gracefully, with clear movements, full of charm.', + 'A character doing some simple body movements.', +] +quick_prompts = [[x] for x in quick_prompts] + + +css = make_progress_bar_css() +block = gr.Blocks(css=css).queue() +with block: + gr.Markdown('# FramePack') + with gr.Row(): + with gr.Column(): + input_image = gr.Image(sources='upload', type="numpy", label="Image", height=320) + end_frame = gr.Image(sources='upload', type="numpy", label="Final Frame (Optional)", height=320) + prompt = gr.Textbox(label="Prompt", value='', lines=8) + + with gr.Row(): + start_button = gr.Button(value="Start Generation") + end_button = gr.Button(value="End Generation", interactive=False) + + with gr.Row(): + example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Quick List', samples_per_page=1000, components=[prompt]) + example_quick_prompts.click(lambda x: x[0], inputs=[example_quick_prompts], outputs=prompt, show_progress=False, queue=False) + + with gr.Group(): + use_teacache = gr.Checkbox(label='Use TeaCache', value=True, info='Faster speed, but often makes hands and fingers slightly worse.') + + # Use Random Initial value of the seed + use_random_seed_default = True + seed_default = random.randint(0, 2**32 - 1) if use_random_seed_default else 31337 + + use_random_seed = gr.Checkbox(label="Use Random Seed", value=use_random_seed_default) + + n_prompt = gr.Textbox(label="Negative Prompt", value="", visible=False) # Not used + seed = gr.Number(label="Seed", value=seed_default, precision=0) + + def set_random_seed(is_checked): + if is_checked: + return random.randint(0, 2**32 - 1) + else: + return gr.update() + use_random_seed.change(fn=set_random_seed, inputs=use_random_seed, outputs=seed) + + total_second_length = gr.Slider(label="Total Video Length (Seconds)", minimum=1, maximum=120, value=5, step=1) + latent_window_size = gr.Slider(label="Latent Window Size", minimum=1, maximum=33, value=9, step=1, visible=False) # Should not change + steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, info='Changing this value is not recommended.') + + cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01, visible=False) # Should not change + gs = gr.Slider(label="Distilled CFG Scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01, info='Changing this value is not recommended.') + rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01, visible=False) # Should not change + + gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=6, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.") + + # Added a checkbox to save still images for each section (default ON) + save_section_frames = gr.Checkbox(label="Save still images for each section", value=True, info="Save the final frame of each section as a still image (default ON)") + + # Section settings (Change from DataFrame to individual input fields) + section_number_inputs = [] + section_image_inputs = [] + section_prompt_inputs = [] # Keep it as an empty list. + with gr.Group(): + gr.Markdown("### Section Settings. The section number counts from the end of the video. (Optional. If not specified, the usual Image/prompt will be used.)") + for i in range(3): + with gr.Row(): + section_number = gr.Number(label=f"Section number{i+1}", value=None, precision=0) + section_image = gr.Image(label=f"Keyframe image{i+1}", sources="upload", type="numpy", height=200) + section_number_inputs.append(section_number) + section_image_inputs.append(section_image) + # section_settings compiles the values of the three input fields into a list. + def collect_section_settings(*args): + # args: [num1, img1, num2, img2, ...] + return [[args[i], args[i+1], ""] for i in range(0, len(args), 2)] + section_settings = gr.State([[None, None, ""] for _ in range(3)]) + section_inputs = [] + for i in range(3): + section_inputs.extend([section_number_inputs[i], section_image_inputs[i]]) + # Store the summed section_inputs in the section_settings State. + def update_section_settings(*args): + return collect_section_settings(*args) + # Update the section_settings state when section_inputs changes. + for inp in section_inputs: + inp.change(fn=update_section_settings, inputs=section_inputs, outputs=section_settings) + + with gr.Column(): + result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=512, loop=True) + progress_desc = gr.Markdown('', elem_classes='no-generating-animation') + progress_bar = gr.HTML('', elem_classes='no-generating-animation') + preview_image = gr.Image(label="Next Latents", height=200, visible=False) + ips = [input_image, end_frame, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, use_random_seed, save_section_frames, section_settings] + start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button, seed]) + end_button.click(fn=end_process) + + +block.launch( + server_name=args.server, + server_port=args.port, + share=args.share, +) \ No newline at end of file diff --git a/extendvideo.py b/extendvideo.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1d7c34b0f4605ae08b2ffd422c2e1e87e2c870 --- /dev/null +++ b/extendvideo.py @@ -0,0 +1,1941 @@ +import argparse +from datetime import datetime +import gc +import random +import os +import re +import time +import math +from typing import Tuple, Optional, List, Union, Any +from pathlib import Path # Added for glob_images in V2V + +import torch +import accelerate +from accelerate import Accelerator +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from PIL import Image +import cv2 # Added for V2V video loading/resizing +import numpy as np # Added for V2V video processing +import torchvision.transforms.functional as TF +from tqdm import tqdm + +from networks import lora_wan +from utils.safetensors_utils import mem_eff_save_file, load_safetensors +from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES +import wan +from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype +from wan.modules.vae import WanVAE +from wan.modules.t5 import T5EncoderModel +from wan.modules.clip import CLIPModel +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +try: + from lycoris.kohya import create_network_from_weights +except: + pass + +from utils.model_utils import str_to_dtype +from utils.device_utils import clean_memory_on_device +# Original load_video/load_images are still needed for Fun-Control / image loading +from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device, load_images as hv_load_images, load_video as hv_load_video + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def parse_args() -> argparse.Namespace: + """parse command line arguments""" + parser = argparse.ArgumentParser(description="Wan 2.1 inference script") + + # WAN arguments + parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).") + parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") + parser.add_argument( + "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample." + ) + + parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path") + parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16") + parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU") + parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path") + parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path") + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") + parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + + # inference + parser.add_argument("--prompt", type=str, required=True, help="prompt for generation (describe the continuation for extension)") + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + help="negative prompt for generation, use default negative prompt if not specified", + ) + parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width") + parser.add_argument("--video_length", type=int, default=None, help="Total video length (input+generated) for diffusion processing. Default depends on task/mode.") + parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16") + parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + parser.add_argument( + "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False." + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=5.0, + help="Guidance scale for classifier free guidance. Default is 5.0.", + ) + + # Modes (mutually exclusive) + parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference (standard Wan V2V)") + parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference") + parser.add_argument("--extend_video", type=str, default=None, help="path to video for extending it using initial frames") + + # Mode specific args + parser.add_argument("--strength", type=float, default=0.75, help="Strength for video2video inference (0.0-1.0)") + parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video or extension inference") + parser.add_argument("--num_input_frames", type=int, default=4, help="Number of frames from start of --extend_video to use as input (min 1)") + parser.add_argument("--extend_length", type=int, default=None, help="Number of frames to generate *after* the input frames for --extend_video. Default makes total length match task default (e.g., 81).") + + + # Fun-Control argument (distinct from V2V/I2V/Extend) + parser.add_argument( + "--control_strength", + type=float, + default=1.0, + help="Strength of control video influence for Fun-Control (1.0 = normal)", + ) + parser.add_argument( + "--control_path", + type=str, + default=None, + help="path to control video for inference with controlnet (Fun-Control model only). video file or directory with images", + ) + parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving") + parser.add_argument( + "--cfg_skip_mode", + type=str, + default="none", + choices=["early", "late", "middle", "early_late", "alternate", "none"], + help="CFG skip mode. each mode skips different parts of the CFG. " + " early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)", + ) + parser.add_argument( + "--cfg_apply_ratio", + type=float, + default=None, + help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).", + ) + parser.add_argument( + "--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated" + ) + parser.add_argument( + "--slg_scale", + type=float, + default=3.0, + help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond", + ) + parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.") + parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.") + parser.add_argument( + "--slg_mode", + type=str, + default=None, + choices=["original", "uncond"], + help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred", + ) + + # Flow Matching + parser.add_argument( + "--flow_shift", + type=float, + default=None, + help="Shift factor for flow matching schedulers. Default depends on task.", + ) + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") + parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled") + parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", + type=str, + default="torch", + choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"], + help="attention mode", + ) + parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") + parser.add_argument( + "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") + parser.add_argument("--compile", action="store_true", help="Enable torch.compile") + parser.add_argument( + "--compile_args", + nargs=4, + metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), + default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], + help="Torch.compile settings", + ) + + args = parser.parse_args() + + assert (args.latent_path is None or len(args.latent_path) == 0) or ( + args.output_type == "images" or args.output_type == "video" + ), "latent_path is only supported for images or video output" + + # --- Mode Exclusivity Checks --- + modes = [args.video_path, args.image_path, args.extend_video, args.control_path] + num_modes_set = sum(1 for mode in modes if mode is not None) + + if num_modes_set > 1: + active_modes = [] + if args.video_path: active_modes.append("--video_path (V2V)") + if args.image_path: active_modes.append("--image_path (I2V)") + if args.extend_video: active_modes.append("--extend_video (Extend)") + if args.control_path: active_modes.append("--control_path (Fun-Control)") + # Allow Fun-Control + another mode conceptually, but the script logic needs adjustment + if not (num_modes_set == 2 and args.control_path is not None): + raise ValueError(f"Only one operation mode can be specified. Found: {', '.join(active_modes)}") + # Special case: Fun-Control can technically be combined, but let's check task compatibility + if args.control_path is not None and not WAN_CONFIGS[args.task].is_fun_control: + raise ValueError("--control_path is provided, but the selected task does not support Fun-Control.") + + # --- Specific Mode Validations --- + if args.extend_video is not None: + if args.num_input_frames < 1: + raise ValueError("--num_input_frames must be at least 1 for video extension.") + if "t2v" in args.task: + logger.warning("--extend_video provided, but task is t2v. Using I2V-like conditioning.") + # We'll set video_length later based on num_input_frames and extend_length + + if args.image_path is not None: + logger.warning("--image_path is provided. This is standard single-frame I2V.") + if "t2v" in args.task: + logger.warning("--image_path provided, but task is t2v. Using I2V conditioning.") + + if args.video_path is not None: + logger.info("Running in V2V mode.") + # V2V length is determined later if not specified + + if args.control_path is not None and not WAN_CONFIGS[args.task].is_fun_control: + raise ValueError("--control_path is provided, but the selected task does not support Fun-Control.") + + return args + + +def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None, is_extend_mode: bool = False) -> Tuple[int, float, int, bool]: + """Return default values for each task + + Args: + task: task name (t2v, t2i, i2v etc.) + size: size of the video (width, height) + is_extend_mode: whether we are in video extension mode + + Returns: + Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip) + """ + width, height = size if size else (0, 0) + + # I2V and Extend mode share similar defaults + is_i2v_like = "i2v" in task or is_extend_mode + + if "t2i" in task: + return 50, 5.0, 1, False + elif is_i2v_like: + flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 + return 40, flow_shift, 81, True # Default total length 81 + else: # t2v or default + return 50, 5.0, 81, False # Default total length 81 + + +def setup_args(args: argparse.Namespace) -> argparse.Namespace: + """Validate and set default values for optional arguments + + Args: + args: command line arguments + + Returns: + argparse.Namespace: updated arguments + """ + is_extend_mode = args.extend_video is not None + + # Get default values for the task + default_infer_steps, default_flow_shift, default_video_length, _ = get_task_defaults(args.task, tuple(args.video_size), is_extend_mode) + + # Apply default values to unset arguments + if args.infer_steps is None: + args.infer_steps = default_infer_steps + if args.flow_shift is None: + args.flow_shift = default_flow_shift + + # --- Video Length Handling --- + if is_extend_mode: + if args.extend_length is None: + # Calculate extend_length to reach the default total length + args.extend_length = max(1, default_video_length - args.num_input_frames) + logger.info(f"Defaulting --extend_length to {args.extend_length} to reach total length {default_video_length}") + # Set the total video_length for processing + args.video_length = args.num_input_frames + args.extend_length + if args.video_length <= args.num_input_frames: + raise ValueError(f"Total video length ({args.video_length}) must be greater than input frames ({args.num_input_frames}). Increase --extend_length.") + elif args.video_length is None and args.video_path is None: # T2V, I2V (not extend) + args.video_length = default_video_length + elif args.video_length is None and args.video_path is not None: # V2V auto-detect + pass # Delay setting default if V2V and length not specified + elif args.video_length is not None: # User specified length + pass + + # Force video_length to 1 for t2i tasks + if "t2i" in task: + assert args.video_length == 1, f"video_length should be 1 for task {args.task}" + + # parse slg_layers + if args.slg_layers is not None: + args.slg_layers = list(map(int, args.slg_layers.split(","))) + + return args + + +def check_inputs(args: argparse.Namespace) -> Tuple[int, int, Optional[int]]: + """Validate video size and potentially length (if not V2V auto-detect) + + Args: + args: command line arguments + + Returns: + Tuple[int, int, Optional[int]]: (height, width, video_length) + """ + height = args.video_size[0] + width = args.video_size[1] + size = f"{width}*{height}" + + is_extend_mode = args.extend_video is not None + is_v2v_mode = args.video_path is not None + + # Check supported sizes unless it's V2V/Extend (input video dictates size) or FunControl + if not is_v2v_mode and not is_extend_mode and not WAN_CONFIGS[args.task].is_fun_control: + if size not in SUPPORTED_SIZES[args.task]: + logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.") + + video_length = args.video_length # Might be None if V2V auto-detect + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + return height, width, video_length + + +def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]: + """calculate dimensions for the generation + + Args: + video_size: video frame size (height, width) + video_length: number of frames in the video being processed + config: model configuration + + Returns: + Tuple[Tuple[int, int, int, int], int]: + ((channels, frames, height, width), seq_len) + """ + height, width = video_size + frames = video_length + + # calculate latent space dimensions + lat_f = (frames - 1) // config.vae_stride[0] + 1 + lat_h = height // config.vae_stride[1] + lat_w = width // config.vae_stride[2] + + # calculate sequence length + seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f) + + return ((16, lat_f, lat_h, lat_w), seq_len) + + +# Modified function (replace the original) +def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE: + """load VAE model with robust path handling + + Args: + args: command line arguments + config: model configuration + device: device to use + dtype: data type for the model + + Returns: + WanVAE: loaded VAE model + """ + vae_override_path = args.vae + vae_filename = config.vae_checkpoint # Get expected filename, e.g., "Wan2.1_VAE.pth" + # Assume models are in 'wan' dir relative to script if not otherwise specified + vae_base_dir = "wan" + + final_vae_path = None + + # 1. Check if args.vae is a valid *existing file path* + if vae_override_path and isinstance(vae_override_path, str) and \ + (vae_override_path.endswith(".pth") or vae_override_path.endswith(".safetensors")) and \ + os.path.isfile(vae_override_path): + final_vae_path = vae_override_path + logger.info(f"Using VAE override path from --vae: {final_vae_path}") + + # 2. If override is invalid or not provided, construct default path + if final_vae_path is None: + constructed_path = os.path.join(vae_base_dir, vae_filename) + if os.path.isfile(constructed_path): + final_vae_path = constructed_path + logger.info(f"Constructed default VAE path: {final_vae_path}") + if vae_override_path: + logger.warning(f"Ignoring potentially invalid --vae argument: {vae_override_path}") + else: + # 3. Fallback using ckpt_dir if provided and default construction failed + if args.ckpt_dir: + fallback_path = os.path.join(args.ckpt_dir, vae_filename) + if os.path.isfile(fallback_path): + final_vae_path = fallback_path + logger.info(f"Using VAE path from --ckpt_dir fallback: {final_vae_path}") + else: + # If all attempts fail, raise error + raise FileNotFoundError(f"Cannot find VAE. Checked override '{vae_override_path}', constructed '{constructed_path}', and fallback '{fallback_path}'") + else: + raise FileNotFoundError(f"Cannot find VAE. Checked override '{vae_override_path}' and constructed '{constructed_path}'. No --ckpt_dir provided for fallback.") + + # At this point, final_vae_path should be valid + logger.info(f"Loading VAE model from final path: {final_vae_path}") + cache_device = torch.device("cpu") if args.vae_cache_cpu else None + vae = WanVAE(vae_path=final_vae_path, device=device, dtype=dtype, cache_device=cache_device) + return vae + + +def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel: + """load text encoder (T5) model + + Args: + args: command line arguments + config: model configuration + device: device to use + + Returns: + T5EncoderModel: loaded text encoder model + """ + checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint) + tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer) + + text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=device, + checkpoint_path=checkpoint_path, + tokenizer_path=tokenizer_path, + weight_path=args.t5, + fp8=args.fp8_t5, + ) + + return text_encoder + + +def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel: + """load CLIP model (for I2V / Extend only) + + Args: + args: command line arguments + config: model configuration + device: device to use + + Returns: + CLIPModel: loaded CLIP model + """ + checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint) + tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer) + + clip = CLIPModel( + dtype=config.clip_dtype, + device=device, + checkpoint_path=checkpoint_path, + tokenizer_path=tokenizer_path, + weight_path=args.clip, + ) + + return clip + + +def load_dit_model( + args: argparse.Namespace, + config, + device: torch.device, + dit_dtype: torch.dtype, + dit_weight_dtype: Optional[torch.dtype] = None, + is_i2v_like: bool = False, # Combined flag for I2V and Extend modes +) -> WanModel: + """load DiT model + + Args: + args: command line arguments + config: model configuration + device: device to use + dit_dtype: data type for the model + dit_weight_dtype: data type for the model weights. None for as-is + is_i2v_like: I2V or Extend mode (might affect some model config details) + + Returns: + WanModel: loaded DiT model + """ + loading_device = "cpu" + if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled: + loading_device = device + + loading_weight_dtype = dit_weight_dtype + if args.fp8_scaled or args.lora_weight is not None: + loading_weight_dtype = dit_dtype # load as-is + + # do not fp8 optimize because we will merge LoRA weights + # Pass the is_i2v_like flag if the underlying loading function uses it + model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, is_i2v_like) + + return model + + +def merge_lora_weights(model: WanModel, args: argparse.Namespace, device: torch.device) -> None: + """merge LoRA weights to the model + + Args: + model: DiT model + args: command line arguments + device: device to use + """ + if args.lora_weight is None or len(args.lora_weight) == 0: + return + + for i, lora_weight in enumerate(args.lora_weight): + if args.lora_multiplier is not None and len(args.lora_multiplier) > i: + lora_multiplier = args.lora_multiplier[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + + # apply include/exclude patterns + original_key_count = len(weights_sd.keys()) + if args.include_patterns is not None and len(args.include_patterns) > i: + include_pattern = args.include_patterns[i] + regex_include = re.compile(include_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} + logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") + if args.exclude_patterns is not None and len(args.exclude_patterns) > i: + original_key_count_ex = len(weights_sd.keys()) + exclude_pattern = args.exclude_patterns[i] + regex_exclude = re.compile(exclude_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} + logger.info( + f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}" + ) + if len(weights_sd) != original_key_count: + remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) + remaining_keys.sort() + logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") + if len(weights_sd) == 0: + logger.warning(f"No keys left after filtering.") + + if args.lycoris: + lycoris_net, _ = create_network_from_weights( + multiplier=lora_multiplier, + file=None, + weights_sd=weights_sd, + unet=model, + text_encoder=None, + vae=None, + for_inference=True, + ) + lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device) + else: + network = lora_wan.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True) + network.merge_to(None, model, weights_sd, device=device, non_blocking=True) + + synchronize_device(device) + logger.info("LoRA weights loaded") + + # save model here before casting to dit_weight_dtype + if args.save_merged_model: + logger.info(f"Saving merged model to {args.save_merged_model}") + mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory + logger.info("Merged model saved") + + +def optimize_model( + model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype +) -> None: + """optimize the model (FP8 conversion, device move etc.) + + Args: + model: dit model + args: command line arguments + device: device to use + dit_dtype: dtype for the model + dit_weight_dtype: dtype for the model weights + """ + if args.fp8_scaled: + # load state dict as-is and optimize to fp8 + state_dict = model.state_dict() + + # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy) + move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU + state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast) + + info = model.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded FP8 optimized weights: {info}") + + if args.blocks_to_swap == 0: + model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.) + else: + # simple cast to dit_dtype + target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) + target_device = None + + if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled + logger.info(f"Convert model to {dit_weight_dtype}") + target_dtype = dit_weight_dtype + + if args.blocks_to_swap == 0: + logger.info(f"Move model to device: {device}") + target_device = device + + model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations + + if args.compile: + compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args + logger.info( + f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" + ) + torch._dynamo.config.cache_size_limit = 32 + for i in range(len(model.blocks)): + model.blocks[i] = torch.compile( + model.blocks[i], + backend=compile_backend, + mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true", + ) + + if args.blocks_to_swap > 0: + logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") + model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) + model.move_to_device_except_swap_blocks(device) + model.prepare_block_swap_before_forward() + else: + # make sure the model is on the right device + model.to(device) + + model.eval().requires_grad_(False) + clean_memory_on_device(device) + + +def prepare_t2v_inputs( + args: argparse.Namespace, config, accelerator: Accelerator, device: torch.device, vae: Optional[WanVAE] = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + """Prepare inputs for T2V (including Fun-Control variation) + + Args: + args: command line arguments + config: model configuration + accelerator: Accelerator instance + device: device to use + vae: VAE model, required only for Fun-Control + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + (noise, context, context_null, (arg_c, arg_null)) + """ + # Prepare inputs for T2V + # calculate dimensions and sequence length + height, width = args.video_size + # T2V/FunControl length should be set by setup_args + frames = args.video_length + if frames is None: + raise ValueError("video_length must be determined before calling prepare_t2v_inputs") + + (_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, frames, config) + target_shape = (16, lat_f, lat_h, lat_w) # Latent channel dim is 16 + + # configure negative prompt + n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt + + # set seed + seed = args.seed # Seed should be set in generate() + if not args.cpu_noise: + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + else: + # ComfyUI compatible noise + seed_g = torch.manual_seed(seed) + + # load text encoder + text_encoder = load_text_encoder(args, config, device) + text_encoder.model.to(device) + + # encode prompt + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + + # free text encoder and clean memory + del text_encoder + clean_memory_on_device(device) + + # Fun-Control: encode control video to latent space + y = None + if config.is_fun_control and args.control_path: + if vae is None: + raise ValueError("VAE must be provided for Fun-Control input preparation.") + logger.info(f"Encoding control video for Fun-Control") + control_video = load_control_video(args.control_path, frames, height, width).to(device) + vae.to_device(device) + with accelerator.autocast(), torch.no_grad(): + y = vae.encode([control_video])[0] # Encode video + y = y * args.control_strength # Apply strength + vae.to_device("cpu" if args.vae_cache_cpu else "cpu") # Move VAE back + clean_memory_on_device(device) + logger.info(f"Fun-Control conditioning 'y' shape: {y.shape}") + + # generate noise + noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu") + noise = noise.to(device) + + # prepare model input arguments + arg_c = {"context": context, "seq_len": seq_len} + arg_null = {"context": context_null, "seq_len": seq_len} + if y is not None: # Add 'y' only if Fun-Control generated it + arg_c["y"] = [y] + arg_null["y"] = [y] + + return noise, context, context_null, (arg_c, arg_null) + + +def load_video_frames(video_path: str, num_frames: int, target_reso: Tuple[int, int]) -> Tuple[List[np.ndarray], torch.Tensor]: + """Load the first N frames from a video, resize, return numpy list and normalized tensor. + + Args: + video_path (str): Path to the video file. + num_frames (int): Number of frames to load from the start. + target_reso (Tuple[int, int]): Target resolution (height, width). + + Returns: + Tuple[List[np.ndarray], torch.Tensor]: + - List of numpy arrays (frames) in HWC, RGB, uint8 format. + - Tensor of shape [C, F, H, W], float32, range [0, 1]. + """ + logger.info(f"Loading first {num_frames} frames from {video_path}, target reso {target_reso}") + target_h, target_w = target_reso + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Failed to open video file: {video_path}") + + # Get total frame count and check if enough frames exist + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if total_frames < num_frames: + cap.release() + raise ValueError(f"Video has only {total_frames} frames, but {num_frames} were requested for input.") + + # Read frames + frames_np = [] + for i in range(num_frames): + ret, frame = cap.read() + if not ret: + logger.warning(f"Could only read {len(frames_np)} frames out of {num_frames} requested from {video_path}.") + break + + # Convert BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize + current_h, current_w = frame_rgb.shape[:2] + interpolation = cv2.INTER_AREA if target_h * target_w < current_h * current_w else cv2.INTER_LANCZOS4 + frame_resized = cv2.resize(frame_rgb, (target_w, target_h), interpolation=interpolation) + + frames_np.append(frame_resized) + + cap.release() + + if len(frames_np) != num_frames: + raise RuntimeError(f"Failed to load the required {num_frames} frames.") + + # Convert list of numpy arrays to tensor [F, H, W, C] -> [C, F, H, W], range [0, 1] + frames_tensor = torch.from_numpy(np.stack(frames_np, axis=0)).permute(0, 3, 1, 2).float() / 255.0 + frames_tensor = frames_tensor.permute(1, 0, 2, 3) # [C, F, H, W] + + logger.info(f"Loaded {len(frames_np)} input frames. Tensor shape: {frames_tensor.shape}") + + # Return both the original numpy frames (for saving later) and the normalized tensor + return frames_np, frames_tensor + + +# Combined function for I2V and Extend modes +def prepare_i2v_or_extend_inputs( + args: argparse.Namespace, config, accelerator: Accelerator, device: torch.device, vae: WanVAE, + input_frames_tensor: Optional[torch.Tensor] = None # Required for Extend mode +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + """Prepare inputs for I2V (single image) or Extend (multiple frames).""" + if vae is None: + raise ValueError("VAE must be provided for I2V/Extend input preparation.") + + is_extend_mode = input_frames_tensor is not None + is_i2v_mode = args.image_path is not None + + # --- Get Dimensions and Frame Counts --- + height, width = args.video_size + frames = args.video_length # Total frames for diffusion process + if frames is None: + raise ValueError("video_length must be set before calling prepare_i2v_or_extend_inputs") + + num_input_frames = 0 + if is_extend_mode: + num_input_frames = args.num_input_frames + if num_input_frames >= frames: + raise ValueError(f"Number of input frames ({num_input_frames}) must be less than total video length ({frames})") + elif is_i2v_mode: + num_input_frames = 1 + + # --- Load Input Image(s) / Frames --- + img_tensor_for_clip = None # Representative tensor for CLIP + img_tensor_for_vae = None # Tensor containing all input frames/image for VAE + + if is_extend_mode: + # Input frames tensor already provided (normalized [0,1]) + img_tensor_for_vae = input_frames_tensor.to(device) + # Use first frame for CLIP + img_tensor_for_clip = img_tensor_for_vae[:, 0:1, :, :] # [C, 1, H, W] + logger.info(f"Preparing inputs for Extend mode with {num_input_frames} input frames.") + + elif is_i2v_mode: + # Load single image + img = Image.open(args.image_path).convert("RGB") + img_cv2 = np.array(img) + interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC + img_resized_np = cv2.resize(img_cv2, (width, height), interpolation=interpolation) + # Normalized [0,1], shape [C, H, W] + img_tensor_single = TF.to_tensor(img_resized_np).to(device) + # Add frame dimension -> [C, 1, H, W] + img_tensor_for_vae = img_tensor_single.unsqueeze(1) + img_tensor_for_clip = img_tensor_for_vae + logger.info("Preparing inputs for standard I2V mode.") + + else: + raise ValueError("Neither extend_video nor image_path provided for I2V/Extend preparation.") + + # --- Optional End Frame --- + has_end_image = args.end_image_path is not None + end_img_tensor_vae = None # Normalized [-1, 1], shape [C, 1, H, W] + if has_end_image: + end_img = Image.open(args.end_image_path).convert("RGB") + end_img_cv2 = np.array(end_img) + interpolation_end = cv2.INTER_AREA if height < end_img_cv2.shape[0] else cv2.INTER_CUBIC + end_img_resized_np = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation_end) + # Normalized [0,1], shape [C, H, W] -> [C, 1, H, W] + end_img_tensor_load = TF.to_tensor(end_img_resized_np).unsqueeze(1).to(device) + end_img_tensor_vae = (end_img_tensor_load * 2.0 - 1.0) # Scale to [-1, 1] for VAE + logger.info(f"Loaded end image: {args.end_image_path}") + + # --- Calculate Latent Dimensions --- + lat_f = (frames - 1) // config.vae_stride[0] + 1 # Total latent frames + lat_h = height // config.vae_stride[1] + lat_w = width // config.vae_stride[2] + # Latent frames corresponding to the input pixel frames + lat_input_f = (num_input_frames - 1) // config.vae_stride[0] + 1 + + max_seq_len = math.ceil((lat_f + (1 if has_end_image else 0)) * lat_h * lat_w / (config.patch_size[1] * config.patch_size[2])) + logger.info(f"Target latent shape: ({lat_f}, {lat_h}, {lat_w}), Input latent frames: {lat_input_f}, Seq len: {max_seq_len}") + + # --- Set Seed --- + seed = args.seed + seed_g = torch.Generator(device=device) if not args.cpu_noise else torch.manual_seed(seed) + if not args.cpu_noise: + seed_g.manual_seed(seed) + + # --- Generate Noise --- + # Noise for the *entire* processing duration (including input frame slots) + noise = torch.randn( + 16, lat_f + (1 if has_end_image else 0), lat_h, lat_w, + dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu" + ).to(device) + + # --- Text Encoding --- + n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt + text_encoder = load_text_encoder(args, config, device) + text_encoder.model.to(device) + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + del text_encoder + clean_memory_on_device(device) + + # --- CLIP Encoding --- + clip = load_clip_model(args, config, device) + clip.model.to(device) + with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): + # Input needs to be [-1, 1], shape [C, 1, H, W] (or maybe [C, F, H, W] if model supports?) + # Assuming visual encoder takes one frame: use the representative clip tensor + clip_input = img_tensor_for_clip.sub_(0.5).div_(0.5) # Scale [0,1] -> [-1,1] + clip_context = clip.visual([clip_input]) # Pass as list [tensor] + del clip + clean_memory_on_device(device) + + # --- VAE Encoding for Conditioning Tensor 'y' --- + vae.to_device(device) + y_latent_part = torch.zeros(config.latent_channels, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device, dtype=vae.dtype) + + with accelerator.autocast(), torch.no_grad(): + # Encode the input frames/image (scale [0,1] -> [-1,1]) + input_frames_vae = (img_tensor_for_vae * 2.0 - 1.0).to(dtype=vae.dtype) # [-1, 1] + # Pad with zeros if needed to match VAE chunking? Assume encode handles variable length for now. + encoded_input_latents = vae.encode([input_frames_vae])[0] # [C', F_in', H', W'] + actual_encoded_input_f = encoded_input_latents.shape[1] + if actual_encoded_input_f > lat_input_f: + logger.warning(f"VAE encoded {actual_encoded_input_f} frames, expected {lat_input_f}. Truncating.") + encoded_input_latents = encoded_input_latents[:, :lat_input_f, :, :] + elif actual_encoded_input_f < lat_input_f: + logger.warning(f"VAE encoded {actual_encoded_input_f} frames, expected {lat_input_f}. Padding needed for mask.") + # This case shouldn't happen if lat_input_f calculation is correct, but handle defensively + + # Place encoded input latents into the full y tensor + y_latent_part[:, :actual_encoded_input_f, :, :] = encoded_input_latents + + # Encode end image if present + if has_end_image and end_img_tensor_vae is not None: + encoded_end_latent = vae.encode([end_img_tensor_vae.to(dtype=vae.dtype)])[0] # [C', 1, H', W'] + y_latent_part[:, -1:, :, :] = encoded_end_latent # Place at the end + + # --- Create Mask --- + msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device, dtype=vae.dtype) + msk[:, :lat_input_f, :, :] = 1 # Mask the input frames + if has_end_image: + msk[:, -1:, :, :] = 1 # Mask the end frame + + # --- Combine Mask and Latent Part for 'y' --- + y = torch.cat([msk, y_latent_part], dim=0) # Shape [4+C', F_total', H', W'] + logger.info(f"Constructed conditioning 'y' tensor shape: {y.shape}") + + # --- Fun-Control Integration (Optional, might need adjustment for Extend mode) --- + if config.is_fun_control and args.control_path: + logger.warning("Fun-Control with Extend mode is experimental. Control signal might conflict with input frames.") + control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device) + with accelerator.autocast(), torch.no_grad(): + control_latent = vae.encode([control_video])[0] # Encode control video + control_latent = control_latent * args.control_strength # Apply strength + + # How to combine? Replace y? Add? For now, let's assume control replaces the VAE part of y + y = torch.cat([msk, control_latent], dim=0) # Overwrite latent part with control + logger.info(f"Replaced latent part of 'y' with Fun-Control latent. New 'y' shape: {y.shape}") + + + vae.to_device("cpu" if args.vae_cache_cpu else "cpu") # Move VAE back + clean_memory_on_device(device) + + # --- Prepare Model Input Dictionaries --- + arg_c = { + "context": [context[0]], # Needs list format? Check model forward + "clip_fea": clip_context, + "seq_len": max_seq_len, + "y": [y], # Pass conditioning tensor y + } + + arg_null = { + "context": context_null, + "clip_fea": clip_context, + "seq_len": max_seq_len, + "y": [y], # Pass conditioning tensor y + } + + return noise, context, context_null, y, (arg_c, arg_null) + + +# --- V2V Helper Functions --- + +def load_video(video_path, start_frame=0, num_frames=None, bucket_reso=(256, 256)): + """Load video frames and resize them to the target resolution for V2V. + + Args: + video_path (str): Path to the video file + start_frame (int): First frame to load (0-indexed) + num_frames (int, optional): Number of frames to load. If None, load all frames from start_frame. + bucket_reso (tuple): Target resolution (height, width) + + Returns: + list: List of numpy arrays containing video frames in RGB format, resized. + int: Actual number of frames loaded. + """ + logger.info(f"Loading video for V2V from {video_path}, target reso {bucket_reso}, frames {start_frame}-{start_frame+num_frames if num_frames else 'end'}") + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Failed to open video file: {video_path}") + + # Get total frame count and FPS + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + logger.info(f"Input video has {total_frames} frames, {fps} FPS") + + # Calculate how many frames to load + if num_frames is None: + frames_to_load = total_frames - start_frame + else: + # Make sure we don't try to load more frames than exist + frames_to_load = min(num_frames, total_frames - start_frame) + + if frames_to_load <= 0: + cap.release() + logger.warning(f"No frames to load (start_frame={start_frame}, num_frames={num_frames}, total_frames={total_frames})") + return [], 0 + + # Skip to start frame + if start_frame > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + # Read frames + frames = [] + target_h, target_w = bucket_reso + for i in range(frames_to_load): + ret, frame = cap.read() + if not ret: + logger.warning(f"Could only read {len(frames)} frames out of {frames_to_load} requested.") + break + + # Convert from BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize the frame + current_h, current_w = frame_rgb.shape[:2] + interpolation = cv2.INTER_AREA if target_h * target_w < current_h * current_w else cv2.INTER_LANCZOS4 + frame_resized = cv2.resize(frame_rgb, (target_w, target_h), interpolation=interpolation) + + frames.append(frame_resized) + + cap.release() + actual_frames_loaded = len(frames) + logger.info(f"Successfully loaded and resized {actual_frames_loaded} frames for V2V.") + + return frames, actual_frames_loaded + + +def encode_video_to_latents(video_tensor: torch.Tensor, vae: WanVAE, device: torch.device, vae_dtype: torch.dtype, args: argparse.Namespace) -> torch.Tensor: + """Encode video tensor to latent space using VAE for V2V. + + Args: + video_tensor (torch.Tensor): Video tensor with shape [B, C, F, H, W], values in [-1, 1]. + vae (WanVAE): VAE model instance. + device (torch.device): Device to perform encoding on. + vae_dtype (torch.dtype): Target dtype for the output latents. + args (argparse.Namespace): Command line arguments (needed for vae_cache_cpu). + + Returns: + torch.Tensor: Encoded latents with shape [B, C', F', H', W']. + """ + if vae is None: + raise ValueError("VAE must be provided for video encoding.") + + logger.info(f"Encoding video tensor to latents: input shape {video_tensor.shape}") + + # Ensure VAE is on the correct device + vae.to_device(device) + + # Prepare video tensor: move to device, ensure correct dtype + video_tensor = video_tensor.to(device=device, dtype=vae.dtype) # Use VAE's dtype + + # WanVAE expects input as a list of [C, F, H, W] tensors (no batch dim) + latents_list = [] + batch_size = video_tensor.shape[0] + for i in range(batch_size): + video_single = video_tensor[i] # Shape [C, F, H, W] + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae.dtype): + encoded_latent = vae.encode([video_single])[0] # Returns tensor [C', F', H', W'] + latents_list.append(encoded_latent) + + # Stack results back into a batch + latents = torch.stack(latents_list, dim=0) # Shape [B, C', F', H', W'] + + # Move VAE back to CPU (or cache device) + vae_target_device = torch.device("cpu") if not args.vae_cache_cpu else torch.device("cpu") + if args.vae_cache_cpu: logger.info("Moving VAE to CPU for caching.") + else: logger.info("Moving VAE to CPU after encoding.") + vae.to_device(vae_target_device) + clean_memory_on_device(device) + + # Convert latents to the desired final dtype (e.g., bfloat16 for DiT) + latents = latents.to(dtype=vae_dtype) # Use the target vae_dtype passed to function + logger.info(f"Encoded video latents shape: {latents.shape}, dtype: {latents.dtype}") + + return latents + + +def prepare_v2v_inputs(args: argparse.Namespace, config, accelerator: Accelerator, device: torch.device, video_latents: torch.Tensor): + """Prepare inputs for Video2Video inference based on encoded video latents. + + Args: + args (argparse.Namespace): Command line arguments. + config: Model configuration. + accelerator: Accelerator instance. + device (torch.device): Device to use. + video_latents (torch.Tensor): Encoded latent representation of input video [B, C', F', H', W']. + + Returns: + Tuple containing noise, context, context_null, (arg_c, arg_null). + """ + # Get dimensions directly from the video latents + if len(video_latents.shape) != 5: + raise ValueError(f"Expected video_latents to have 5 dimensions [B, C, F, H, W], but got shape {video_latents.shape}") + + batch_size, latent_channels, lat_f, lat_h, lat_w = video_latents.shape + if batch_size != 1: + logger.warning(f"V2V input preparation currently assumes batch size 1, but got {batch_size}. Using first item.") + video_latents = video_latents[0:1] # Keep batch dim + + # Calculate target shape and sequence length based on actual latent dimensions + target_shape = video_latents.shape[1:] # [C', F', H', W'] + (_, _, _), seq_len = calculate_dimensions((args.video_size[0], args.video_size[1]), args.video_length, config) # Use original args to get seq_len + # (_, _, _), seq_len = calculate_dimensions((lat_h * config.vae_stride[1], lat_w * config.vae_stride[2]), (lat_f-1)*config.vae_stride[0]+1, config) # Recalculate seq_len from latent dims + + logger.info(f"V2V derived latent shape: {target_shape}, seq_len: {seq_len}") + + # Configure negative prompt + n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt + + # Set seed (already set in generate(), just need generator) + seed = args.seed + if not args.cpu_noise: + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + else: + seed_g = torch.manual_seed(seed) + + # Load text encoder + text_encoder = load_text_encoder(args, config, device) + text_encoder.model.to(device) + + # Encode prompt + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + + # Free text encoder and clean memory + del text_encoder + clean_memory_on_device(device) + + # Generate noise with the same shape as video_latents (including batch dimension) + noise = torch.randn( + video_latents.shape, # [B, C', F', H', W'] + dtype=torch.float32, + device=device if not args.cpu_noise else "cpu", + generator=seed_g + ) + noise = noise.to(device) # Ensure noise is on the target device + + # Prepare model input arguments (context needs to match batch size of latents) + arg_c = {"context": context, "seq_len": seq_len} + arg_null = {"context": context_null, "seq_len": seq_len} + + # V2V does not use 'y' or 'clip_fea' in the standard Wan model case + + return noise, context, context_null, (arg_c, arg_null) + + +# --- End V2V Helper Functions --- + +def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor: + """load control video to pixel space for Fun-Control model + + Args: + control_path: path to control video + frames: number of frames in the video + height: height of the video + width: width of the video + + Returns: + torch.Tensor: control video tensor, CFHW, range [-1, 1] + """ + logger.info(f"Load control video for Fun-Control from {control_path}") + + # Use the original helper from hv_generate_video for consistency + if os.path.isfile(control_path): + # Use hv_load_video which returns list of numpy arrays (HWC, 0-255) + # NOTE: hv_load_video takes (W, H) for bucket_reso! + video_frames_np = hv_load_video(control_path, 0, frames, bucket_reso=(width, height)) + elif os.path.isdir(control_path): + # Use hv_load_images which returns list of numpy arrays (HWC, 0-255) + # NOTE: hv_load_images takes (W, H) for bucket_reso! + video_frames_np = hv_load_images(control_path, frames, bucket_reso=(width, height)) + else: + raise FileNotFoundError(f"Control path not found: {control_path}") + + if not video_frames_np: + raise ValueError(f"No frames loaded from control path: {control_path}") + if len(video_frames_np) < frames: + logger.warning(f"Control video has {len(video_frames_np)} frames, less than requested {frames}. Using available frames and repeating last.") + # Repeat last frame to match length + last_frame = video_frames_np[-1] + video_frames_np.extend([last_frame] * (frames - len(video_frames_np))) + + # Stack and convert to tensor: F, H, W, C (0-255) -> F, C, H, W (-1 to 1) + video_frames_np = np.stack(video_frames_np, axis=0) + video_tensor = torch.from_numpy(video_frames_np).permute(0, 3, 1, 2).float() / 127.5 - 1.0 # Normalize to [-1, 1] + + # Permute to C, F, H, W + video_tensor = video_tensor.permute(1, 0, 2, 3) + logger.info(f"Loaded Fun-Control video tensor shape: {video_tensor.shape}") + + return video_tensor + +def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]: + """setup scheduler for sampling + + Args: + args: command line arguments + config: model configuration + device: device to use + + Returns: + Tuple[Any, torch.Tensor]: (scheduler, timesteps) + """ + if args.sample_solver == "unipc": + scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False) + scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift) + timesteps = scheduler.timesteps + elif args.sample_solver == "dpm++": + scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift) + timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas) + elif args.sample_solver == "vanilla": + scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift) + scheduler.set_timesteps(args.infer_steps, device=device) + timesteps = scheduler.timesteps + + # FlowMatchDiscreteScheduler does not support generator argument in step method + org_step = scheduler.step + + def step_wrapper( + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None, # Add generator argument here + ): + # Call original step, ignoring generator if it doesn't accept it + try: + # Try calling with generator if the underlying class was updated + return org_step(model_output, timestep, sample, return_dict=return_dict, generator=generator) + except TypeError: + # Fallback to calling without generator + # logger.warning("Scheduler step does not support generator argument, proceeding without it.") # Reduce noise + return org_step(model_output, timestep, sample, return_dict=return_dict) + + + scheduler.step = step_wrapper + else: + raise NotImplementedError(f"Unsupported solver: {args.sample_solver}") + + logger.info(f"Using scheduler: {args.sample_solver}, timesteps shape: {timesteps.shape}") + return scheduler, timesteps + + +def run_sampling( + model: WanModel, + noise: torch.Tensor, # This might be pure noise (T2V/I2V/Extend) or mixed noise+latent (V2V) + scheduler: Any, + timesteps: torch.Tensor, # Might be a subset for V2V + args: argparse.Namespace, + inputs: Tuple[dict, dict], # (arg_c, arg_null) + device: torch.device, + seed_g: torch.Generator, + accelerator: Accelerator, + use_cpu_offload: bool = True, +) -> torch.Tensor: + """run sampling loop (Denoising) + Args: + model: dit model + noise: initial latent state (pure noise or mixed noise/video latent) + scheduler: scheduler for sampling + timesteps: time steps for sampling (can be subset for V2V) + args: command line arguments + inputs: model input dictionaries (arg_c, arg_null) containing context etc. + device: device to use + seed_g: random generator + accelerator: Accelerator instance + use_cpu_offload: Whether to offload tensors to CPU during processing + Returns: + torch.Tensor: generated latent + """ + arg_c, arg_null = inputs + + # Ensure inputs (context, y, etc.) are correctly formatted (e.g., lists if model expects list input) + # Example: ensure context is list [tensor] if model expects list + if isinstance(arg_c.get("context"), torch.Tensor): + arg_c["context"] = [arg_c["context"]] + if isinstance(arg_null.get("context"), torch.Tensor): + arg_null["context"] = [arg_null["context"]] + # Similar checks/conversions for other keys like 'y' if needed based on WanModel.forward signature + + + latent = noise # Initialize latent state [B, C, F, H, W] + latent_storage_device = device if not use_cpu_offload else "cpu" + latent = latent.to(latent_storage_device) # Move initial state to storage device + + # cfg skip logic + apply_cfg_array = [] + num_timesteps = len(timesteps) + + if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None: + # Calculate thresholds based on cfg_apply_ratio + apply_steps = int(num_timesteps * args.cfg_apply_ratio) + + if args.cfg_skip_mode == "early": + start_index = num_timesteps - apply_steps; end_index = num_timesteps + elif args.cfg_skip_mode == "late": + start_index = 0; end_index = apply_steps + elif args.cfg_skip_mode == "early_late": + start_index = (num_timesteps - apply_steps) // 2; end_index = start_index + apply_steps + elif args.cfg_skip_mode == "middle": + skip_steps = num_timesteps - apply_steps + middle_start = (num_timesteps - skip_steps) // 2; middle_end = middle_start + skip_steps + else: # Includes "alternate" - handled inside loop + start_index = 0; end_index = num_timesteps # Default range for alternate + + w = 0.0 # For alternate mode + for step_idx in range(num_timesteps): + apply = True # Default + if args.cfg_skip_mode == "alternate": + w += args.cfg_apply_ratio; apply = w >= 1.0 + if apply: w -= 1.0 + elif args.cfg_skip_mode == "middle": + apply = not (step_idx >= middle_start and step_idx < middle_end) + elif args.cfg_skip_mode != "none": # early, late, early_late + apply = step_idx >= start_index and step_idx < end_index + + apply_cfg_array.append(apply) + + pattern = ["A" if apply else "S" for apply in apply_cfg_array] + pattern = "".join(pattern) + logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, steps: {num_timesteps}, pattern: {pattern}") + else: + # Apply CFG on all steps + apply_cfg_array = [True] * num_timesteps + + # SLG (Skip Layer Guidance) setup + apply_slg_global = args.slg_layers is not None and args.slg_mode is not None + slg_start_step = int(args.slg_start * num_timesteps) + slg_end_step = int(args.slg_end * num_timesteps) + + logger.info(f"Starting sampling loop for {num_timesteps} steps.") + for i, t in enumerate(tqdm(timesteps)): + # Prepare input for the model (move latent to compute device) + # Latent should be [B, C, F, H, W] + # Model expects latent input 'x' as list: [tensor] + latent_on_device = latent.to(device) + latent_model_input_list = [latent_on_device] # Wrap in list + timestep = torch.stack([t]).to(device) # Ensure timestep is a tensor on device + + with accelerator.autocast(), torch.no_grad(): + # 1. Predict conditional noise estimate + noise_pred_cond = model(x=latent_model_input_list, t=timestep, **arg_c)[0] + noise_pred_cond = noise_pred_cond.to(latent_storage_device) + + # 2. Predict unconditional noise estimate (potentially with SLG) + apply_cfg = apply_cfg_array[i] + if apply_cfg: + apply_slg_step = apply_slg_global and (i >= slg_start_step and i < slg_end_step) + slg_indices_for_call = args.slg_layers if apply_slg_step else None + uncond_input_args = arg_null + + if apply_slg_step and args.slg_mode == "original": + # Standard uncond prediction first + noise_pred_uncond = model(x=latent_model_input_list, t=timestep, **uncond_input_args)[0].to(latent_storage_device) + # SLG prediction (skipping layers in uncond) + skip_layer_out = model(x=latent_model_input_list, t=timestep, skip_block_indices=slg_indices_for_call, **uncond_input_args)[0].to(latent_storage_device) + # Combine: scaled = uncond + scale * (cond - uncond) + slg_scale * (cond - skip) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out) + + elif apply_slg_step and args.slg_mode == "uncond": + # SLG prediction (skipping layers in uncond) replaces standard uncond + noise_pred_uncond = model(x=latent_model_input_list, t=timestep, skip_block_indices=slg_indices_for_call, **uncond_input_args)[0].to(latent_storage_device) + # Combine: scaled = slg_uncond + scale * (cond - slg_uncond) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: + # Regular CFG (no SLG or SLG not active this step) + noise_pred_uncond = model(x=latent_model_input_list, t=timestep, **uncond_input_args)[0].to(latent_storage_device) + # Combine: scaled = uncond + scale * (cond - uncond) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + # CFG is skipped for this step, use conditional prediction directly + noise_pred = noise_pred_cond + + # 3. Compute previous sample state with the scheduler + # Scheduler expects noise_pred [B, C, F, H, W] and latent [B, C, F, H, W] + scheduler_output = scheduler.step( + noise_pred.to(device), # Ensure noise_pred is on compute device + t, + latent_on_device, # Pass the tensor directly + return_dict=False, + generator=seed_g # Pass generator + ) + prev_latent = scheduler_output[0] # Get the new latent state [B, C, F, H, W] + + # 4. Update latent state (move back to storage device) + latent = prev_latent.to(latent_storage_device) + + # Return the final denoised latent (should be on storage device) + logger.info("Sampling loop finished.") + return latent + + +def generate(args: argparse.Namespace) -> Tuple[Optional[torch.Tensor], Optional[List[np.ndarray]]]: + """main function for generation pipeline (T2V, I2V, V2V, Extend) + + Args: + args: command line arguments + + Returns: + Tuple[Optional[torch.Tensor], Optional[List[np.ndarray]]]: + - generated latent tensor [B, C, F, H, W], or None if error/skipped. + - list of original input frames (numpy HWC RGB uint8) if in Extend mode, else None. + """ + device = torch.device(args.device) + cfg = WAN_CONFIGS[args.task] + + # --- Determine Mode --- + is_extend_mode = args.extend_video is not None + is_i2v_mode = args.image_path is not None and not is_extend_mode + is_v2v_mode = args.video_path is not None + is_fun_control = args.control_path is not None and cfg.is_fun_control # Can overlap + is_t2v_mode = not is_extend_mode and not is_i2v_mode and not is_v2v_mode and not is_fun_control + + mode_str = ("Extend" if is_extend_mode else + "I2V" if is_i2v_mode else + "V2V" if is_v2v_mode else + "T2V" + ("+FunControl" if is_fun_control else "")) + if is_fun_control and not is_t2v_mode: # If funcontrol combined with other modes + mode_str += "+FunControl" + logger.info(f"Running in {mode_str} mode") + + # --- Data Types --- + dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16 + if dit_dtype.itemsize == 1: + dit_dtype = torch.bfloat16 + if args.fp8_scaled: raise ValueError("Cannot use --fp8_scaled with pre-quantized FP8 weights.") + dit_weight_dtype = None + elif args.fp8_scaled: dit_weight_dtype = None + elif args.fp8: dit_weight_dtype = torch.float8_e4m3fn + else: dit_weight_dtype = dit_dtype + + vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else (torch.bfloat16 if dit_dtype == torch.bfloat16 else torch.float16) + logger.info( + f"Using device: {device}, DiT compute: {dit_dtype}, DiT weight: {dit_weight_dtype or 'Mixed (FP8 Scaled)' if args.fp8_scaled else dit_dtype}, VAE: {vae_dtype}, T5 FP8: {args.fp8_t5}" + ) + + # --- Accelerator --- + mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" + accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) + + # --- Seed --- + seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + args.seed = seed + logger.info(f"Using seed: {seed}") + + # --- Load VAE (if needed for input processing) --- + vae = None + needs_vae_early = is_extend_mode or is_i2v_mode or is_v2v_mode or is_fun_control + if needs_vae_early: + vae = load_vae(args, cfg, device, vae_dtype) + + # --- Prepare Inputs --- + noise = None + context = None + context_null = None + inputs = None + video_latents = None # For V2V mixing + original_input_frames_np = None # For Extend mode saving + + if is_extend_mode: + # 1. Load initial frames (numpy list and normalized tensor) + original_input_frames_np, input_frames_tensor = load_video_frames( + args.extend_video, args.num_input_frames, tuple(args.video_size) + ) + # 2. Prepare inputs using the loaded frames tensor + noise, context, context_null, _, inputs = prepare_i2v_or_extend_inputs( + args, cfg, accelerator, device, vae, input_frames_tensor=input_frames_tensor + ) + del input_frames_tensor # Free memory + clean_memory_on_device(device) + + elif is_i2v_mode: + # Prepare I2V inputs (single image) + noise, context, context_null, _, inputs = prepare_i2v_or_extend_inputs( + args, cfg, accelerator, device, vae + ) + + elif is_v2v_mode: + # 1. Load and prepare video + video_frames_np, actual_frames_loaded = load_video( + args.video_path, start_frame=0, num_frames=args.video_length, bucket_reso=tuple(args.video_size) + ) + if actual_frames_loaded == 0: raise ValueError(f"Could not load frames from video: {args.video_path}") + if args.video_length is None or actual_frames_loaded < args.video_length: + logger.info(f"Updating video_length based on loaded V2V frames: {actual_frames_loaded}") + args.video_length = actual_frames_loaded + height, width, video_length = check_inputs(args) # Re-check + + # Convert frames np [F,H,W,C] uint8 -> tensor [1,C,F,H,W] float32 [-1, 1] + video_tensor = torch.from_numpy(np.stack(video_frames_np, axis=0)) + video_tensor = video_tensor.permute(0, 3, 1, 2).float() # F,C,H,W + video_tensor = video_tensor.permute(1, 0, 2, 3).unsqueeze(0) # 1,C,F,H,W + video_tensor = video_tensor / 127.5 - 1.0 # Normalize to [-1, 1] + + # 2. Encode video to latents (pass vae_dtype for DiT compatibility) + video_latents = encode_video_to_latents(video_tensor, vae, device, vae_dtype, args) + del video_tensor, video_frames_np + clean_memory_on_device(device) + + # 3. Prepare V2V inputs (noise, context, etc.) + noise, context, context_null, inputs = prepare_v2v_inputs(args, cfg, accelerator, device, video_latents) + + elif is_t2v_mode or is_fun_control: # Should handle T2V+FunControl here + # Prepare T2V inputs (passes VAE if is_fun_control) + if args.video_length is None: + raise ValueError("video_length must be specified for T2V/Fun-Control.") + noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae if is_fun_control else None) + + # At this point, VAE should be on CPU/cache unless still needed for decoding + + # --- Load DiT Model --- + is_i2v_like = is_i2v_mode or is_extend_mode + model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v_like) + + # --- Merge LoRA --- + if args.lora_weight is not None and len(args.lora_weight) > 0: + merge_lora_weights(model, args, device) + if args.save_merged_model: + logger.info("Merged model saved. Exiting without generation.") + return None, None + + # --- Optimize Model --- + optimize_model(model, args, device, dit_dtype, dit_weight_dtype) + + # --- Setup Scheduler & Timesteps --- + scheduler, timesteps = setup_scheduler(args, cfg, device) + + # --- Prepare for Sampling --- + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + + latent = noise # Start with noise (correctly shaped for T2V/I2V/Extend) + + # --- V2V Strength Adjustment --- + if is_v2v_mode and args.strength < 1.0: + if video_latents is None: raise RuntimeError("video_latents not available for V2V strength.") + num_inference_steps = max(1, int(args.infer_steps * args.strength)) + logger.info(f"V2V Strength: {args.strength}, adjusting inference steps to {num_inference_steps}") + t_start_idx = len(timesteps) - num_inference_steps + if t_start_idx < 0: t_start_idx = 0 + t_start = timesteps[t_start_idx] + # Use scheduler.add_noise for proper mixing + video_latents = video_latents.to(device=noise.device, dtype=noise.dtype) + latent = scheduler.add_noise(video_latents, noise, t_start.unsqueeze(0).expand(noise.shape[0])) # Add noise based on start time + latent = latent.to(noise.dtype) # Ensure correct dtype after add_noise + logger.info(f"Mixed noise and video latents using scheduler.add_noise at timestep {t_start.item():.1f}") + timesteps = timesteps[t_start_idx:] # Use subset of timesteps + logger.info(f"Using last {len(timesteps)} timesteps for V2V sampling.") + else: + logger.info(f"Using full {len(timesteps)} timesteps for sampling.") + # Latent remains the initial noise (already handles I2V/Extend via 'y' conditioning) + + + # --- Run Sampling Loop --- + logger.info("Starting denoising sampling loop...") + final_latent = run_sampling( + model, latent, scheduler, timesteps, args, inputs, device, seed_g, accelerator, + use_cpu_offload=(args.blocks_to_swap > 0) + ) + + # --- Cleanup --- + del model, scheduler, context, context_null, inputs + if video_latents is not None: del video_latents + synchronize_device(device) + if args.blocks_to_swap > 0: + logger.info("Waiting 5 seconds for block swap cleanup...") + time.sleep(5) + gc.collect() + clean_memory_on_device(device) + + # Store VAE instance for decoding + args._vae = vae + + # Return latent [B, C, F, H, W] and original frames if extending + if len(final_latent.shape) == 4: final_latent = final_latent.unsqueeze(0) + return final_latent, original_input_frames_np + + +def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor: + """decode latent tensor to video frames + + Args: + latent: latent tensor [B, C, F, H, W] + args: command line arguments (contains _vae instance) + cfg: model configuration + + Returns: + torch.Tensor: decoded video tensor [B, C, F, H, W], range [0, 1], on CPU + """ + device = torch.device(args.device) + vae = None + if hasattr(args, "_vae") and args._vae is not None: + vae = args._vae + logger.info("Using VAE instance from generation pipeline for decoding.") + else: + logger.info("Loading VAE for decoding...") + vae_dtype_decode = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16 # Default bfloat16 if not specified + vae = load_vae(args, cfg, device, vae_dtype_decode) + args._vae = vae + + vae.to_device(device) + logger.info(f"Decoding video from latents: shape {latent.shape}, dtype {latent.dtype}") + latent_decode = latent.to(device=device, dtype=vae.dtype) + + videos = None + with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): + # Assuming vae.decode handles batch tensor [B, C, F, H, W] and returns list of [C, F, H, W] + decoded_list = vae.decode(latent_decode) + if decoded_list and len(decoded_list) > 0: + videos = torch.stack(decoded_list, dim=0) # Stack list back into batch: B, C, F, H, W + else: + raise RuntimeError("VAE decoding failed or returned empty list.") + + vae.to_device("cpu" if args.vae_cache_cpu else "cpu") # Move back VAE + clean_memory_on_device(device) + logger.info(f"Decoded video shape: {videos.shape}") + + # Post-processing: scale [-1, 1] -> [0, 1], clamp, move to CPU float32 + videos = (videos + 1.0) / 2.0 + videos = torch.clamp(videos, 0.0, 1.0) + video_final = videos.cpu().to(torch.float32) + + # Apply trim tail frames *after* decoding + if args.trim_tail_frames > 0: + logger.info(f"Trimming last {args.trim_tail_frames} frames from decoded video.") + video_final = video_final[:, :, : -args.trim_tail_frames, :, :] + + logger.info(f"Decoding complete. Final video tensor shape: {video_final.shape}") + return video_final + + +def save_output( + video_tensor: torch.Tensor, # Full decoded video [B, C, F, H, W], range [0, 1] + args: argparse.Namespace, + original_base_names: Optional[List[str]] = None, + latent_to_save: Optional[torch.Tensor] = None, # Full latent [B, C, F, H, W] + original_input_frames_np: Optional[List[np.ndarray]] = None # For Extend mode +) -> None: + """save output video, images, or latent, handling concatenation for Extend mode""" + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + seed = args.seed + is_extend_mode = original_input_frames_np is not None + + # --- Determine Final Video Tensor for Saving --- + video_to_save = video_tensor # Default: save the full decoded tensor + final_video_length = video_tensor.shape[2] + final_height = video_tensor.shape[3] + final_width = video_tensor.shape[4] + + if is_extend_mode: + logger.info("Processing output for Extend mode: concatenating original frames with generated frames.") + num_original_frames = len(original_input_frames_np) + + # 1. Prepare original frames tensor: list[HWC uint8] -> tensor[B, C, N, H, W] float32 [0,1] + original_frames_np_stacked = np.stack(original_input_frames_np, axis=0) # [N, H, W, C] + original_frames_tensor = torch.from_numpy(original_frames_np_stacked).permute(0, 3, 1, 2).float() / 255.0 # [N, C, H, W] + original_frames_tensor = original_frames_tensor.permute(1, 0, 2, 3).unsqueeze(0) # [1, C, N, H, W] + original_frames_tensor = original_frames_tensor.to(video_tensor.device, dtype=video_tensor.dtype) # Match decoded tensor attributes + + # 2. Extract the generated part from the decoded tensor + # The decoded tensor includes reconstructed input frames + generated frames + # We only want the part *after* the input frames. + if video_tensor.shape[2] <= num_original_frames: + logger.error(f"Decoded video length ({video_tensor.shape[2]}) is not longer than original frames ({num_original_frames}). Cannot extract generated part.") + # Fallback to saving the full decoded video? Or raise error? + # Let's save the full decoded video for inspection + logger.warning("Saving the full decoded video instead of concatenating.") + else: + generated_part_tensor = video_tensor[:, :, num_original_frames:, :, :] # [B, C, M, H, W] + + # 3. Concatenate original pixel tensor + generated pixel tensor + video_to_save = torch.cat((original_frames_tensor, generated_part_tensor), dim=2) # Concat along Frame dimension + final_video_length = video_to_save.shape[2] # Update final length + logger.info(f"Concatenated original {num_original_frames} frames with generated {generated_part_tensor.shape[2]} frames. Final shape: {video_to_save.shape}") + + # --- Determine Base Filename --- + base_name = f"{time_flag}_{seed}" + if original_base_names: + base_name += f"_{original_base_names[0]}" # Use original name if from latent + elif args.extend_video: + input_video_name = os.path.splitext(os.path.basename(args.extend_video))[0] + base_name += f"_ext_{input_video_name}" + elif args.image_path: + input_image_name = os.path.splitext(os.path.basename(args.image_path))[0] + base_name += f"_i2v_{input_image_name}" + elif args.video_path: + input_video_name = os.path.splitext(os.path.basename(args.video_path))[0] + base_name += f"_v2v_{input_video_name}" + # Add prompt hint? Might be too long + # prompt_hint = "".join(filter(str.isalnum, args.prompt))[:20] + # base_name += f"_{prompt_hint}" + + + # --- Save Latent --- + if (args.output_type == "latent" or args.output_type == "both") and latent_to_save is not None: + latent_path = os.path.join(save_path, f"{base_name}_latent.safetensors") + logger.info(f"Saving latent tensor shape: {latent_to_save.shape}") # Save the full latent + metadata = {} + if not args.no_metadata: + # Get metadata from final saved video dimensions + metadata = { + "prompt": f"{args.prompt}", "negative_prompt": f"{args.negative_prompt or ''}", + "seeds": f"{seed}", "height": f"{final_height}", "width": f"{final_width}", + "video_length": f"{final_video_length}", # Length of the *saved* video/latent + "infer_steps": f"{args.infer_steps}", "guidance_scale": f"{args.guidance_scale}", + "flow_shift": f"{args.flow_shift}", "task": f"{args.task}", + "dit_model": f"{args.dit or os.path.join(args.ckpt_dir, cfg.dit_checkpoint) if args.ckpt_dir else 'N/A'}", + "vae_model": f"{args.vae or os.path.join(args.ckpt_dir, cfg.vae_checkpoint) if args.ckpt_dir else 'N/A'}", + "mode": ("Extend" if is_extend_mode else "I2V" if args.image_path else "V2V" if args.video_path else "T2V"), + } + if is_extend_mode: + metadata["extend_video"] = f"{os.path.basename(args.extend_video)}" + metadata["num_input_frames"] = f"{args.num_input_frames}" + metadata["extend_length"] = f"{args.extend_length}" # Generated part length + metadata["total_processed_length"] = f"{latent_to_save.shape[2]}" # Latent length + # Add other mode details... (V2V strength, I2V image, etc.) + if args.video_path: metadata["v2v_strength"] = f"{args.strength}" + if args.image_path: metadata["i2v_image"] = f"{os.path.basename(args.image_path)}" + if args.end_image_path: metadata["end_image"] = f"{os.path.basename(args.end_image_path)}" + if args.control_path: metadata["funcontrol_video"] = f"{os.path.basename(args.control_path)}" + if args.lora_weight: + metadata["lora_weights"] = ", ".join([os.path.basename(p) for p in args.lora_weight]) + metadata["lora_multipliers"] = ", ".join(map(str, args.lora_multiplier)) + + sd = {"latent": latent_to_save.cpu()} + try: + save_file(sd, latent_path, metadata=metadata) + logger.info(f"Latent saved to: {latent_path}") + except Exception as e: + logger.error(f"Failed to save latent file: {e}") + + + # --- Save Video or Images --- + if args.output_type == "video" or args.output_type == "both": + video_path = os.path.join(save_path, f"{base_name}.mp4") + # save_videos_grid expects [B, T, H, W, C], input is [B, C, T, H, W] range [0, 1] + try: + # Ensure tensor is on CPU for saving function + save_videos_grid(video_to_save.cpu(), video_path, fps=args.fps, rescale=False) + logger.info(f"Video saved to: {video_path}") + except Exception as e: + logger.error(f"Failed to save video file: {e}") + logger.error(f"Video tensor info: shape={video_to_save.shape}, dtype={video_to_save.dtype}, min={video_to_save.min()}, max={video_to_save.max()}") + + elif args.output_type == "images": + image_save_dir = os.path.join(save_path, base_name) + os.makedirs(image_save_dir, exist_ok=True) + # save_images_grid expects [B, T, H, W, C] + try: + save_images_grid(video_to_save.cpu(), image_save_dir, "frame", rescale=False, save_individually=True) + logger.info(f"Image frames saved to directory: {image_save_dir}") + except Exception as e: + logger.error(f"Failed to save image files: {e}") + + +def main(): + # --- Argument Parsing & Setup --- + args = parse_args() + + latents_mode = args.latent_path is not None and len(args.latent_path) > 0 + device_str = args.device if args.device is not None else ("cuda" if torch.cuda.is_available() else "cpu") + args.device = torch.device(device_str) + logger.info(f"Using device: {args.device}") + + generated_latent = None + original_input_frames_np = None # Store original frames for extend mode + cfg = WAN_CONFIGS[args.task] + height, width, video_length = None, None, None + original_base_names = None # For naming output when loading latents + + if not latents_mode: + # --- Generation Mode --- + logger.info("Running in Generation Mode") + args = setup_args(args) # Sets defaults, calculates video_length for extend mode + height, width, video_length = check_inputs(args) # Validate final dimensions + args.video_size = [height, width] + args.video_length = video_length # Ensure video_length is stored in args for processing + + mode_str = ("Extend" if args.extend_video else + "I2V" if args.image_path else + "V2V" if args.video_path else + "T2V" + ("+FunControl" if args.control_path else "")) + if args.control_path and not (args.extend_video or args.image_path or args.video_path): + pass # Already handled above + elif args.control_path: + mode_str += "+FunControl" + + logger.info(f"Mode: {mode_str}") + logger.info( + f"Settings: video size: {height}x{width}, processed length: {video_length} frames, fps: {args.fps}, " + f"infer_steps: {args.infer_steps}, guidance: {args.guidance_scale}, flow_shift: {args.flow_shift}" + ) + if args.extend_video: + logger.info(f" Extend details: Input video: {args.extend_video}, Input frames: {args.num_input_frames}, Generated frames: {args.extend_length}") + + # Core generation pipeline - returns latent and potentially original frames + generated_latent, original_input_frames_np = generate(args) + + if args.save_merged_model: + logger.info("Exiting after saving merged model.") + return + if generated_latent is None: + logger.error("Generation failed or was skipped, exiting.") + return + + # Get dimensions from the *generated latent* for logging/metadata consistency + _, _, lat_f, lat_h, lat_w = generated_latent.shape + processed_pixel_height = lat_h * cfg.vae_stride[1] + processed_pixel_width = lat_w * cfg.vae_stride[2] + processed_pixel_frames = (lat_f - 1) * cfg.vae_stride[0] + 1 + logger.info(f"Generation complete. Processed latent shape: {generated_latent.shape} -> Approx Pixel Video: {processed_pixel_height}x{processed_pixel_width}@{processed_pixel_frames}") + # Note: Final saved dimensions might differ slightly due to concatenation in Extend mode + + else: + # --- Latents Mode --- + logger.info("Running in Latent Loading Mode") + original_base_names = [] + latents_list = [] + seeds = [] + metadata = {} + + if len(args.latent_path) > 1: + logger.warning("Loading multiple latent files is not fully supported. Using first file's info.") + + latent_path = args.latent_path[0] + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + loaded_latent = None + seed = args.seed if args.seed is not None else 0 + + try: + if os.path.splitext(latent_path)[1] != ".safetensors": + logger.warning("Loading non-safetensors latent file. Metadata might be missing.") + loaded_latent = torch.load(latent_path, map_location="cpu") + if isinstance(loaded_latent, dict): + if "latent" in loaded_latent: loaded_latent = loaded_latent["latent"] + elif "state_dict" in loaded_latent: raise ValueError("Loaded file appears to be a model checkpoint.") + else: + first_key = next(iter(loaded_latent)); loaded_latent = loaded_latent[first_key] + else: + loaded_latent = load_file(latent_path, device="cpu")["latent"] + with safe_open(latent_path, framework="pt", device="cpu") as f: metadata = f.metadata() or {} + logger.info(f"Loaded metadata: {metadata}") + # Restore args from metadata if available + if "seeds" in metadata: seed = int(metadata["seeds"]) + if "prompt" in metadata: args.prompt = metadata["prompt"] + if "negative_prompt" in metadata: args.negative_prompt = metadata["negative_prompt"] + # Use metadata dimensions if available, otherwise infer later + if "height" in metadata and "width" in metadata: + height = int(metadata["height"]); width = int(metadata["width"]) + args.video_size = [height, width] + if "video_length" in metadata: # This is the length of the *saved* video/latent + video_length = int(metadata["video_length"]) + args.video_length = video_length # Store the length of the latent data + # Restore other relevant args... + if "guidance_scale" in metadata: args.guidance_scale = float(metadata["guidance_scale"]) + if "infer_steps" in metadata: args.infer_steps = int(metadata["infer_steps"]) + if "flow_shift" in metadata: args.flow_shift = float(metadata["flow_shift"]) + if "mode" in metadata and metadata["mode"] == "Extend": + if "num_input_frames" in metadata: args.num_input_frames = int(metadata["num_input_frames"]) + # Cannot reliably get original frames from latent, so concatenation won't work right + + seeds.append(seed) + latents_list.append(loaded_latent) + logger.info(f"Loaded latent from {latent_path}. Shape: {loaded_latent.shape}, dtype: {loaded_latent.dtype}") + + except Exception as e: + logger.error(f"Failed to load latent file {latent_path}: {e}") + return + + if not latents_list: logger.error("No latent tensors loaded."); return + + generated_latent = torch.stack(latents_list, dim=0) # [B, C, F, H, W] + if len(generated_latent.shape) != 5: raise ValueError(f"Loaded latent shape error: {generated_latent.shape}") + + args.seed = seeds[0] + # Infer pixel dimensions from latent if not fully set by metadata + if height is None or width is None or video_length is None: + logger.warning("Dimensions not fully found in metadata, inferring from latent shape.") + _, _, lat_f, lat_h, lat_w = generated_latent.shape + height = lat_h * cfg.vae_stride[1]; width = lat_w * cfg.vae_stride[2] + video_length = (lat_f - 1) * cfg.vae_stride[0] + 1 # This is the length corresponding to the latent + logger.info(f"Inferred pixel dimensions from latent: {height}x{width}@{video_length}") + args.video_size = [height, width]; args.video_length = video_length + + # --- Decode and Save --- + if generated_latent is not None: + # Decode latent to video tensor [B, C, F, H, W], range [0, 1] + # Note: args.video_length might be different from latent's frame dim if trimmed during decode + decoded_video = decode_latent(generated_latent, args, cfg) + + # Save output (handles Extend mode concatenation inside) + save_output( + decoded_video, args, + original_base_names=original_base_names, + latent_to_save=generated_latent if (args.output_type in ["latent", "both"]) else None, + original_input_frames_np=original_input_frames_np # Pass original frames if in Extend mode + ) + else: + logger.error("No latent available for decoding and saving.") + + logger.info("Done!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/f1_video_cli_local.py b/f1_video_cli_local.py new file mode 100644 index 0000000000000000000000000000000000000000..b68c42183ef78155c94f2a49803024d877fa614e --- /dev/null +++ b/f1_video_cli_local.py @@ -0,0 +1,778 @@ +import os +import torch +import traceback +import einops +import numpy as np +import argparse +import math +import decord +from tqdm import tqdm +import pathlib +from datetime import datetime +import imageio_ffmpeg +import tempfile +import shutil +import subprocess +import sys + +from PIL import Image + +# --- Imports from fpack_generate_video.py's ecosystem --- +from frame_pack.hunyuan_video_packed import load_packed_model +from frame_pack.framepack_utils import ( + load_vae, + load_text_encoder1, + load_text_encoder2, + load_image_encoders +) +from frame_pack.hunyuan import encode_prompt_conds, vae_decode, vae_encode +from frame_pack.utils import crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp +from frame_pack.k_diffusion_hunyuan import sample_hunyuan +from frame_pack.clip_vision import hf_clip_vision_encode +from frame_pack.bucket_tools import find_nearest_bucket +from diffusers_helper.utils import save_bcthw_as_mp4 +from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, \ + move_model_to_device_with_memory_preservation, \ + offload_model_from_device_for_memory_preservation, \ + fake_diffusers_current_device, DynamicSwapInstaller, \ + unload_complete_models, load_model_as_complete + +from networks import lora_framepack +try: + from lycoris.kohya import create_network_from_weights +except ImportError: + pass +from base_wan_generate_video import merge_lora_weights + + +# --- Global Model Variables --- +text_encoder = None +text_encoder_2 = None +tokenizer = None +tokenizer_2 = None +vae = None +feature_extractor = None +image_encoder = None +transformer = None + +high_vram = False +free_mem_gb = 0.0 + +outputs_folder = './outputs/' # Default, can be overridden by --output_dir + +@torch.no_grad() +def image_encode(image_np, target_width, target_height, vae_model, image_encoder_model, feature_extractor_model, device="cuda"): + global high_vram + print("Processing single image for encoding (e.g., start_guidance_image)...") + try: + print(f"Using target resolution for image encoding: {target_width}x{target_height}") + + processed_image_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height) + + image_pt = torch.from_numpy(processed_image_np).float() / 127.5 - 1.0 + image_pt = image_pt.permute(2, 0, 1).unsqueeze(0).unsqueeze(2) + + target_vae_device = device + if not high_vram: load_model_as_complete(vae_model, target_device=target_vae_device) + else: vae_model.to(target_vae_device) + image_pt_device = image_pt.to(target_vae_device) + + latent = vae_encode(image_pt_device, vae_model).cpu() + print(f"Single image VAE output shape (latent): {latent.shape}") + + if not high_vram: unload_complete_models(vae_model) + + target_img_enc_device = device + if not high_vram: load_model_as_complete(image_encoder_model, target_device=target_img_enc_device) + else: image_encoder_model.to(target_img_enc_device) + + clip_embedding_output = hf_clip_vision_encode(processed_image_np, feature_extractor_model, image_encoder_model) + clip_embedding = clip_embedding_output.last_hidden_state.cpu() + print(f"Single image CLIP embedding shape: {clip_embedding.shape}") + + if not high_vram: unload_complete_models(image_encoder_model) + + if device == "cuda": + torch.cuda.empty_cache() + + return latent, clip_embedding + + except Exception as e: + print(f"Error in image_encode: {str(e)}") + traceback.print_exc() + raise + +@torch.no_grad() +def video_encode(video_path, resolution, no_resize, vae_model, vae_batch_size=16, device="cuda", width=None, height=None): + video_path = str(pathlib.Path(video_path).resolve()) + print(f"Processing video for encoding: {video_path}") + + if device == "cuda" and not torch.cuda.is_available(): + print("CUDA is not available, falling back to CPU for video_encode") + device = "cpu" + + try: + print("Initializing VideoReader...") + vr = decord.VideoReader(video_path) + fps = vr.get_avg_fps() + if fps == 0: + print("Warning: VideoReader reported FPS as 0. Attempting to get it via OpenCV.") + import cv2 + cap = cv2.VideoCapture(video_path) + fps_cv = cap.get(cv2.CAP_PROP_FPS) + cap.release() + if fps_cv > 0: + fps = fps_cv + print(f"Using FPS from OpenCV: {fps}") + else: + raise ValueError("Failed to determine FPS for the input video.") + + num_real_frames = len(vr) + print(f"Video loaded: {num_real_frames} frames, FPS: {fps}") + + latent_size_factor = 4 + num_frames = (num_real_frames // latent_size_factor) * latent_size_factor + if num_frames != num_real_frames: + print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility") + + if num_frames == 0: + raise ValueError(f"Video too short ({num_real_frames} frames) or becomes 0 after truncation. Needs at least {latent_size_factor} frames.") + num_real_frames = num_frames + + print("Reading video frames...") + frames_np_all = vr.get_batch(range(num_real_frames)).asnumpy() + print(f"Frames read: {frames_np_all.shape}") + + native_height, native_width = frames_np_all.shape[1], frames_np_all.shape[2] + print(f"Native video resolution: {native_width}x{native_height}") + + target_h_arg = native_height if height is None else height + target_w_arg = native_width if width is None else width + + if not no_resize: + actual_target_height, actual_target_width = find_nearest_bucket(target_h_arg, target_w_arg, resolution=resolution) + print(f"Adjusted resolution for VAE encoding: {actual_target_width}x{actual_target_height}") + else: + actual_target_width = (native_width // 8) * 8 + actual_target_height = (native_height // 8) * 8 + if actual_target_width != native_width or actual_target_height != native_height: + print(f"Using native resolution, adjusted to be divisible by 8: {actual_target_width}x{actual_target_height}") + else: + print(f"Using native resolution without resizing: {actual_target_width}x{actual_target_height}") + + processed_frames_list = [] + for frame_idx in range(frames_np_all.shape[0]): + frame = frames_np_all[frame_idx] + frame_resized_np = resize_and_center_crop(frame, target_width=actual_target_width, target_height=actual_target_height) + processed_frames_list.append(frame_resized_np) + + processed_frames_np_stack = np.stack(processed_frames_list) + print(f"Frames preprocessed: {processed_frames_np_stack.shape}") + + input_image_np_for_clip = processed_frames_np_stack[0] + + print("Converting frames to tensor...") + frames_pt = torch.from_numpy(processed_frames_np_stack).float() / 127.5 - 1.0 + frames_pt = frames_pt.permute(0, 3, 1, 2) + frames_pt = frames_pt.unsqueeze(0).permute(0, 2, 1, 3, 4) + print(f"Tensor shape for VAE: {frames_pt.shape}") + + input_video_pixels_cpu = frames_pt.clone().cpu() + + print(f"Moving VAE and tensor to device: {device}") + vae_model.to(device) + frames_pt = frames_pt.to(device) + + print(f"Encoding input video frames with VAE (batch size: {vae_batch_size})") + all_latents_list = [] + vae_model.eval() + with torch.no_grad(): + for i in tqdm(range(0, frames_pt.shape[2], vae_batch_size), desc="VAE Encoding Video Frames", mininterval=0.1): + batch_frames_pt = frames_pt[:, :, i:i + vae_batch_size] + try: + batch_latents = vae_encode(batch_frames_pt, vae_model) + all_latents_list.append(batch_latents.cpu()) + except RuntimeError as e: + print(f"Error during VAE encoding: {str(e)}") + if "out of memory" in str(e).lower() and device == "cuda": + print("CUDA out of memory during VAE encoding. Try reducing --vae_batch_size or use CPU for VAE.") + raise + + history_latents_cpu = torch.cat(all_latents_list, dim=2) + print(f"History latents shape (original video): {history_latents_cpu.shape}") + + start_latent_cpu = history_latents_cpu[:, :, :1].clone() + print(f"Start latent shape (for conditioning): {start_latent_cpu.shape}") + + if device == "cuda": + vae_model.to(cpu) + torch.cuda.empty_cache() + print("VAE moved back to CPU, CUDA cache cleared") + + return start_latent_cpu, input_image_np_for_clip, history_latents_cpu, fps, actual_target_height, actual_target_width, input_video_pixels_cpu + + except Exception as e: + print(f"Error in video_encode: {str(e)}") + traceback.print_exc() + raise + +def set_mp4_comments_imageio_ffmpeg(input_file, comments): + try: + ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() + if not os.path.exists(input_file): + print(f"Error: Input file {input_file} does not exist") + return False + temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name + command = [ + ffmpeg_path, '-i', input_file, '-metadata', f'comment={comments}', + '-c:v', 'copy', '-c:a', 'copy', '-y', temp_file + ] + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False) + if result.returncode == 0: + shutil.move(temp_file, input_file) + print(f"Successfully added comments to {input_file}") + return True + else: + if os.path.exists(temp_file): os.remove(temp_file) + print(f"Error: FFmpeg failed with message:\n{result.stderr}") + return False + except Exception as e: + if 'temp_file' in locals() and os.path.exists(temp_file): os.remove(temp_file) + print(f"Error saving prompt to video metadata, ffmpeg may be required: "+str(e)) + return False + +@torch.no_grad() +def do_extension_work( + input_video_path, prompt, n_prompt, seed, + resolution_max_dim, + additional_second_length, + latent_window_size, steps, cfg, gs, rs, + gpu_memory_preservation, use_teacache, no_resize, mp4_crf, + num_clean_frames, vae_batch_size, + extension_only +): + global high_vram, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, feature_extractor, image_encoder, transformer, args + + print('--- Starting Video Extension Work (with optional Start Guidance Image) ---') + + try: + if not high_vram: + unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer) + + print('Text encoding for extension...') + target_text_enc_device = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram: + if text_encoder: fake_diffusers_current_device(text_encoder, target_text_enc_device) + if text_encoder_2: load_model_as_complete(text_encoder_2, target_device=target_text_enc_device) + else: + if text_encoder: text_encoder.to(target_text_enc_device) + if text_encoder_2: text_encoder_2.to(target_text_enc_device) + + llama_vec_gpu, clip_l_pooler_gpu = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) + if cfg == 1.0: + llama_vec_n_gpu, clip_l_pooler_n_gpu = torch.zeros_like(llama_vec_gpu), torch.zeros_like(clip_l_pooler_gpu) + else: + llama_vec_n_gpu, clip_l_pooler_n_gpu = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) + + llama_vec_padded_cpu, llama_attention_mask_cpu = crop_or_pad_yield_mask(llama_vec_gpu.cpu(), length=512) + llama_vec_n_padded_cpu, llama_attention_mask_n_cpu = crop_or_pad_yield_mask(llama_vec_n_gpu.cpu(), length=512) + clip_l_pooler_cpu = clip_l_pooler_gpu.cpu() + clip_l_pooler_n_cpu = clip_l_pooler_n_gpu.cpu() + + if not high_vram: unload_complete_models(text_encoder_2) + + + print('Encoding input video for extension base...') + video_encode_device = str(gpu if torch.cuda.is_available() else cpu) + start_latent_input_video_cpu, input_image_np_for_clip, video_latents_history_cpu, fps, height, width, _ = video_encode( + input_video_path, resolution_max_dim, no_resize, vae, vae_batch_size=vae_batch_size, device=video_encode_device + ) + if fps <= 0: + raise ValueError("FPS from input video is 0 or invalid. Cannot proceed with extension.") + + guidance_latent_cpu = None + guidance_clip_embedding_cpu = None + + if args.start_guidance_image: + print(f"Encoding provided start guidance image from: {args.start_guidance_image}") + try: + guidance_pil = Image.open(args.start_guidance_image).convert("RGB") + guidance_np = np.array(guidance_pil) + + guidance_latent_cpu, guidance_clip_embedding_cpu = image_encode( + guidance_np, target_width=width, target_height=height, + vae_model=vae, image_encoder_model=image_encoder, + feature_extractor_model=feature_extractor, device=video_encode_device + ) + print("Start guidance image encoded successfully.") + except Exception as e_img_enc: + print(f"Warning: Could not encode start_guidance_image: {e_img_enc}. Proceeding without it.") + guidance_latent_cpu = None + guidance_clip_embedding_cpu = None + + print('CLIP Vision encoding for input video (first frame)...') + target_img_enc_device = str(gpu if torch.cuda.is_available() else cpu) + image_encoder_was_already_on_gpu = False + if image_encoder is not None and hasattr(image_encoder, 'device') and image_encoder.device.type == 'cuda': + image_encoder_was_already_on_gpu = True + + if not image_encoder_was_already_on_gpu: + if not high_vram: + if image_encoder: load_model_as_complete(image_encoder, target_device=target_img_enc_device) + else: + if image_encoder: image_encoder.to(target_img_enc_device) + + input_video_first_frame_clip_output = hf_clip_vision_encode(input_image_np_for_clip, feature_extractor, image_encoder) + input_video_first_frame_clip_embedding_cpu = input_video_first_frame_clip_output.last_hidden_state.cpu() + + final_clip_embedding_for_sampling_cpu = input_video_first_frame_clip_embedding_cpu.clone() + if guidance_clip_embedding_cpu is not None and args.start_guidance_image_clip_weight > 0: + print(f"Blending input video's first frame CLIP with guidance image CLIP (weight: {args.start_guidance_image_clip_weight})") + final_clip_embedding_for_sampling_cpu = \ + (1.0 - args.start_guidance_image_clip_weight) * input_video_first_frame_clip_embedding_cpu + \ + args.start_guidance_image_clip_weight * guidance_clip_embedding_cpu + elif guidance_clip_embedding_cpu is not None and args.start_guidance_image_clip_weight == 0: + print("Guidance image provided, but weight is 0. Using input video's first frame CLIP only.") + else: + print("Using input video's first frame CLIP embedding for image conditioning (no guidance image or weight is 0).") + + if not image_encoder_was_already_on_gpu: + if not high_vram and image_encoder: unload_complete_models(image_encoder) + + + target_transformer_device = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram: + if transformer: move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation) + else: + if transformer: transformer.to(target_transformer_device) + + cond_device = transformer.device + cond_dtype = transformer.dtype + + llama_vec = llama_vec_padded_cpu.to(device=cond_device, dtype=cond_dtype) + llama_attention_mask = llama_attention_mask_cpu.to(device=cond_device) + llama_vec_n = llama_vec_n_padded_cpu.to(device=cond_device, dtype=cond_dtype) + llama_attention_mask_n = llama_attention_mask_n_cpu.to(device=cond_device) + clip_l_pooler = clip_l_pooler_cpu.to(device=cond_device, dtype=cond_dtype) + clip_l_pooler_n = clip_l_pooler_n_cpu.to(device=cond_device, dtype=cond_dtype) + + image_embeddings_for_sampling_loop = final_clip_embedding_for_sampling_cpu.to(device=cond_device, dtype=cond_dtype) + + start_latent_from_input_video_gpu = start_latent_input_video_cpu.to(device=cond_device, dtype=torch.float32) + + + num_output_pixel_frames_per_section = latent_window_size * 4 + if num_output_pixel_frames_per_section == 0: + raise ValueError("latent_window_size * 4 is zero, cannot calculate total_extension_latent_sections.") + total_extension_latent_sections = int(max(round((additional_second_length * fps) / num_output_pixel_frames_per_section), 1)) + + print(f"Input video FPS: {fps}, Target additional length: {additional_second_length}s") + print(f"Generating {total_extension_latent_sections} new sections for extension (approx {total_extension_latent_sections * num_output_pixel_frames_per_section / fps:.2f}s).") + + job_id_base = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + \ + f"_framepackf1-vidEXT_{width}x{height}_{additional_second_length:.1f}s_seed{seed}_s{steps}_gs{gs}_cfg{cfg}" + + job_id = job_id_base + if extension_only: + job_id += "_extonly" + print("Extension-only mode enabled. Filenames will reflect this.") + + rnd = torch.Generator("cpu").manual_seed(seed) + + history_latents_combined_cpu = video_latents_history_cpu.clone() + + print("Decoding original input video content for appending...") + target_vae_device_for_initial_decode = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram: + if vae: load_model_as_complete(vae, target_device=target_vae_device_for_initial_decode) + else: + if vae: vae.to(target_vae_device_for_initial_decode) + + initial_video_pixels_cpu = vae_decode(video_latents_history_cpu.to(target_vae_device_for_initial_decode), vae).cpu() + if extension_only: + history_pixels_decoded_cpu = None + print("Extension only mode: Intermediate and final videos will contain only the generated extension.") + else: + history_pixels_decoded_cpu = initial_video_pixels_cpu.clone() + print("Normal mode: Intermediate and final videos will contain input video + extension.") + + if not high_vram and vae: unload_complete_models(vae) + + total_current_pixel_frames_count = history_pixels_decoded_cpu.shape[2] if history_pixels_decoded_cpu is not None else 0 + previous_video_path_for_cleanup = None + + initial_guidance_clip_weight = args.start_guidance_image_clip_weight + num_guidance_fade_sections = min(3, total_extension_latent_sections) + + + for section_index in range(total_extension_latent_sections): + print(f"--- F1 Extension: Seed {seed}: Section {section_index + 1}/{total_extension_latent_sections} ---") + + if transformer: transformer.initialize_teacache(enable_teacache=use_teacache, num_steps=steps if use_teacache else 0) + + progress_bar_sampler = tqdm(total=steps, desc=f"Sampling Extension Section {section_index+1}/{total_extension_latent_sections}", file=sys.stdout) + + def sampler_callback_cli(d): + progress_bar_sampler.update(1) + + available_latents_count_cpu = history_latents_combined_cpu.shape[2] + pixel_frames_to_generate_this_step = latent_window_size * 4 - 3 + adjusted_latent_frames_for_output = (pixel_frames_to_generate_this_step + 3) // 4 + + base_effective_clean_frames = max(0, args.num_clean_frames -1) if args.num_clean_frames > 1 else 0 + + effective_clean_frames_count_section = base_effective_clean_frames + effective_clean_frames_count_section = min(effective_clean_frames_count_section, max(0, available_latents_count_cpu - 1 - (2 if available_latents_count_cpu > 3 else 0) )) + + num_2x_frames_count_section = min(2, max(0, available_latents_count_cpu - effective_clean_frames_count_section -1)) + num_4x_frames_count_section = min(16, max(0, available_latents_count_cpu - effective_clean_frames_count_section - num_2x_frames_count_section -1)) + + if section_index == 0 and args.use_guidance_image_as_first_latent and guidance_latent_cpu is not None: + print("First section with guidance VAE: Forcing 0 historical clean/2x/4x frames from input video.") + effective_clean_frames_count_section = 0 + num_2x_frames_count_section = 0 + num_4x_frames_count_section = 0 + + print(f"Section {section_index+1}: Effective Context Counts: 1x={effective_clean_frames_count_section}, 2x={num_2x_frames_count_section}, 4x={num_4x_frames_count_section}") + + total_context_latents_count = num_4x_frames_count_section + num_2x_frames_count_section + effective_clean_frames_count_section + total_context_latents_count = min(total_context_latents_count, available_latents_count_cpu) + + indices_tensor_gpu = torch.arange(0, sum([ + 1, + num_4x_frames_count_section, + num_2x_frames_count_section, + effective_clean_frames_count_section, + adjusted_latent_frames_for_output + ])).unsqueeze(0).to(cond_device) + + clean_latent_indices_start_gpu, \ + clean_latent_4x_indices_gpu, \ + clean_latent_2x_indices_gpu, \ + clean_latent_1x_indices_gpu, \ + latent_indices_for_denoising_gpu = indices_tensor_gpu.split( + [1, num_4x_frames_count_section, num_2x_frames_count_section, effective_clean_frames_count_section, adjusted_latent_frames_for_output], dim=1 + ) + clean_latent_indices_combined_gpu = torch.cat([clean_latent_indices_start_gpu, clean_latent_1x_indices_gpu], dim=1) + + context_latents_for_split_cpu = history_latents_combined_cpu[:, :, -total_context_latents_count:, :, :] if total_context_latents_count > 0 else torch.empty((1,history_latents_combined_cpu.shape[1],0,height//8,width//8), dtype=torch.float32) + + clean_latents_4x_gpu_data = torch.empty((1,history_latents_combined_cpu.shape[1],0,height//8,width//8), device=cond_device, dtype=torch.float32) + clean_latents_2x_gpu_data = torch.empty((1,history_latents_combined_cpu.shape[1],0,height//8,width//8), device=cond_device, dtype=torch.float32) + clean_latents_1x_gpu_data = torch.empty((1,history_latents_combined_cpu.shape[1],0,height//8,width//8), device=cond_device, dtype=torch.float32) + + current_offset_in_context_cpu = 0 + if num_4x_frames_count_section > 0 and total_context_latents_count > 0 and current_offset_in_context_cpu < context_latents_for_split_cpu.shape[2]: + slice_end = min(current_offset_in_context_cpu + num_4x_frames_count_section, context_latents_for_split_cpu.shape[2]) + clean_latents_4x_gpu_data = context_latents_for_split_cpu[:, :, current_offset_in_context_cpu:slice_end].to(device=cond_device, dtype=torch.float32) + current_offset_in_context_cpu += clean_latents_4x_gpu_data.shape[2] + + if num_2x_frames_count_section > 0 and total_context_latents_count > 0 and current_offset_in_context_cpu < context_latents_for_split_cpu.shape[2]: + slice_end = min(current_offset_in_context_cpu + num_2x_frames_count_section, context_latents_for_split_cpu.shape[2]) + clean_latents_2x_gpu_data = context_latents_for_split_cpu[:, :, current_offset_in_context_cpu:slice_end].to(device=cond_device, dtype=torch.float32) + current_offset_in_context_cpu += clean_latents_2x_gpu_data.shape[2] + + if effective_clean_frames_count_section > 0 and total_context_latents_count > 0 and current_offset_in_context_cpu < context_latents_for_split_cpu.shape[2]: + slice_end = min(current_offset_in_context_cpu + effective_clean_frames_count_section, context_latents_for_split_cpu.shape[2]) + clean_latents_1x_gpu_data = context_latents_for_split_cpu[:, :, current_offset_in_context_cpu:slice_end].to(device=cond_device, dtype=torch.float32) + + actual_start_latent_for_clean_latents_gpu = start_latent_from_input_video_gpu + if section_index == 0 and args.use_guidance_image_as_first_latent and guidance_latent_cpu is not None: + print("Using guidance image VAE latent as the start_latent for the first generated segment.") + actual_start_latent_for_clean_latents_gpu = guidance_latent_cpu.to(device=cond_device, dtype=torch.float32) + elif section_index == 0: + print("Using input video's first VAE latent as start_latent for first generated segment.") + + clean_latents_for_sampler_gpu = torch.cat([actual_start_latent_for_clean_latents_gpu, clean_latents_1x_gpu_data], dim=2) + + current_guidance_clip_weight = 0.0 + if guidance_clip_embedding_cpu is not None and initial_guidance_clip_weight > 0: + if section_index < num_guidance_fade_sections: + current_guidance_clip_weight = initial_guidance_clip_weight * (1.0 - (section_index / float(num_guidance_fade_sections))) + print(f"Section {section_index+1}: Current guidance CLIP weight: {current_guidance_clip_weight:.2f}") + else: + current_guidance_clip_weight = 0.0 + print(f"Section {section_index+1}: Guidance CLIP weight faded to 0.") + + if current_guidance_clip_weight > 0 and guidance_clip_embedding_cpu is not None : + current_image_embeddings_for_sampling_cpu = \ + (1.0 - current_guidance_clip_weight) * input_video_first_frame_clip_embedding_cpu + \ + current_guidance_clip_weight * guidance_clip_embedding_cpu + else: + current_image_embeddings_for_sampling_cpu = input_video_first_frame_clip_embedding_cpu.clone() + + current_image_embeddings_for_sampling_gpu = current_image_embeddings_for_sampling_cpu.to(device=cond_device, dtype=cond_dtype) + + generated_latents_gpu_step = sample_hunyuan( + transformer=transformer, sampler='unipc', width=width, height=height, + frames=pixel_frames_to_generate_this_step, + real_guidance_scale=cfg, distilled_guidance_scale=gs, guidance_rescale=rs, + num_inference_steps=steps, generator=rnd, + prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler, + negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n, + device=cond_device, dtype=cond_dtype, + image_embeddings=current_image_embeddings_for_sampling_gpu, + latent_indices=latent_indices_for_denoising_gpu, + clean_latents=clean_latents_for_sampler_gpu, + clean_latent_indices=clean_latent_indices_combined_gpu, + clean_latents_2x=clean_latents_2x_gpu_data if num_2x_frames_count_section > 0 else None, + clean_latent_2x_indices=clean_latent_2x_indices_gpu if num_2x_frames_count_section > 0 else None, + clean_latents_4x=clean_latents_4x_gpu_data if num_4x_frames_count_section > 0 else None, + clean_latent_4x_indices=clean_latent_4x_indices_gpu if num_4x_frames_count_section > 0 else None, + callback=sampler_callback_cli, + ) + if progress_bar_sampler: progress_bar_sampler.close() + + history_latents_combined_cpu = torch.cat([history_latents_combined_cpu, generated_latents_gpu_step.cpu()], dim=2) + + target_vae_device = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram: + if transformer: offload_model_from_device_for_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation) + if vae: load_model_as_complete(vae, target_device=target_vae_device) + else: + if vae: vae.to(target_vae_device) + + num_latents_for_stitch_decode = latent_window_size * 2 + num_latents_for_stitch_decode = min(num_latents_for_stitch_decode, history_latents_combined_cpu.shape[2]) + latents_for_current_part_decode_gpu = history_latents_combined_cpu[:, :, -num_latents_for_stitch_decode:].to(target_vae_device) + + pixels_for_current_part_decoded_cpu = vae_decode( + latents_for_current_part_decode_gpu, + vae + ).cpu() + + if extension_only and history_pixels_decoded_cpu is None: + history_pixels_decoded_cpu = pixels_for_current_part_decoded_cpu + else: + overlap_for_soft_append = latent_window_size * 4 - 3 + overlap_for_soft_append = min(overlap_for_soft_append, history_pixels_decoded_cpu.shape[2], pixels_for_current_part_decoded_cpu.shape[2]) + + if overlap_for_soft_append <= 0: + history_pixels_decoded_cpu = torch.cat([history_pixels_decoded_cpu, pixels_for_current_part_decoded_cpu], dim=2) + else: + history_pixels_decoded_cpu = soft_append_bcthw( + history_pixels_decoded_cpu, + pixels_for_current_part_decoded_cpu, + overlap=overlap_for_soft_append + ) + + total_current_pixel_frames_count = history_pixels_decoded_cpu.shape[2] + + if not high_vram: + if vae: unload_complete_models(vae) + if transformer and not (section_index == total_extension_latent_sections - 1): + move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation) + + current_output_filename = os.path.join(outputs_folder, f'{job_id}_part{section_index + 1}_totalframes{history_pixels_decoded_cpu.shape[2]}.mp4') + save_bcthw_as_mp4(history_pixels_decoded_cpu, current_output_filename, fps=fps, crf=mp4_crf) + print(f"MP4 Preview for section {section_index + 1} saved: {current_output_filename}") + set_mp4_comments_imageio_ffmpeg(current_output_filename, f"Prompt: {prompt} | Neg: {n_prompt} | Seed: {seed}"); + + if previous_video_path_for_cleanup is not None and os.path.exists(previous_video_path_for_cleanup): + try: + os.remove(previous_video_path_for_cleanup) + print(f"Cleaned up previous part: {previous_video_path_for_cleanup}") + except Exception as e_del: + print(f"Error deleting previous partial video {previous_video_path_for_cleanup}: {e_del}") + previous_video_path_for_cleanup = current_output_filename + + final_video_path_for_item = previous_video_path_for_cleanup + if extension_only: + print(f"Final extension-only video for seed {seed} saved as: {final_video_path_for_item}") + else: + print(f"Final video for seed {seed} (extension) saved as: {final_video_path_for_item}") + + except Exception as e_outer: + traceback.print_exc() + print(f"Error during extension generation: {e_outer}") + + finally: + if not high_vram: + unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer) + print("--- Extension work cycle finished. ---") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="FramePack F1 Video Extension CLI") + + parser.add_argument('--input_video', type=str, required=True, help='Path to the input video file for extension.') + parser.add_argument('--prompt', type=str, required=True, help='Prompt for video generation.') + parser.add_argument('--n_prompt', type=str, default="", help='Negative prompt.') + parser.add_argument('--seed', type=int, default=31337, help='Seed for generation.') + parser.add_argument('--resolution_max_dim', type=int, default=640, help='Target resolution (max width or height for bucket search).') + parser.add_argument('--total_second_length', type=float, default=5.0, help='Additional video length to generate (seconds).') + parser.add_argument('--latent_window_size', type=int, default=9, help='Latent window size (frames).') + parser.add_argument('--steps', type=int, default=25, help='Number of inference steps.') + parser.add_argument('--cfg', type=float, default=1.0, help='CFG Scale (Classifier Free Guidance).') + parser.add_argument('--gs', type=float, default=3.0, help='Distilled CFG Scale (Embedded CFG).') + parser.add_argument('--rs', type=float, default=0.0, help='CFG Re-Scale (usually 0.0).') + parser.add_argument('--gpu_memory_preservation', type=float, default=6.0, help='GPU memory to preserve (GB) for low VRAM mode.') + parser.add_argument('--use_teacache', action='store_true', default=False, help='Enable TeaCache.') + parser.add_argument('--no_resize', action='store_true', default=False, help='Force original video resolution for input video encoding.') + parser.add_argument('--mp4_crf', type=int, default=16, help='MP4 CRF value (0-51, lower is better quality).') + parser.add_argument('--num_clean_frames', type=int, default=5, help='Number of 1x context frames from input video history for DiT conditioning.') + parser.add_argument('--vae_batch_size', type=int, default=-1, help='VAE batch size for input video encoding. Default: auto based on VRAM.') + parser.add_argument('--output_dir', type=str, default='./outputs/', help="Directory to save output videos.") + + parser.add_argument('--dit', type=str, required=True, help="Path to local DiT model weights file or directory.") + parser.add_argument('--vae', type=str, required=True, help="Path to local VAE model weights file or directory.") + parser.add_argument('--text_encoder1', type=str, required=True, help="Path to Text Encoder 1 (Llama) WEIGHT FILE.") + parser.add_argument('--text_encoder2', type=str, required=True, help="Path to Text Encoder 2 (CLIP) WEIGHT FILE.") + parser.add_argument('--image_encoder', type=str, required=True, help="Path to Image Encoder (SigLIP) WEIGHT FILE.") + + parser.add_argument('--attn_mode', type=str, default="torch", help="Attention mode for DiT.") + parser.add_argument('--fp8_llm', action='store_true', help="Use fp8 for Text Encoder 1 (Llama).") + parser.add_argument("--vae_chunk_size", type=int, default=None, help="Chunk size for CausalConv3d in VAE.") + parser.add_argument("--vae_spatial_tile_sample_min_size", type=int, default=None, help="Spatial tile sample min size for VAE.") + + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path(s).") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier(s).") + parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns.") + parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns.") + parser.add_argument('--extension_only', action='store_true', help="Save only the extension video without the input video attached.") + parser.add_argument('--start_guidance_image', type=str, default=None, + help='Optional path to an image to guide the start of the generated extension.') + parser.add_argument('--start_guidance_image_clip_weight', type=float, default=0.75, + help='Weight for the start_guidance_image CLIP embedding (0.0 to 1.0). Default 0.75. Blends with input video\'s first frame CLIP.') + parser.add_argument('--use_guidance_image_as_first_latent', action='store_true', default=False, + help='If true, use the VAE latent of the start_guidance_image as the initial conditioning latent for the first generated segment.') + + args = parser.parse_args() + + current_device_str = str(gpu if torch.cuda.is_available() else cpu) + args.device = current_device_str + + for model_arg_name in ['dit', 'vae', 'text_encoder1', 'text_encoder2', 'image_encoder']: + path_val = getattr(args, model_arg_name) + if not os.path.exists(path_val): + parser.error(f"Path for --{model_arg_name} not found: {path_val}") + + outputs_folder = args.output_dir + os.makedirs(outputs_folder, exist_ok=True) + print(f"Outputting extensions to: {outputs_folder}") + + free_mem_gb = get_cuda_free_memory_gb(gpu if torch.cuda.is_available() else None) + high_vram = free_mem_gb > 100 + print(f'Free VRAM {free_mem_gb:.2f} GB. High-VRAM Mode: {high_vram}') + + if args.vae_batch_size == -1: + if free_mem_gb >= 18: args.vae_batch_size = 64 + elif free_mem_gb >= 10: args.vae_batch_size = 32 + else: args.vae_batch_size = 16 + print(f"Auto-set VAE batch size to: {args.vae_batch_size}") + + print("Loading models for extension...") + loading_device_str = str(cpu) + + transformer = load_packed_model( + device=loading_device_str, + dit_path=args.dit, + attn_mode=args.attn_mode, + loading_device=loading_device_str + ) + print("DiT loaded.") + + if args.lora_weight is not None and len(args.lora_weight) > 0: + print("Merging LoRA weights for extension...") + if len(args.lora_multiplier) == 1 and len(args.lora_weight) > 1: + args.lora_multiplier = args.lora_multiplier * len(args.lora_weight) + elif len(args.lora_multiplier) != len(args.lora_weight): + parser.error(f"Number of LoRA weights ({len(args.lora_weight)}) and multipliers ({len(args.lora_multiplier)}) must match, or provide a single multiplier.") + + try: + if not hasattr(args, 'lycoris'): + args.lycoris = False + if not hasattr(args, 'save_merged_model'): + args.save_merged_model = None + current_device_for_lora = torch.device(loading_device_str) + + + merge_lora_weights( + lora_framepack, + transformer, + args, + current_device_for_lora + ) + print("LoRA weights merged successfully using the same call structure as fpack_generate_video.py.") + + except Exception as e_lora: + print(f"Error merging LoRA weights: {e_lora}") + traceback.print_exc() + + vae = load_vae( + vae_path=args.vae, + vae_chunk_size=args.vae_chunk_size, + vae_spatial_tile_sample_min_size=args.vae_spatial_tile_sample_min_size, + device=loading_device_str + ) + print("VAE loaded.") + + tokenizer, text_encoder = load_text_encoder1(args, device=loading_device_str) + print("Text Encoder 1 and Tokenizer 1 loaded.") + tokenizer_2, text_encoder_2 = load_text_encoder2(args) + print("Text Encoder 2 and Tokenizer 2 loaded.") + feature_extractor, image_encoder = load_image_encoders(args) + print("Image Encoder and Feature Extractor loaded.") + + all_models_list = [transformer, vae, text_encoder, text_encoder_2, image_encoder] + for model_obj in all_models_list: + if model_obj is not None: + model_obj.eval().requires_grad_(False) + + if transformer: transformer.to(dtype=torch.bfloat16) + if vae: vae.to(dtype=torch.float16) + if image_encoder: image_encoder.to(dtype=torch.float16) + if text_encoder: text_encoder.to(dtype=torch.float16) + if text_encoder_2: text_encoder_2.to(dtype=torch.float16) + + if transformer: + transformer.high_quality_fp32_output_for_inference = True + print('Transformer: high_quality_fp32_output_for_inference = True') + + if vae and not high_vram: + vae.enable_slicing() + vae.enable_tiling() + + target_gpu_device_str = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram and torch.cuda.is_available(): + print("Low VRAM mode: Setting up dynamic swapping for DiT and Text Encoder 1.") + if transformer: DynamicSwapInstaller.install_model(transformer, device=target_gpu_device_str) + if text_encoder: DynamicSwapInstaller.install_model(text_encoder, device=target_gpu_device_str) + if vae: vae.to(cpu) + if text_encoder_2: text_encoder_2.to(cpu) + if image_encoder: image_encoder.to(cpu) + elif torch.cuda.is_available(): + print(f"High VRAM mode: Moving all models to {target_gpu_device_str}.") + for model_obj in all_models_list: + if model_obj is not None: model_obj.to(target_gpu_device_str) + else: + print("Running on CPU. Models remain on CPU.") + + print("All models loaded and configured for extension.") + + actual_gs_cli = args.gs + if args.cfg > 1.0: + actual_gs_cli = 1.0 + print(f"CFG > 1.0 detected ({args.cfg}), overriding GS to 1.0 from {args.gs}.") + + do_extension_work( + input_video_path=args.input_video, + prompt=args.prompt, + n_prompt=args.n_prompt, + seed=args.seed, + resolution_max_dim=args.resolution_max_dim, + additional_second_length=args.total_second_length, + latent_window_size=args.latent_window_size, + steps=args.steps, + cfg=args.cfg, + gs=actual_gs_cli, + rs=args.rs, + gpu_memory_preservation=args.gpu_memory_preservation, + use_teacache=args.use_teacache, + no_resize=args.no_resize, + mp4_crf=args.mp4_crf, + num_clean_frames=args.num_clean_frames, + vae_batch_size=args.vae_batch_size, + extension_only=args.extension_only + ) + + print("Video extension process completed.") \ No newline at end of file diff --git a/f_video_end_cli_local.py b/f_video_end_cli_local.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d52d63d01773f563e9e844c2acbcc59c07d44b --- /dev/null +++ b/f_video_end_cli_local.py @@ -0,0 +1,854 @@ +import os +import torch +import traceback +import einops +import numpy as np +import argparse +import math +import decord +from tqdm import tqdm +import pathlib +from datetime import datetime +import imageio_ffmpeg +import tempfile +import shutil +import subprocess +import sys + +from PIL import Image +try: + from frame_pack.hunyuan_video_packed import load_packed_model + from frame_pack.framepack_utils import ( + load_vae, + load_text_encoder1, + load_text_encoder2, + load_image_encoders + ) + from frame_pack.hunyuan import encode_prompt_conds, vae_decode, vae_encode # vae_decode_fake might be needed for previews if added + from frame_pack.utils import crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp + from frame_pack.k_diffusion_hunyuan import sample_hunyuan + from frame_pack.clip_vision import hf_clip_vision_encode + from frame_pack.bucket_tools import find_nearest_bucket + from diffusers_helper.utils import save_bcthw_as_mp4 # from a common helper library + from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, \ + move_model_to_device_with_memory_preservation, \ + offload_model_from_device_for_memory_preservation, \ + fake_diffusers_current_device, DynamicSwapInstaller, \ + unload_complete_models, load_model_as_complete + # For LoRA + from networks import lora_framepack + try: + from lycoris.kohya import create_network_from_weights + except ImportError: + pass # Lycoris optional + from base_wan_generate_video import merge_lora_weights # Assuming this is accessible +except ImportError as e: + print(f"Error importing FramePack related modules: {e}. Ensure they are in PYTHONPATH.") + sys.exit(1) + + +# --- Global Model Variables --- +text_encoder = None +text_encoder_2 = None +tokenizer = None +tokenizer_2 = None +vae = None +feature_extractor = None +image_encoder = None +transformer = None + +high_vram = False +free_mem_gb = 0.0 + +outputs_folder = './outputs/' # Default, can be overridden by --output_dir + +@torch.no_grad() +def video_encode(video_path, resolution, no_resize, vae_model, vae_batch_size=16, device="cuda", width=None, height=None): + video_path = str(pathlib.Path(video_path).resolve()) + print(f"Processing video for encoding: {video_path}") + + if device == "cuda" and not torch.cuda.is_available(): + print("CUDA is not available, falling back to CPU for video_encode") + device = "cpu" + + try: + print("Initializing VideoReader...") + vr = decord.VideoReader(video_path) + fps = vr.get_avg_fps() + if fps == 0: + print("Warning: VideoReader reported FPS as 0. Attempting to get it via OpenCV.") + import cv2 + cap = cv2.VideoCapture(video_path) + fps_cv = cap.get(cv2.CAP_PROP_FPS) + cap.release() + if fps_cv > 0: + fps = fps_cv + print(f"Using FPS from OpenCV: {fps}") + else: + # Fallback FPS if all else fails + fps = 25 + print(f"Failed to determine FPS for the input video. Defaulting to {fps} FPS.") + + + num_real_frames = len(vr) + print(f"Video loaded: {num_real_frames} frames, FPS: {fps}") + + latent_size_factor = 4 # Hunyuan VAE downsamples by 8, but generation often uses 4x frame groups + num_frames = (num_real_frames // latent_size_factor) * latent_size_factor + if num_frames != num_real_frames: + print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility (multiple of {latent_size_factor})") + + if num_frames == 0: + raise ValueError(f"Video too short ({num_real_frames} frames) or becomes 0 after truncation. Needs at least {latent_size_factor} frames.") + num_real_frames = num_frames + + print("Reading video frames...") + frames_np_all = vr.get_batch(range(num_real_frames)).asnumpy() + print(f"Frames read: {frames_np_all.shape}") + + native_height, native_width = frames_np_all.shape[1], frames_np_all.shape[2] + print(f"Native video resolution: {native_width}x{native_height}") + + target_h_arg = native_height if height is None else height + target_w_arg = native_width if width is None else width + + if not no_resize: + actual_target_height, actual_target_width = find_nearest_bucket(target_h_arg, target_w_arg, resolution=resolution) + print(f"Adjusted resolution for VAE encoding: {actual_target_width}x{actual_target_height}") + else: + actual_target_width = (native_width // 8) * 8 + actual_target_height = (native_height // 8) * 8 + if actual_target_width != native_width or actual_target_height != native_height: + print(f"Using native resolution, adjusted to be divisible by 8: {actual_target_width}x{actual_target_height}") + else: + print(f"Using native resolution without resizing: {actual_target_width}x{actual_target_height}") + + processed_frames_list = [] + for frame_idx in range(frames_np_all.shape[0]): + frame = frames_np_all[frame_idx] + frame_resized_np = resize_and_center_crop(frame, target_width=actual_target_width, target_height=actual_target_height) + processed_frames_list.append(frame_resized_np) + + processed_frames_np_stack = np.stack(processed_frames_list) + print(f"Frames preprocessed: {processed_frames_np_stack.shape}") + + input_image_np_for_clip_first = processed_frames_np_stack[0] + input_image_np_for_clip_last = processed_frames_np_stack[-1] + + + print("Converting frames to tensor...") + frames_pt = torch.from_numpy(processed_frames_np_stack).float() / 127.5 - 1.0 + frames_pt = frames_pt.permute(0, 3, 1, 2) # B, H, W, C -> B, C, H, W + frames_pt = frames_pt.unsqueeze(0).permute(0, 2, 1, 3, 4) # B, C, H, W -> 1, C, B, H, W (as VAE expects 1,C,F,H,W) + print(f"Tensor shape for VAE: {frames_pt.shape}") + + input_video_pixels_cpu = frames_pt.clone().cpu() + + print(f"Moving VAE and tensor to device: {device}") + vae_model.to(device) + frames_pt = frames_pt.to(device) + + print(f"Encoding input video frames with VAE (batch size: {vae_batch_size})") + all_latents_list = [] + vae_model.eval() + with torch.no_grad(): + for i in tqdm(range(0, frames_pt.shape[2], vae_batch_size), desc="VAE Encoding Video Frames", mininterval=0.1): + batch_frames_pt = frames_pt[:, :, i:i + vae_batch_size] + try: + batch_latents = vae_encode(batch_frames_pt, vae_model) + all_latents_list.append(batch_latents.cpu()) + except RuntimeError as e: + print(f"Error during VAE encoding: {str(e)}") + if "out of memory" in str(e).lower() and device == "cuda": + print("CUDA out of memory during VAE encoding. Try reducing --vae_batch_size or use CPU for VAE.") + raise + + history_latents_cpu = torch.cat(all_latents_list, dim=2) + print(f"History latents shape (original video): {history_latents_cpu.shape}") + + start_latent_cpu = history_latents_cpu[:, :, :1].clone() + end_of_input_video_latent_cpu = history_latents_cpu[:, :, -1:].clone() + print(f"Start latent shape (for conditioning): {start_latent_cpu.shape}") + print(f"End of input video latent shape: {end_of_input_video_latent_cpu.shape}") + + + if device == "cuda": + vae_model.to(cpu) # Move VAE back to CPU + torch.cuda.empty_cache() + print("VAE moved back to CPU, CUDA cache cleared") + + return (start_latent_cpu, input_image_np_for_clip_first, + history_latents_cpu, fps, + actual_target_height, actual_target_width, + input_video_pixels_cpu, + end_of_input_video_latent_cpu, input_image_np_for_clip_last) + + except Exception as e: + print(f"Error in video_encode: {str(e)}") + traceback.print_exc() + raise + +@torch.no_grad() +def image_encode(image_np, target_width, target_height, vae_model, image_encoder_model, feature_extractor_model, device="cuda"): + """ + Encode a single image into a latent and compute its CLIP vision embedding. + """ + global high_vram # Use global high_vram status + print("Processing single image for encoding (e.g., end_frame)...") + try: + print(f"Using target resolution for image encoding: {target_width}x{target_height}") + + processed_image_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height) + + image_pt = torch.from_numpy(processed_image_np).float() / 127.5 - 1.0 + image_pt = image_pt.permute(2, 0, 1).unsqueeze(0).unsqueeze(2) # N C F H W (N=1, F=1) + + target_vae_device = device + if not high_vram: load_model_as_complete(vae_model, target_device=target_vae_device) + else: vae_model.to(target_vae_device) + image_pt_device = image_pt.to(target_vae_device) + + latent = vae_encode(image_pt_device, vae_model).cpu() # Encode and move to CPU + print(f"Single image VAE output shape (latent): {latent.shape}") + + if not high_vram: unload_complete_models(vae_model) # Offload VAE if low VRAM + + target_img_enc_device = device + if not high_vram: load_model_as_complete(image_encoder_model, target_device=target_img_enc_device) + else: image_encoder_model.to(target_img_enc_device) + + clip_embedding_output = hf_clip_vision_encode(processed_image_np, feature_extractor_model, image_encoder_model) + clip_embedding = clip_embedding_output.last_hidden_state.cpu() # Encode and move to CPU + print(f"Single image CLIP embedding shape: {clip_embedding.shape}") + + if not high_vram: unload_complete_models(image_encoder_model) # Offload image encoder if low VRAM + + if device == "cuda": + torch.cuda.empty_cache() + # print("CUDA cache cleared after single image_encode") + + return latent, clip_embedding, processed_image_np + + except Exception as e: + print(f"Error in image_encode: {str(e)}") + traceback.print_exc() + raise + +def set_mp4_comments_imageio_ffmpeg(input_file, comments): + try: + ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() + if not os.path.exists(input_file): + print(f"Error: Input file {input_file} does not exist") + return False + temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name + command = [ + ffmpeg_path, '-i', input_file, '-metadata', f'comment={comments}', + '-c:v', 'copy', '-c:a', 'copy', '-y', temp_file + ] + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False) + if result.returncode == 0: + shutil.move(temp_file, input_file) + print(f"Successfully added comments to {input_file}") + return True + else: + if os.path.exists(temp_file): os.remove(temp_file) + print(f"Error: FFmpeg failed with message:\n{result.stderr}") + return False + except Exception as e: + if 'temp_file' in locals() and os.path.exists(temp_file): os.remove(temp_file) + print(f"Error saving prompt to video metadata, ffmpeg may be required: "+str(e)) + return False + +@torch.no_grad() +def do_generation_work( + input_video_path, prompt, n_prompt, seed, + end_frame_path, end_frame_weight, # New arguments + resolution_max_dim, + additional_second_length, + latent_window_size, steps, cfg, gs, rs, + gpu_memory_preservation, use_teacache, no_resize, mp4_crf, + num_clean_frames, vae_batch_size, + extension_only +): + global high_vram, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, feature_extractor, image_encoder, transformer, args + + print('--- Starting Video Generation (with End Frame support) ---') + + try: + # --- Text Encoding --- + print('Text encoding...') + target_text_enc_device = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram: + if text_encoder: fake_diffusers_current_device(text_encoder, target_text_enc_device) # DynamicSwapInstaller for text_encoder + if text_encoder_2: load_model_as_complete(text_encoder_2, target_device=target_text_enc_device) + else: + if text_encoder: text_encoder.to(target_text_enc_device) + if text_encoder_2: text_encoder_2.to(target_text_enc_device) + + llama_vec_gpu, clip_l_pooler_gpu = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) + if cfg == 1.0: # Note: Original FramePack usually uses gs, cfg=1 means gs is active + llama_vec_n_gpu, clip_l_pooler_n_gpu = torch.zeros_like(llama_vec_gpu), torch.zeros_like(clip_l_pooler_gpu) + else: # If cfg > 1.0, it implies standard CFG, so n_prompt is used. gs should be 1.0 in this case. + llama_vec_n_gpu, clip_l_pooler_n_gpu = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) + + # Store on CPU + llama_vec_padded_cpu, llama_attention_mask_cpu = crop_or_pad_yield_mask(llama_vec_gpu.cpu(), length=512) + llama_vec_n_padded_cpu, llama_attention_mask_n_cpu = crop_or_pad_yield_mask(llama_vec_n_gpu.cpu(), length=512) + clip_l_pooler_cpu = clip_l_pooler_gpu.cpu() + clip_l_pooler_n_cpu = clip_l_pooler_n_gpu.cpu() + + if not high_vram: unload_complete_models(text_encoder_2) # text_encoder is managed by DynamicSwap + + # --- Video and End Frame Encoding --- + print('Encoding input video...') + video_encode_device = str(gpu if torch.cuda.is_available() else cpu) + (start_latent_input_cpu, input_image_np_first, + video_latents_history_cpu, fps, height, width, + input_video_pixels_cpu, + end_of_input_video_latent_cpu, input_image_np_last) = video_encode( + input_video_path, resolution_max_dim, no_resize, vae, + vae_batch_size=vae_batch_size, device=video_encode_device, + width=None, height=None # video_encode will use resolution_max_dim + ) + if fps <= 0: raise ValueError("FPS from input video is 0 or invalid.") + + end_latent_from_file_cpu, end_clip_embedding_from_file_cpu = None, None + if end_frame_path: + print(f"Encoding provided end frame from: {end_frame_path}") + end_frame_pil = Image.open(end_frame_path).convert("RGB") + end_frame_np = np.array(end_frame_pil) + end_latent_from_file_cpu, end_clip_embedding_from_file_cpu, _ = image_encode( + end_frame_np, target_width=width, target_height=height, + vae_model=vae, image_encoder_model=image_encoder, + feature_extractor_model=feature_extractor, device=video_encode_device + ) + + # --- CLIP Vision Encoding for first and last frames of input video --- + print('CLIP Vision encoding for input video frames...') + target_img_enc_device = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram: load_model_as_complete(image_encoder, target_device=target_img_enc_device) + else: image_encoder.to(target_img_enc_device) + + # For original FramePack, image_embeddings in sample_hunyuan often comes from the *start* image. + # Script 2 uses end_of_input_video_embedding or a blend with the explicit end_frame. + # We will follow script 2 for conditioning. + # start_clip_embedding_cpu = hf_clip_vision_encode(input_image_np_first, feature_extractor, image_encoder).last_hidden_state.cpu() + end_of_input_video_clip_embedding_cpu = hf_clip_vision_encode(input_image_np_last, feature_extractor, image_encoder).last_hidden_state.cpu() + + if not high_vram: unload_complete_models(image_encoder) + + # Determine final image embedding for sampling loop + if end_clip_embedding_from_file_cpu is not None: + print(f"Blending end-of-input-video embedding with provided end_frame embedding (weight: {end_frame_weight})") + final_clip_embedding_for_sampling_cpu = \ + (1.0 - end_frame_weight) * end_of_input_video_clip_embedding_cpu + \ + end_frame_weight * end_clip_embedding_from_file_cpu + else: + print("Using end-of-input-video's last frame embedding for image conditioning.") + final_clip_embedding_for_sampling_cpu = end_of_input_video_clip_embedding_cpu.clone() + + # --- Prepare for Sampling Loop --- + target_transformer_device = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram: + if transformer: move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation) + else: + if transformer: transformer.to(target_transformer_device) + + cond_device = transformer.device + cond_dtype = transformer.dtype + + # Move conditioning tensors to transformer's device and dtype + llama_vec = llama_vec_padded_cpu.to(device=cond_device, dtype=cond_dtype) + llama_attention_mask = llama_attention_mask_cpu.to(device=cond_device) # Mask is usually bool/int + clip_l_pooler = clip_l_pooler_cpu.to(device=cond_device, dtype=cond_dtype) + llama_vec_n = llama_vec_n_padded_cpu.to(device=cond_device, dtype=cond_dtype) + llama_attention_mask_n = llama_attention_mask_n_cpu.to(device=cond_device) + clip_l_pooler_n = clip_l_pooler_n_cpu.to(device=cond_device, dtype=cond_dtype) + + # This is the image embedding that will be used in the sampling loop + image_embeddings_for_sampling_loop = final_clip_embedding_for_sampling_cpu.to(device=cond_device, dtype=cond_dtype) + + # start_latent_for_initial_cond_gpu is the first frame of input video, used for clean_latents_pre + # However, script 2 uses `video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):]` for clean_latents_pre. + # And `start_latent` for sample_hunyuan's `clean_latents` is `torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)` + # For backward generation, the "start_latent" concept for `sample_hunyuan`'s `clean_latents` argument + # is often the *last frame of the input video* when generating the chunk closest to the input video. + # Let's use end_of_input_video_latent_cpu for this role when appropriate. + + num_output_pixel_frames_per_section = latent_window_size * 4 # Not -3 here, as this is for total section calc + if num_output_pixel_frames_per_section == 0: + raise ValueError("latent_window_size * 4 is zero, cannot calculate total_extension_latent_sections.") + total_extension_latent_sections = int(max(round((additional_second_length * fps) / num_output_pixel_frames_per_section), 1)) + + print(f"Input video FPS: {fps}, Target additional length: {additional_second_length}s") + print(f"Generating {total_extension_latent_sections} new sections for extension (approx {total_extension_latent_sections * num_output_pixel_frames_per_section / fps:.2f}s).") + + job_id_base = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + \ + f"_framepack-vidEndFrm_{width}x{height}_{additional_second_length:.1f}s_seed{seed}_s{steps}_gs{gs}_cfg{cfg}" + + job_id = job_id_base + if args.extension_only: # <<< Access args directly + job_id += "_extonly" + print("Extension-only mode enabled. Filenames will reflect this.") + + rnd = torch.Generator("cpu").manual_seed(seed) + + # Initialize history for generated latents (starts empty or with end_latent_from_file) + if end_latent_from_file_cpu is not None: + # This assumes end_latent_from_file_cpu is [1,C,1,H,W], we might need more frames if it's a seed + # Script 2's logic for clean_latents_post when is_end_of_video seems to use just 1 frame. + history_latents_generated_cpu = end_latent_from_file_cpu.clone() + else: + channels_dim = video_latents_history_cpu.shape[1] # Get from input video latents + latent_h, latent_w = height // 8, width // 8 + history_latents_generated_cpu = torch.empty((1, channels_dim, 0, latent_h, latent_w), dtype=torch.float32, device='cpu') + + # Initialize history for decoded pixels (starts empty) + history_pixels_decoded_cpu = None + + total_generated_latent_frames_count = history_latents_generated_cpu.shape[2] + previous_video_path_for_cleanup = None + + # Backward generation loop (from demo_gradio_video+endframe.py) + latent_paddings = list(reversed(range(total_extension_latent_sections))) + if total_extension_latent_sections > 4: # Heuristic from script 2 + latent_paddings = [3] + [2] * (total_extension_latent_sections - 3) + [1, 0] + + for loop_idx, latent_padding_val in enumerate(latent_paddings): + current_section_num_from_end = loop_idx + 1 + is_start_of_extension = (latent_padding_val == 0) # This is the chunk closest to input video + is_end_of_extension = (latent_padding_val == latent_paddings[0]) # This is the chunk furthest from input video + + print(f"--- Generating Extension: Seed {seed}: Section {current_section_num_from_end}/{total_extension_latent_sections} (backward), padding={latent_padding_val} ---") + + if transformer: transformer.initialize_teacache(enable_teacache=use_teacache, num_steps=steps if use_teacache else 0) + progress_bar_sampler = tqdm(total=steps, desc=f"Sampling Extension Section {current_section_num_from_end}/{total_extension_latent_sections}", file=sys.stdout, dynamic_ncols=True) + def sampler_callback_cli(d): progress_bar_sampler.update(1) + + # Context frame calculation (from demo_gradio_video+endframe.py worker) + # `available_frames` for context refers to previously *generated* frames or input video frames + # For `clean_latents_pre`, it's always from `video_latents_history_cpu` + # For `clean_latents_post`, `_2x`, `_4x`, it's from `history_latents_generated_cpu` + + effective_clean_frames_count = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1 + + # For clean_latents_pre (from input video) + # If is_start_of_extension, we might want stronger anchoring to input video. Script 2 uses full `effective_clean_frames_count`. + clean_latent_pre_frames_num = effective_clean_frames_count + if is_start_of_extension: # Closest to input video + clean_latent_pre_frames_num = 1 # Script 2 uses 1 to avoid jumpcuts from input video when generating chunk closest to it. + + # For clean_latents_post, _2x, _4x (from previously generated extension chunks) + available_generated_latents = history_latents_generated_cpu.shape[2] + + # `post_frames_num` is for clean_latents_post + post_frames_num = 1 if is_end_of_extension and end_latent_from_file_cpu is not None else effective_clean_frames_count + if is_end_of_extension and end_latent_from_file_cpu is not None: post_frames_num = 1 # script 2 detail for end_latent + + num_2x_frames_count = min(2, max(0, available_generated_latents - post_frames_num -1)) + num_4x_frames_count = min(16, max(0, available_generated_latents - post_frames_num - num_2x_frames_count)) + + # Latent indexing for sample_hunyuan (from script 2) + latent_padding_size_for_indices = latent_padding_val * latent_window_size + pixel_frames_to_generate_this_step = latent_window_size * 4 - 3 + + indices_tensor_gpu = torch.arange(0, + clean_latent_pre_frames_num + + latent_padding_size_for_indices + + latent_window_size + # Note: script 2 uses latent_window_size here for `latent_indices` count + post_frames_num + + num_2x_frames_count + + num_4x_frames_count + ).unsqueeze(0).to(cond_device) + + (clean_latent_indices_pre_gpu, + blank_indices_gpu, # For padding + latent_indices_for_denoising_gpu, # For new generation + clean_latent_indices_post_gpu, + clean_latent_2x_indices_gpu, + clean_latent_4x_indices_gpu + ) = indices_tensor_gpu.split( + [clean_latent_pre_frames_num, latent_padding_size_for_indices, latent_window_size, + post_frames_num, num_2x_frames_count, num_4x_frames_count], dim=1 + ) + clean_latent_indices_combined_gpu = torch.cat([clean_latent_indices_pre_gpu, clean_latent_indices_post_gpu], dim=1) + + # Prepare conditioning latents + # clean_latents_pre_cpu: from end of input video + actual_pre_frames_to_take = min(clean_latent_pre_frames_num, video_latents_history_cpu.shape[2]) + clean_latents_pre_cpu = video_latents_history_cpu[:, :, -actual_pre_frames_to_take:].clone() + if clean_latents_pre_cpu.shape[2] < clean_latent_pre_frames_num and clean_latents_pre_cpu.shape[2] > 0: # Pad if necessary + repeats = math.ceil(clean_latent_pre_frames_num / clean_latents_pre_cpu.shape[2]) + clean_latents_pre_cpu = clean_latents_pre_cpu.repeat(1,1,repeats,1,1)[:,:,:clean_latent_pre_frames_num] + elif clean_latents_pre_cpu.shape[2] == 0 and clean_latent_pre_frames_num > 0: # Should not happen if video_latents_history_cpu is valid + clean_latents_pre_cpu = torch.zeros((1,channels_dim,clean_latent_pre_frames_num,latent_h,latent_w),dtype=torch.float32) + + + # clean_latents_post_cpu, _2x_cpu, _4x_cpu: from start of `history_latents_generated_cpu` + current_offset_in_generated = 0 + + # Post frames + actual_post_frames_to_take = min(post_frames_num, history_latents_generated_cpu.shape[2]) + if is_end_of_extension and end_latent_from_file_cpu is not None: + clean_latents_post_cpu = end_latent_from_file_cpu.clone() # Should be [1,C,1,H,W] + else: + clean_latents_post_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_post_frames_to_take].clone() + current_offset_in_generated += clean_latents_post_cpu.shape[2] + + if clean_latents_post_cpu.shape[2] < post_frames_num and clean_latents_post_cpu.shape[2] > 0: # Pad + repeats = math.ceil(post_frames_num / clean_latents_post_cpu.shape[2]) + clean_latents_post_cpu = clean_latents_post_cpu.repeat(1,1,repeats,1,1)[:,:,:post_frames_num] + elif clean_latents_post_cpu.shape[2] == 0 and post_frames_num > 0: # Fill with zeros if no history and no end_latent + clean_latents_post_cpu = torch.zeros((1,channels_dim,post_frames_num,latent_h,latent_w),dtype=torch.float32) + + # 2x frames + actual_2x_frames_to_take = min(num_2x_frames_count, history_latents_generated_cpu.shape[2] - current_offset_in_generated) + clean_latents_2x_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_2x_frames_to_take].clone() + current_offset_in_generated += clean_latents_2x_cpu.shape[2] + if clean_latents_2x_cpu.shape[2] < num_2x_frames_count and clean_latents_2x_cpu.shape[2] > 0: # Pad + repeats = math.ceil(num_2x_frames_count / clean_latents_2x_cpu.shape[2]) + clean_latents_2x_cpu = clean_latents_2x_cpu.repeat(1,1,repeats,1,1)[:,:,:num_2x_frames_count] + elif clean_latents_2x_cpu.shape[2] == 0 and num_2x_frames_count > 0: + clean_latents_2x_cpu = torch.zeros((1,channels_dim,num_2x_frames_count,latent_h,latent_w),dtype=torch.float32) + + # 4x frames + actual_4x_frames_to_take = min(num_4x_frames_count, history_latents_generated_cpu.shape[2] - current_offset_in_generated) + clean_latents_4x_cpu = history_latents_generated_cpu[:,:, current_offset_in_generated : current_offset_in_generated + actual_4x_frames_to_take].clone() + if clean_latents_4x_cpu.shape[2] < num_4x_frames_count and clean_latents_4x_cpu.shape[2] > 0: # Pad + repeats = math.ceil(num_4x_frames_count / clean_latents_4x_cpu.shape[2]) + clean_latents_4x_cpu = clean_latents_4x_cpu.repeat(1,1,repeats,1,1)[:,:,:num_4x_frames_count] + elif clean_latents_4x_cpu.shape[2] == 0 and num_4x_frames_count > 0: + clean_latents_4x_cpu = torch.zeros((1,channels_dim,num_4x_frames_count,latent_h,latent_w),dtype=torch.float32) + +# Combine pre and post for `clean_latents` argument + clean_latents_for_sampler_gpu = torch.cat([ + clean_latents_pre_cpu.to(device=cond_device, dtype=torch.float32), + clean_latents_post_cpu.to(device=cond_device, dtype=torch.float32) + ], dim=2) + + # Ensure 2x and 4x latents are None if their frame counts are 0 + # The k_diffusion_hunyuan.sample_hunyuan and the DiT should handle None for these if indices are also empty. + clean_latents_2x_gpu = None + if num_2x_frames_count > 0 and clean_latents_2x_cpu.shape[2] > 0: + clean_latents_2x_gpu = clean_latents_2x_cpu.to(device=cond_device, dtype=torch.float32) + elif num_2x_frames_count > 0 and clean_latents_2x_cpu.shape[2] == 0: # Should have been filled with zeros if count > 0 + print(f"Warning: num_2x_frames_count is {num_2x_frames_count} but clean_latents_2x_cpu is empty. Defaulting to None.") + + + clean_latents_4x_gpu = None + if num_4x_frames_count > 0 and clean_latents_4x_cpu.shape[2] > 0: + clean_latents_4x_gpu = clean_latents_4x_cpu.to(device=cond_device, dtype=torch.float32) + elif num_4x_frames_count > 0 and clean_latents_4x_cpu.shape[2] == 0: + print(f"Warning: num_4x_frames_count is {num_4x_frames_count} but clean_latents_4x_cpu is empty. Defaulting to None.") + + # Also, ensure indices are None or empty if counts are zero. + # The split logic already ensures this if the split size is 0. + # clean_latent_2x_indices_gpu will be shape (B, 0) if num_2x_frames_count is 0. + # The DiT model should correctly interpret an empty indices tensor or None for the corresponding latent. + generated_latents_gpu_step = sample_hunyuan( + transformer=transformer, sampler='unipc', width=width, height=height, + frames=pixel_frames_to_generate_this_step, # Num frames for current chunk + real_guidance_scale=cfg, distilled_guidance_scale=gs, guidance_rescale=rs, + num_inference_steps=steps, generator=rnd, + prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler, + negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n, + device=cond_device, dtype=cond_dtype, + image_embeddings=image_embeddings_for_sampling_loop, # Use the blended/final one + latent_indices=latent_indices_for_denoising_gpu, + clean_latents=clean_latents_for_sampler_gpu, + clean_latent_indices=clean_latent_indices_combined_gpu, + clean_latents_2x=clean_latents_2x_gpu, # Can be None + clean_latent_2x_indices=clean_latent_2x_indices_gpu if num_2x_frames_count > 0 else None, # Pass None if count is 0 + clean_latents_4x=clean_latents_4x_gpu, # Can be None + clean_latent_4x_indices=clean_latent_4x_indices_gpu if num_4x_frames_count > 0 else None, # Pass None if count is 0 + callback=sampler_callback_cli, + ) + if progress_bar_sampler: progress_bar_sampler.close() + + # If this was the chunk closest to input video, prepend the last frame of input video for smoother transition + if is_start_of_extension: + generated_latents_gpu_step = torch.cat([ + end_of_input_video_latent_cpu.to(generated_latents_gpu_step), # Use actual last frame latent + generated_latents_gpu_step + ], dim=2) + + # Prepend generated latents to history + history_latents_generated_cpu = torch.cat([generated_latents_gpu_step.cpu(), history_latents_generated_cpu], dim=2) + total_generated_latent_frames_count = history_latents_generated_cpu.shape[2] + + # --- Decode and Append Pixels --- + target_vae_device = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram: + if transformer: offload_model_from_device_for_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation) + if vae: load_model_as_complete(vae, target_device=target_vae_device) + else: + if vae: vae.to(target_vae_device) + + # Decode the newly generated part (or a relevant segment for stitching) + # Script 2 decodes `real_history_latents[:, :, :section_latent_frames]` + # section_latent_frames = (latent_window_size * 2 + 1) if is_start_of_video else (latent_window_size * 2) + num_latents_to_decode_for_stitch = (latent_window_size * 2 + 1) if is_start_of_extension else (latent_window_size * 2) + num_latents_to_decode_for_stitch = min(num_latents_to_decode_for_stitch, history_latents_generated_cpu.shape[2]) + + latents_for_current_decode_gpu = history_latents_generated_cpu[:, :, :num_latents_to_decode_for_stitch].to(target_vae_device) + + pixels_for_current_part_decoded_cpu = vae_decode(latents_for_current_decode_gpu, vae).cpu() + + # Soft append pixels (current_pixels, history_pixels, overlap) + overlap_for_soft_append = latent_window_size * 4 - 3 + + if history_pixels_decoded_cpu is None: + history_pixels_decoded_cpu = pixels_for_current_part_decoded_cpu + else: + overlap_actual = min(overlap_for_soft_append, history_pixels_decoded_cpu.shape[2], pixels_for_current_part_decoded_cpu.shape[2]) + if overlap_actual <=0: # Should not happen with proper windowing + history_pixels_decoded_cpu = torch.cat([pixels_for_current_part_decoded_cpu, history_pixels_decoded_cpu], dim=2) # Simple prepend + else: + history_pixels_decoded_cpu = soft_append_bcthw( + pixels_for_current_part_decoded_cpu, # Current (prepended) + history_pixels_decoded_cpu, # History + overlap=overlap_actual + ) + + if not high_vram: + if vae: unload_complete_models(vae) + if transformer and not is_start_of_extension : # Reload transformer for next iter + move_model_to_device_with_memory_preservation(transformer, target_device=target_transformer_device, preserved_memory_gb=gpu_memory_preservation) + + # Save intermediate video + current_output_filename = os.path.join(outputs_folder, f'{job_id}_part{current_section_num_from_end}_totalframes{history_pixels_decoded_cpu.shape[2]}.mp4') + save_bcthw_as_mp4(history_pixels_decoded_cpu, current_output_filename, fps=fps, crf=mp4_crf) + print(f"MP4 Preview for section {current_section_num_from_end} saved: {current_output_filename}") + set_mp4_comments_imageio_ffmpeg(current_output_filename, f"Prompt: {prompt} | Neg: {n_prompt} | Seed: {seed}"); + + if previous_video_path_for_cleanup is not None and os.path.exists(previous_video_path_for_cleanup): + try: os.remove(previous_video_path_for_cleanup) + except Exception as e_del: print(f"Error deleting {previous_video_path_for_cleanup}: {e_del}") + previous_video_path_for_cleanup = current_output_filename + + if is_start_of_extension: # Last iteration of backward loop + break + + # --- Final Video Assembly --- + if args.extension_only: # <<< Access args directly + print("Saving only the generated extension...") + # history_pixels_decoded_cpu already contains only the generated extension due to backward generation + # and how it's accumulated. + video_to_save_cpu = history_pixels_decoded_cpu + final_output_filename_suffix = "_extension_only_final.mp4" + final_log_message = "Final extension-only video saved:" + else: + print("Appending generated extension to the input video...") + # input_video_pixels_cpu is (1, C, F_in, H, W) + # history_pixels_decoded_cpu is (1, C, F_ext, H, W) + video_to_save_cpu = torch.cat([input_video_pixels_cpu, history_pixels_decoded_cpu], dim=2) + final_output_filename_suffix = "_final.mp4" + final_log_message = "Final extended video saved:" + + final_output_filename = os.path.join(outputs_folder, f'{job_id}{final_output_filename_suffix}') # job_id already has _extonly if needed + save_bcthw_as_mp4(video_to_save_cpu, final_output_filename, fps=fps, crf=mp4_crf) + print(f"{final_log_message} {final_output_filename}") + set_mp4_comments_imageio_ffmpeg(final_output_filename, f"Prompt: {prompt} | Neg: {n_prompt} | Seed: {seed}"); + + if previous_video_path_for_cleanup is not None and os.path.exists(previous_video_path_for_cleanup) and previous_video_path_for_cleanup != final_output_filename: + try: os.remove(previous_video_path_for_cleanup) + except Exception as e_del: print(f"Error deleting last part: {e_del}") + + except Exception as e_outer: + traceback.print_exc() + print(f"Error during generation: {e_outer}") + finally: + if not high_vram: + unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer) + print("--- Generation work cycle finished. ---") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="FramePack Video Generation CLI (with End Frame)") + + # Inputs + parser.add_argument('--input_video', type=str, required=True, help='Path to the input video file.') + parser.add_argument('--prompt', type=str, required=True, help='Prompt for video generation.') + parser.add_argument('--n_prompt', type=str, default="", help='Negative prompt.') + parser.add_argument('--end_frame', type=str, default=None, help='Optional path to an image to guide the end of the video.') + parser.add_argument('--end_frame_weight', type=float, default=1.0, help='Weight for the end_frame image conditioning (0.0 to 1.0). Default 1.0.') + + # Generation parameters + parser.add_argument('--seed', type=int, default=31337, help='Seed for generation.') + parser.add_argument('--resolution_max_dim', type=int, default=640, help='Target resolution (max width or height for bucket search).') + parser.add_argument('--total_second_length', type=float, default=5.0, help='Additional video length to generate (seconds).') + parser.add_argument('--latent_window_size', type=int, default=9, help='Latent window size (frames for DiT). Orignal FramePack default is 9.') + parser.add_argument('--steps', type=int, default=25, help='Number of inference steps.') + parser.add_argument('--cfg', type=float, default=1.0, help='CFG Scale. If > 1.0, n_prompt is used and gs is set to 1.0. Default 1.0 (for distilled guidance).') + parser.add_argument('--gs', type=float, default=10.0, help='Distilled CFG Scale (Embedded CFG for Original FramePack). Default 10.0.') # Original default + parser.add_argument('--rs', type=float, default=0.0, help='CFG Re-Scale (usually 0.0).') + parser.add_argument('--num_clean_frames', type=int, default=5, help='Number of 1x context frames for DiT conditioning. Script2 default 5.') + + # Technical parameters + parser.add_argument('--gpu_memory_preservation', type=float, default=6.0, help='GPU memory to preserve (GB) for low VRAM mode.') + parser.add_argument('--use_teacache', action='store_true', default=False, help='Enable TeaCache (if DiT supports it).') + parser.add_argument('--no_resize', action='store_true', default=False, help='Force original video resolution for input video encoding (VAE).') + parser.add_argument('--mp4_crf', type=int, default=16, help='MP4 CRF value (0-51, lower is better quality).') + parser.add_argument('--vae_batch_size', type=int, default=-1, help='VAE batch size for input video encoding. Default: auto based on VRAM.') + parser.add_argument('--output_dir', type=str, default='./outputs/', help="Directory to save output videos.") + + # Model paths + parser.add_argument('--dit', type=str, required=True, help="Path to local DiT model weights file or directory (e.g., for lllyasviel/FramePackI2V_HY).") + parser.add_argument('--vae', type=str, required=True, help="Path to local VAE model weights file or directory.") + parser.add_argument('--text_encoder1', type=str, required=True, help="Path to Text Encoder 1 (Llama) WEIGHT FILE.") + parser.add_argument('--text_encoder2', type=str, required=True, help="Path to Text Encoder 2 (CLIP) WEIGHT FILE.") + parser.add_argument('--image_encoder', type=str, required=True, help="Path to Image Encoder (SigLIP) WEIGHT FILE.") + + # Advanced model settings + parser.add_argument('--attn_mode', type=str, default="torch", help="Attention mode for DiT (torch, flash, xformers, etc.).") + parser.add_argument('--fp8_llm', action='store_true', help="Use fp8 for Text Encoder 1 (Llama).") # from fpack_generate_video + parser.add_argument("--vae_chunk_size", type=int, default=None, help="Chunk size for CausalConv3d in VAE.") + parser.add_argument("--vae_spatial_tile_sample_min_size", type=int, default=None, help="Spatial tile sample min size for VAE.") + + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path(s).") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier(s).") + parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns.") + parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns.") + parser.add_argument('--extension_only', action='store_true', help="Save only the extension video without the input video attached.") + + args = parser.parse_args() + + current_device_str = str(gpu if torch.cuda.is_available() else cpu) + args.device = current_device_str + + for model_arg_name in ['dit', 'vae', 'text_encoder1', 'text_encoder2', 'image_encoder']: + path_val = getattr(args, model_arg_name) + if not os.path.exists(path_val): + parser.error(f"Path for --{model_arg_name} not found: {path_val}") + + outputs_folder = args.output_dir + os.makedirs(outputs_folder, exist_ok=True) + print(f"Outputting videos to: {outputs_folder}") + + free_mem_gb = get_cuda_free_memory_gb(gpu if torch.cuda.is_available() else None) + # Adjusted high_vram threshold, can be tuned + high_vram = free_mem_gb > 30 # Example: 30GB+ for "high_vram" + print(f'Free VRAM {free_mem_gb:.2f} GB. High-VRAM Mode: {high_vram}') + + if args.vae_batch_size == -1: + if free_mem_gb >= 18: args.vae_batch_size = 64 + elif free_mem_gb >= 10: args.vae_batch_size = 32 + else: args.vae_batch_size = 16 + print(f"Auto-set VAE batch size to: {args.vae_batch_size}") + + print("Loading models...") + loading_device_str = str(cpu) # Load to CPU first + + transformer = load_packed_model( + device=loading_device_str, + dit_path=args.dit, + attn_mode=args.attn_mode, + loading_device=loading_device_str + ) + print("DiT loaded.") + + if args.lora_weight is not None and len(args.lora_weight) > 0: + print("Merging LoRA weights...") + if len(args.lora_multiplier) == 1 and len(args.lora_weight) > 1: + args.lora_multiplier = args.lora_multiplier * len(args.lora_weight) + elif len(args.lora_multiplier) != len(args.lora_weight): + parser.error(f"Number of LoRA weights ({len(args.lora_weight)}) and multipliers ({len(args.lora_multiplier)}) must match, or provide a single multiplier.") + + try: + # Mimic fpack_generate_video.py's LoRA args structure if needed by merge_lora_weights + if not hasattr(args, 'lycoris'): args.lycoris = False + if not hasattr(args, 'save_merged_model'): args.save_merged_model = None + + current_device_for_lora = torch.device(loading_device_str) + merge_lora_weights(lora_framepack, transformer, args, current_device_for_lora) + print("LoRA weights merged successfully.") + except Exception as e_lora: + print(f"Error merging LoRA weights: {e_lora}") + traceback.print_exc() + + vae = load_vae( + vae_path=args.vae, + vae_chunk_size=args.vae_chunk_size, + vae_spatial_tile_sample_min_size=args.vae_spatial_tile_sample_min_size, + device=loading_device_str + ) + print("VAE loaded.") + + # For text_encoder loading, fpack_generate_video.py uses args.fp8_llm for text_encoder1 + # The f1_video_cli_local.py passes `args` directly. We'll do the same. + tokenizer, text_encoder = load_text_encoder1(args, device=loading_device_str) + print("Text Encoder 1 and Tokenizer 1 loaded.") + tokenizer_2, text_encoder_2 = load_text_encoder2(args) + print("Text Encoder 2 and Tokenizer 2 loaded.") + feature_extractor, image_encoder = load_image_encoders(args) + print("Image Encoder and Feature Extractor loaded.") + + all_models_list = [transformer, vae, text_encoder, text_encoder_2, image_encoder] + for model_obj in all_models_list: + if model_obj is not None: + model_obj.eval().requires_grad_(False) + + # Set dtypes (Original FramePack typically bfloat16 for DiT, float16 for others) + if transformer: transformer.to(dtype=torch.bfloat16) + if vae: vae.to(dtype=torch.float16) + if image_encoder: image_encoder.to(dtype=torch.float16) + if text_encoder: text_encoder.to(dtype=torch.float16) # Or bfloat16 if fp8_llm implies that + if text_encoder_2: text_encoder_2.to(dtype=torch.float16) + + if transformer: + transformer.high_quality_fp32_output_for_inference = True # Common setting + print('Transformer: high_quality_fp32_output_for_inference = True') + + if vae and not high_vram: + vae.enable_slicing() + vae.enable_tiling() + + target_gpu_device_str = str(gpu if torch.cuda.is_available() else cpu) + if not high_vram and torch.cuda.is_available(): + print("Low VRAM mode: Setting up dynamic swapping for DiT and Text Encoder 1.") + if transformer: DynamicSwapInstaller.install_model(transformer, device=target_gpu_device_str) + if text_encoder: DynamicSwapInstaller.install_model(text_encoder, device=target_gpu_device_str) + # Other models (VAE, TE2, ImgEnc) will be loaded/offloaded as needed by `load_model_as_complete` / `unload_complete_models` + if vae: vae.to(cpu) + if text_encoder_2: text_encoder_2.to(cpu) + if image_encoder: image_encoder.to(cpu) + elif torch.cuda.is_available(): + print(f"High VRAM mode: Moving all models to {target_gpu_device_str}.") + for model_obj in all_models_list: + if model_obj is not None: model_obj.to(target_gpu_device_str) + else: + print("Running on CPU. Models remain on CPU.") + + print("All models loaded and configured.") + + # Adjust gs if cfg > 1.0 (standard CFG mode) + actual_gs_cli = args.gs + if args.cfg > 1.0: + actual_gs_cli = 1.0 # For standard CFG, distilled guidance is turned off + print(f"CFG > 1.0 detected ({args.cfg}), this implies standard CFG. Overriding GS to 1.0 from {args.gs}.") + + do_generation_work( + input_video_path=args.input_video, + prompt=args.prompt, + n_prompt=args.n_prompt, + seed=args.seed, + end_frame_path=args.end_frame, + end_frame_weight=args.end_frame_weight, + resolution_max_dim=args.resolution_max_dim, + additional_second_length=args.total_second_length, + latent_window_size=args.latent_window_size, + steps=args.steps, + cfg=args.cfg, + gs=actual_gs_cli, + rs=args.rs, + gpu_memory_preservation=args.gpu_memory_preservation, + use_teacache=args.use_teacache, + no_resize=args.no_resize, + mp4_crf=args.mp4_crf, + num_clean_frames=args.num_clean_frames, + vae_batch_size=args.vae_batch_size, + extension_only=args.extension_only + ) + + print("Video generation process completed.") \ No newline at end of file diff --git a/fpack_generate_video.py b/fpack_generate_video.py new file mode 100644 index 0000000000000000000000000000000000000000..b052d3e992d582df8241eebf8a22a19fee0da6b1 --- /dev/null +++ b/fpack_generate_video.py @@ -0,0 +1,1689 @@ +import argparse +from datetime import datetime +import gc +import json +import random +import os +import re +import time +import math +import copy +from typing import Tuple, Optional, List, Union, Any, Dict +from rich.traceback import install as install_rich_tracebacks +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from PIL import Image +import cv2 +import numpy as np +import torchvision.transforms.functional as TF +from transformers import LlamaModel +from tqdm import tqdm +from rich_argparse import RichHelpFormatter +from networks import lora_framepack +from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D +from frame_pack import hunyuan +from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model +from frame_pack.utils import crop_or_pad_yield_mask, resize_and_center_crop, soft_append_bcthw +from frame_pack.bucket_tools import find_nearest_bucket +from frame_pack.clip_vision import hf_clip_vision_encode +from frame_pack.k_diffusion_hunyuan import sample_hunyuan +from dataset import image_video_dataset + +try: + from lycoris.kohya import create_network_from_weights +except: + pass + +from utils.device_utils import clean_memory_on_device +from base_hv_generate_video import save_images_grid, save_videos_grid, synchronize_device +from base_wan_generate_video import merge_lora_weights +from frame_pack.framepack_utils import load_vae, load_text_encoder1, load_text_encoder2, load_image_encoders +from dataset.image_video_dataset import load_video +from blissful_tuner.blissful_args import add_blissful_args, parse_blissful_args +from blissful_tuner.video_processing_common import save_videos_grid_advanced +from blissful_tuner.latent_preview import LatentPreviewer +import logging +from diffusers_helper.utils import save_bcthw_as_mp4 + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class GenerationSettings: + def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None): + self.device = device + self.dit_weight_dtype = dit_weight_dtype + + +def parse_args() -> argparse.Namespace: + """parse command line arguments""" + install_rich_tracebacks() + parser = argparse.ArgumentParser(description="Framepack inference script", formatter_class=RichHelpFormatter) + + # WAN arguments + # parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).") + parser.add_argument("--is_f1", action="store_true", help="Use the FramePack F1 model specific logic.") + parser.add_argument( + "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample." + ) + + parser.add_argument("--dit", type=str, default=None, help="DiT directory or path. Overrides --model_version if specified.") + parser.add_argument( + "--model_version", type=str, default="original", choices=["original", "f1"], help="Select the FramePack model version to use ('original' or 'f1'). Ignored if --dit is specified." + ) + parser.add_argument("--vae", type=str, default=None, help="VAE directory or path") + parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory or path") + parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory or path") + parser.add_argument("--image_encoder", type=str, required=True, help="Image Encoder directory or path") + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") + parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + + # inference + parser.add_argument( + "--prompt", + type=str, + default=None, + help="prompt for generation. If `;;;` is used, it will be split into sections. Example: `section_index:prompt` or " + "`section_index:prompt;;;section_index:prompt;;;...`, section_index can be `0` or `-1` or `0-2`, `-1` means last section, `0-2` means from 0 to 2 (inclusive).", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + help="negative prompt for generation, default is empty string. should not change.", + ) + parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width") + parser.add_argument("--video_seconds", type=float, default=5.0, help="video length, Default is 5.0 seconds") + parser.add_argument("--fps", type=int, default=30, help="video fps, Default is 30") + parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, Default is 25") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=str, default=None, help="Seed for evaluation.") + # parser.add_argument( + # "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False." + # ) + parser.add_argument("--latent_window_size", type=int, default=9, help="latent window size, default is 9. should not change.") + parser.add_argument( + "--embedded_cfg_scale", type=float, default=10.0, help="Embeded CFG scale (distilled CFG Scale), default is 10.0" + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=1.0, + help="Guidance scale for classifier free guidance. Default is 1.0, should not change.", + ) + parser.add_argument("--guidance_rescale", type=float, default=0.0, help="CFG Re-scale, default is 0.0. Should not change.") + # parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") + parser.add_argument( + "--image_path", + type=str, + default=None, + help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.", + ) + parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference") + # parser.add_argument( + # "--control_path", + # type=str, + # default=None, + # help="path to control video for inference with controlnet. video file or directory with images", + # ) + # parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving") + + # # Flow Matching + # parser.add_argument( + # "--flow_shift", + # type=float, + # default=None, + # help="Shift factor for flow matching schedulers. Default depends on task.", + # ) + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") + parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled mode and can degrade quality slightly but offers noticeable speedup") + parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", + type=str, + default="torch", + choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3", + help="attention mode", + ) + parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") + parser.add_argument( + "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" + ) + parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once") + parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") + parser.add_argument( + "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") + parser.add_argument("--compile", action="store_true", help="Enable torch.compile") + parser.add_argument( + "--compile_args", + nargs=4, + metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), + default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], + help="Torch.compile settings", + ) + + # New arguments for batch and interactive modes + parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file") + parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console") + + #parser.add_argument("--preview_latent_every", type=int, default=None, help="Preview latent every N sections") + parser.add_argument("--preview_suffix", type=str, default=None, help="Unique suffix for preview files to avoid conflicts in concurrent runs.") + parser.add_argument("--full_preview", action="store_true", help="Save full intermediate video previews instead of latent previews.") + + # TeaCache arguments + parser.add_argument("--use_teacache", action="store_true", help="Enable TeaCache for faster generation.") + parser.add_argument("--teacache_steps", type=int, default=25, help="Number of steps for TeaCache initialization (should match --infer_steps).") + parser.add_argument("--teacache_thresh", type=float, default=0.15, help="Relative L1 distance threshold for TeaCache skipping.") + + parser.add_argument( + "--video_sections", + type=int, + default=None, + help="number of video sections, Default is None (auto calculate from video seconds). Overrides --video_seconds if set.", + ) + + parser = add_blissful_args(parser) + args = parser.parse_args() + args = parse_blissful_args(args) + + # Validate arguments + if args.from_file and args.interactive: + raise ValueError("Cannot use both --from_file and --interactive at the same time") + + if args.prompt is None and not args.from_file and not args.interactive: + raise ValueError("Either --prompt, --from_file or --interactive must be specified") + + return args + + +def parse_prompt_line(line: str) -> Dict[str, Any]: + """Parse a prompt line into a dictionary of argument overrides + + Args: + line: Prompt line with options + + Returns: + Dict[str, Any]: Dictionary of argument overrides + """ + # TODO common function with hv_train_network.line_to_prompt_dict + parts = line.split(" --") + prompt = parts[0].strip() + + # Create dictionary of overrides + overrides = {"prompt": prompt} + + for part in parts[1:]: + if not part.strip(): + continue + option_parts = part.split(" ", 1) + option = option_parts[0].strip() + value = option_parts[1].strip() if len(option_parts) > 1 else "" + + # Map options to argument names + if option == "w": + overrides["video_size_width"] = int(value) + elif option == "h": + overrides["video_size_height"] = int(value) + elif option == "f": + overrides["video_seconds"] = float(value) + elif option == "d": + overrides["seed"] = int(value) + elif option == "s": + overrides["infer_steps"] = int(value) + elif option == "g" or option == "l": + overrides["guidance_scale"] = float(value) + # elif option == "fs": + # overrides["flow_shift"] = float(value) + elif option == "i": + overrides["image_path"] = value + elif option == "cn": + overrides["control_path"] = value + elif option == "n": + overrides["negative_prompt"] = value + + return overrides + + +def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace: + """Apply overrides to args + + Args: + args: Original arguments + overrides: Dictionary of overrides + + Returns: + argparse.Namespace: New arguments with overrides applied + """ + args_copy = copy.deepcopy(args) + + for key, value in overrides.items(): + if key == "video_size_width": + args_copy.video_size[1] = value + elif key == "video_size_height": + args_copy.video_size[0] = value + else: + setattr(args_copy, key, value) + + return args_copy + + +def check_inputs(args: argparse.Namespace) -> Tuple[int, int, float]: + """Validate video size and length + + Args: + args: command line arguments + + Returns: + Tuple[int, int, float]: (height, width, video_seconds) + """ + height = args.video_size[0] + width = args.video_size[1] + + if args.video_sections is not None: + video_seconds = (args.video_sections * (args.latent_window_size * 4) + 1) / args.fps + logger.info(f"--video_sections is set to {args.video_sections}. Calculated video_seconds: {video_seconds:.2f}s") + args.video_seconds = video_seconds + else: + video_seconds = args.video_seconds + + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + return height, width, video_seconds + + +# region DiT model + + +def get_dit_dtype(args: argparse.Namespace) -> torch.dtype: + dit_dtype = torch.bfloat16 + if args.precision == "fp16": + dit_dtype = torch.float16 + elif args.precision == "fp32": + dit_dtype = torch.float32 + return dit_dtype + + +def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVideoTransformer3DModelPacked: + """load DiT model + + Args: + args: command line arguments + device: device to use + + Returns: + HunyuanVideoTransformer3DModelPacked: DiT model + """ + loading_device = "cpu" + # Adjust loading device logic based on F1 requirements if necessary + if args.blocks_to_swap == 0 and not args.fp8_scaled and args.lora_weight is None: + loading_device = device + + # F1 model expects bfloat16 according to demo + # However, load_packed_model might handle dtype internally based on checkpoint. + # Let's keep the call as is for now. + logger.info(f"Loading DiT model (Class: HunyuanVideoTransformer3DModelPacked) for {'F1' if args.is_f1 else 'Standard'} mode.") + model = load_packed_model( + device=device, + dit_path=args.dit, + attn_mode=args.attn_mode, + loading_device=loading_device, + # Pass fp8_scaled and split_attn if load_packed_model supports them directly + # fp8_scaled=args.fp8_scaled, # Assuming load_packed_model handles this + # split_attn=False, # F1 demo doesn't use split_attn + ) + return model + + +def optimize_model(model: HunyuanVideoTransformer3DModelPacked, args: argparse.Namespace, device: torch.device) -> None: + """optimize the model (FP8 conversion, device move etc.) + + Args: + model: dit model + args: command line arguments + device: device to use + """ + if args.fp8_scaled: + # load state dict as-is and optimize to fp8 + state_dict = model.state_dict() + + # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy) + move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU + state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast) # args.fp8_fast) + + info = model.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded FP8 optimized weights: {info}") + + if args.blocks_to_swap == 0: + model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.) + else: + # simple cast to dit_dtype + target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) + target_device = None + + if args.fp8: + target_dtype = torch.float8e4m3fn + + if args.blocks_to_swap == 0: + logger.info(f"Move model to device: {device}") + target_device = device + + if target_device is not None and target_dtype is not None: + model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations + + if args.compile: + compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args + logger.info( + f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" + ) + torch._dynamo.config.cache_size_limit = 32 + for i in range(len(model.transformer_blocks)): + model.transformer_blocks[i] = torch.compile( + model.transformer_blocks[i], + backend=compile_backend, + mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true", + ) + + if args.blocks_to_swap > 0: + logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") + model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) + model.move_to_device_except_swap_blocks(device) + model.prepare_block_swap_before_forward() + else: + # make sure the model is on the right device + model.to(device) + + model.eval().requires_grad_(False) + clean_memory_on_device(device) + + +# endregion + + +def decode_latent( + latent_window_size: int, + total_latent_sections: int, + bulk_decode: bool, + vae: AutoencoderKLCausal3D, + latent: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + logger.info(f"Decoding video...") + if latent.ndim == 4: + latent = latent.unsqueeze(0) # add batch dimension + + vae.to(device) + if not bulk_decode: + latent_window_size = latent_window_size # default is 9 + # total_latent_sections = (args.video_seconds * 30) / (latent_window_size * 4) + # total_latent_sections = int(max(round(total_latent_sections), 1)) + num_frames = latent_window_size * 4 - 3 + + latents_to_decode = [] + latent_frame_index = 0 + for i in range(total_latent_sections - 1, -1, -1): + is_last_section = i == total_latent_sections - 1 + generated_latent_frames = (num_frames + 3) // 4 + (1 if is_last_section else 0) + section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) + + section_latent = latent[:, :, latent_frame_index : latent_frame_index + section_latent_frames, :, :] + latents_to_decode.append(section_latent) + + latent_frame_index += generated_latent_frames + + latents_to_decode = latents_to_decode[::-1] # reverse the order of latents to decode + + history_pixels = None + for latent in tqdm(latents_to_decode): + if history_pixels is None: + history_pixels = hunyuan.vae_decode(latent, vae).cpu() + else: + overlapped_frames = latent_window_size * 4 - 3 + current_pixels = hunyuan.vae_decode(latent, vae).cpu() + history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) + clean_memory_on_device(device) + else: + # bulk decode + logger.info(f"Bulk decoding") + history_pixels = hunyuan.vae_decode(latent, vae).cpu() + vae.to("cpu") + + logger.info(f"Decoded. Pixel shape {history_pixels.shape}") + return history_pixels[0] # remove batch dimension + + +def prepare_i2v_inputs( + args: argparse.Namespace, + device: torch.device, + vae: AutoencoderKLCausal3D, + encoded_context: Optional[Dict] = None, + encoded_context_n: Optional[Dict] = None, +) -> Tuple[int, int, float, dict, dict, dict, torch.Tensor]: # Adjusted return type annotation + """Prepare inputs for I2V + + Args: + args: command line arguments + device: device to use + vae: VAE model, used for image encoding + encoded_context: Pre-encoded text context + encoded_context_n: Pre-encoded negative text context + + Returns: + Tuple[int, int, float, dict, dict, dict, torch.Tensor]: + (height, width, video_seconds, context, context_null, context_img, end_latent) + """ + + def parse_section_strings(input_string: str) -> dict[int, str]: + section_strings = {} + if not input_string: # Handle empty input string + return {0: ""} + if ";;;" in input_string: + split_section_strings = input_string.split(";;;") + for section_str in split_section_strings: + if ":" not in section_str: + start = end = 0 + section_str_val = section_str.strip() + else: + index_str, section_str_val = section_str.split(":", 1) + index_str = index_str.strip() + section_str_val = section_str_val.strip() + + m = re.match(r"^(-?\d+)(-\d+)?$", index_str) + if m: + start = int(m.group(1)) + end = int(m.group(2)[1:]) if m.group(2) is not None else start + else: + start = end = 0 # Default to 0 if index format is invalid + + # Handle negative indices relative to a hypothetical 'last section' (-1) + # This part is tricky without knowing the total sections beforehand. + # For now, treat negative indices directly. A better approach might involve + # resolving them later in the generation loop. + for i in range(start, end + 1): + section_strings[i] = section_str_val + else: + # If no section specifiers, assume section 0 + section_strings[0] = input_string.strip() + + + # Ensure section 0 exists if any sections are defined + if section_strings and 0 not in section_strings: + indices = list(section_strings.keys()) + # Prefer smallest non-negative index, otherwise smallest negative index + try: + first_positive_index = min(i for i in indices if i >= 0) + section_index = first_positive_index + except ValueError: # No non-negative indices + section_index = min(indices) if indices else 0 # Fallback to 0 if empty + + if section_index in section_strings: + section_strings[0] = section_strings[section_index] + elif section_strings: # If section_index wasn't valid somehow, pick first available + section_strings[0] = next(iter(section_strings.values())) + else: # If section_strings was empty initially + section_strings[0] = "" # Default empty prompt + + # If still no section 0 (e.g., empty input string initially) + if 0 not in section_strings: + section_strings[0] = "" + + return section_strings + + # prepare image preprocessing function + def preprocess_image(image_path: str, target_height: int, target_width: int, is_f1: bool): # is_f1 is kept for signature, but not used differently here + image = Image.open(image_path).convert("RGB") + image_np = np.array(image) # PIL to numpy, HWC + + # Consistent image preprocessing for both F1 and standard mode, + # using target_height/target_width which come from args.video_size + image_np = image_video_dataset.resize_image_to_bucket(image_np, (target_width, target_height)) + processed_height, processed_width = image_np.shape[0], image_np.shape[1] # Get actual size after resize + + image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC + image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1 + return image_tensor, image_np, processed_height, processed_width + + # Initial height/width check. These dimensions will be used for image processing and generation. + height, width, video_seconds = check_inputs(args) + logger.info(f"Video dimensions for processing and generation set to: {height}x{width} (from --video_size or default).") + + section_image_paths = parse_section_strings(args.image_path) + + section_images = {} + first_image_processed = False + for index, image_path in section_image_paths.items(): + + img_tensor, img_np, proc_h, proc_w = preprocess_image(image_path, height, width, args.is_f1) + section_images[index] = (img_tensor, img_np) + if not first_image_processed and image_path: + default_video_size_used = (args.video_size[0] == 256 and args.video_size[1] == 256) # Check if default was used + if default_video_size_used and (proc_h != height or proc_w != width): + logger.info(f"Video dimensions updated to {proc_h}x{proc_w} based on first image processing (as default --video_size was used).") + height, width = proc_h, proc_w + args.video_size = [height, width] # Update args for consistency for downstream logging/metadata. + elif not default_video_size_used and (proc_h != height or proc_w != width): + logger.warning(f"User specified --video_size {height}x{width}, but first image processed to {proc_h}x{proc_w}. " + f"Generation will use {height}x{width}. Conditioning image aspect might differ.") + first_image_processed = True + + + # Process end image if provided + if args.end_image_path is not None: + end_img_tensor, end_img_np, _, _ = preprocess_image(args.end_image_path, height, width, args.is_f1) + else: + end_img_tensor, end_img_np = None, None + + # configure negative prompt + n_prompt = args.negative_prompt if args.negative_prompt else "" + + if encoded_context is None or encoded_context_n is None: # Regenerate if either is missing + # parse section prompts + section_prompts = parse_section_strings(args.prompt) + + # load text encoder + # Assuming load_text_encoder1/2 are compatible + tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device) + tokenizer2, text_encoder2 = load_text_encoder2(args) + text_encoder2.to(device) + + logger.info(f"Encoding prompts...") + llama_vecs = {} + llama_attention_masks = {} + clip_l_poolers = {} + # Use a common dtype for text encoders if possible, respecting fp8 flag + text_encoder_dtype = torch.float8_e4m3fn if args.fp8_llm else torch.float16 # text_encoder1.dtype + + # Pre-allocate negative prompt tensors only if needed + llama_vec_n, clip_l_pooler_n = None, None + llama_attention_mask_n = None + + # Encode positive prompts first + with torch.autocast(device_type=device.type, dtype=text_encoder_dtype), torch.no_grad(): + for index, prompt in section_prompts.items(): + # Ensure prompt is not empty before encoding + current_prompt = prompt if prompt else "" # Use empty string if prompt is None or empty + llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(current_prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2) + + # Pad/crop and store + llama_vec_padded, llama_attention_mask = crop_or_pad_yield_mask(llama_vec.cpu(), length=512) # Move to CPU before padding + + llama_vecs[index] = llama_vec_padded + llama_attention_masks[index] = llama_attention_mask + clip_l_poolers[index] = clip_l_pooler.cpu() # Move to CPU + + # Use the encoding of section 0 as fallback for negative if needed + if index == 0 and args.guidance_scale == 1.0: + llama_vec_n = torch.zeros_like(llama_vec_padded) + llama_attention_mask_n = torch.zeros_like(llama_attention_mask) + clip_l_pooler_n = torch.zeros_like(clip_l_poolers[0]) + + # Encode negative prompt if needed + if args.guidance_scale != 1.0: + with torch.autocast(device_type=device.type, dtype=text_encoder_dtype), torch.no_grad(): + current_n_prompt = n_prompt if n_prompt else "" + llama_vec_n_raw, clip_l_pooler_n_raw = hunyuan.encode_prompt_conds( + current_n_prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2 + ) + llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n_raw.cpu(), length=512) # Move to CPU + clip_l_pooler_n = clip_l_pooler_n_raw.cpu() # Move to CPU + + + # Check if negative prompt was generated (handles guidance_scale=1.0 case) + if llama_vec_n is None: + logger.warning("Negative prompt tensors not generated (likely guidance_scale=1.0). Using zeros.") + # Assuming section 0 exists and was processed + llama_vec_n = torch.zeros_like(llama_vecs[0]) + llama_attention_mask_n = torch.zeros_like(llama_attention_masks[0]) + clip_l_pooler_n = torch.zeros_like(clip_l_poolers[0]) + + + # free text encoder and clean memory + del text_encoder1, text_encoder2, tokenizer1, tokenizer2 + clean_memory_on_device(device) + + # load image encoder (Handles SigLIP via framepack_utils) + feature_extractor, image_encoder = load_image_encoders(args) + image_encoder.to(device) + + # encode image with image encoder + logger.info(f"Encoding images with {'SigLIP' if args.is_f1 else 'Image Encoder'}...") + section_image_encoder_last_hidden_states = {} + img_encoder_dtype = image_encoder.dtype # Get dtype from loaded model + end_image_embedding_for_f1 = None # Initialize for F1 end image + with torch.autocast(device_type=device.type, dtype=img_encoder_dtype), torch.no_grad(): + for index, (img_tensor, img_np) in section_images.items(): + # Use hf_clip_vision_encode (works for SigLIP too) + image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder) + image_encoder_last_hidden_state = image_encoder_output.last_hidden_state.cpu() # Move to CPU + section_image_encoder_last_hidden_states[index] = image_encoder_last_hidden_state + + if args.is_f1 and end_img_np is not None: # end_img_np is from args.end_image_path + logger.info("F1 Mode: Encoding end image for potential conditioning.") + end_image_encoder_output_f1 = hf_clip_vision_encode(end_img_np, feature_extractor, image_encoder) + end_image_embedding_for_f1 = end_image_encoder_output_f1.last_hidden_state.cpu() + + # free image encoder and clean memory + del image_encoder, feature_extractor + clean_memory_on_device(device) + + # --- Store encoded contexts for potential reuse --- + # Positive context (bundle per unique prompt string if needed, or just section 0) + # For simplicity, let's assume we only cache based on args.prompt for now + encoded_context = { + "llama_vecs": llama_vecs, + "llama_attention_masks": llama_attention_masks, + "clip_l_poolers": clip_l_poolers, + "image_encoder_last_hidden_states": section_image_encoder_last_hidden_states # Store all section states + } + # Negative context + encoded_context_n = { + "llama_vec": llama_vec_n, + "llama_attention_mask": llama_attention_mask_n, + "clip_l_pooler": clip_l_pooler_n, + } + # --- End context caching --- + + else: + # Use pre-encoded context + logger.info("Using pre-encoded context.") + llama_vecs = encoded_context["llama_vecs"] + llama_attention_masks = encoded_context["llama_attention_masks"] + clip_l_poolers = encoded_context["clip_l_poolers"] + section_image_encoder_last_hidden_states = encoded_context["image_encoder_last_hidden_states"] # Retrieve all sections + llama_vec_n = encoded_context_n["llama_vec"] + llama_attention_mask_n = encoded_context_n["llama_attention_mask"] + clip_l_pooler_n = encoded_context_n["clip_l_pooler"] + # Need to re-parse section prompts if using cached context + section_prompts = parse_section_strings(args.prompt) + + + # VAE encoding + logger.info(f"Encoding image(s) to latent space...") + vae.to(device) + vae_dtype = vae.dtype # Get VAE dtype + + section_start_latents = {} + with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad(): + for index, (img_tensor, img_np) in section_images.items(): + start_latent = hunyuan.vae_encode(img_tensor, vae).cpu() # Move to CPU + section_start_latents[index] = start_latent + + end_latent = hunyuan.vae_encode(end_img_tensor, vae).cpu() if end_img_tensor is not None else None # Move to CPU + + vae.to("cpu") # move VAE to CPU to save memory + clean_memory_on_device(device) + + # prepare model input arguments + arg_c = {} # Positive text conditioning per section + arg_c_img = {} # Positive image conditioning per section + + # Ensure section_prompts is available (parsed earlier) + if 'section_prompts' not in locals(): + section_prompts = parse_section_strings(args.prompt) + + # Populate positive text args + for index in llama_vecs.keys(): + # Get corresponding prompt, defaulting to empty string if index missing + prompt_text = section_prompts.get(index, "") + + arg_c_i = { + "llama_vec": llama_vecs[index], + "llama_attention_mask": llama_attention_masks[index], + "clip_l_pooler": clip_l_poolers[index], + "prompt": prompt_text, # Include the actual prompt text + } + arg_c[index] = arg_c_i + + # Populate negative text args (only one needed) + arg_null = { + "llama_vec": llama_vec_n, + "llama_attention_mask": llama_attention_mask_n, + "clip_l_pooler": clip_l_pooler_n, + "prompt": n_prompt, # Include negative prompt text + } + + # Populate positive image args + for index in section_start_latents.keys(): # Use latents keys as reference + # Check if corresponding hidden state exists, fallback to section 0 if needed + image_encoder_last_hidden_state = section_image_encoder_last_hidden_states.get(index, section_image_encoder_last_hidden_states.get(0)) + if image_encoder_last_hidden_state is None and section_image_encoder_last_hidden_states: + # Absolute fallback if index and 0 are missing but others exist + image_encoder_last_hidden_state = next(iter(section_image_encoder_last_hidden_states.values())) + elif image_encoder_last_hidden_state is None: + raise ValueError(f"Cannot find image encoder state for section {index} or fallback section 0.") + + + arg_c_img_i = { + "image_encoder_last_hidden_state": image_encoder_last_hidden_state, + "start_latent": section_start_latents[index] + } + arg_c_img[index] = arg_c_img_i + + # Ensure fallback section 0 exists in arg_c and arg_c_img if needed later + if 0 not in arg_c and arg_c: + arg_c[0] = next(iter(arg_c.values())) + if 0 not in arg_c_img and arg_c_img: + arg_c_img[0] = next(iter(arg_c_img.values())) + + # Final check for minimal context existence + if not arg_c or not arg_c_img: + raise ValueError("Failed to prepare conditioning arguments. Check prompts and image paths.") + + + return height, width, video_seconds, arg_c, arg_null, arg_c_img, end_latent, end_image_embedding_for_f1 + + +# def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]: +# """setup scheduler for sampling + +# Args: +# args: command line arguments +# config: model configuration +# device: device to use + +# Returns: +# Tuple[Any, torch.Tensor]: (scheduler, timesteps) +# """ +# if args.sample_solver == "unipc": +# scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False) +# scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift) +# timesteps = scheduler.timesteps +# elif args.sample_solver == "dpm++": +# scheduler = FlowDPMSolverMultistepScheduler( +# num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False +# ) +# sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift) +# timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas) +# elif args.sample_solver == "vanilla": +# scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift) +# scheduler.set_timesteps(args.infer_steps, device=device) +# timesteps = scheduler.timesteps + +# # FlowMatchDiscreteScheduler does not support generator argument in step method +# org_step = scheduler.step + +# def step_wrapper( +# model_output: torch.Tensor, +# timestep: Union[int, torch.Tensor], +# sample: torch.Tensor, +# return_dict: bool = True, +# generator=None, +# ): +# return org_step(model_output, timestep, sample, return_dict=return_dict) + +# scheduler.step = step_wrapper +# else: +# raise NotImplementedError("Unsupported solver.") + +# return scheduler, timesteps + + +# In fpack_generate_video.py + +def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> Tuple[AutoencoderKLCausal3D, torch.Tensor]: # Return VAE too + """main function for generation + + Args: + args: command line arguments + gen_settings: Generation settings object + shared_models: dictionary containing pre-loaded models and encoded data + + Returns: + Tuple[AutoencoderKLCausal3D, torch.Tensor]: vae, generated latent + """ + device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype) + + # prepare seed + seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + # Ensure seed is integer + if isinstance(seed, str): + try: + seed = int(seed) + except ValueError: + logger.warning(f"Invalid seed string: {seed}. Generating random seed.") + seed = random.randint(0, 2**32 - 1) + elif not isinstance(seed, int): + logger.warning(f"Invalid seed type: {type(seed)}. Generating random seed.") + seed = random.randint(0, 2**32 - 1) + + args.seed = seed # set seed to args for saving + + vae = None # Initialize VAE + + # Check if we have shared models + if shared_models is not None: + # Use shared models and encoded data + vae = shared_models.get("vae") + model = shared_models.get("model") + + # --- Retrieve cached context --- + # Try to get context based on the full prompt string first + prompt_key = args.prompt if args.prompt else "" + n_prompt_key = args.negative_prompt if args.negative_prompt else "" + + encoded_context = shared_models.get("encoded_contexts", {}).get(prompt_key) + encoded_context_n = shared_models.get("encoded_contexts", {}).get(n_prompt_key) + + # If not found, maybe the cache uses a simpler key (like just section 0?) - needs alignment with prepare_i2v_inputs caching logic + # For now, assume prepare_i2v_inputs handles regeneration if cache miss + if encoded_context is None or encoded_context_n is None: + logger.info("Cached context not found or incomplete, preparing inputs.") + # Need VAE for preparation if regenerating context + if vae is None: + vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) + height, width, video_seconds, context, context_null, context_img, end_latent = prepare_i2v_inputs( + args, device, vae # Pass VAE here + ) + # Store newly generated context back? (Requires shared_models to be mutable and handled carefully) + # shared_models["encoded_contexts"][prompt_key] = context # Simplified example + # shared_models["encoded_contexts"][n_prompt_key] = context_null # Simplified example + else: + logger.info("Using cached context from shared models.") + # Need VAE if decoding later, load if not present + if vae is None: + vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) + height, width, video_seconds, context, context_null, context_img, end_latent = prepare_i2v_inputs( + args, device, vae, encoded_context, encoded_context_n + ) + # --- End context retrieval --- + + else: + # prepare inputs without shared models + vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) + height, width, video_seconds, context, context_null, context_img, end_latent, end_image_embedding_for_f1 = prepare_i2v_inputs(args, device, vae) + # load DiT model + model = load_dit_model(args, device) # Handles F1 class loading implicitly + + # merge LoRA weights + if args.lora_weight is not None and len(args.lora_weight) > 0: + # Ensure merge_lora_weights can handle HunyuanVideoTransformer3DModelPacked + # It might need adjustments depending on its implementation. + logger.info("Merging LoRA weights...") + # Assuming lora_framepack is the correct network type definition + # Make sure merge_lora_weights exists and is imported + try: + from base_wan_generate_video import merge_lora_weights # Example import path + merge_lora_weights(lora_framepack, model, args, device) + except ImportError: + logger.error("merge_lora_weights function not found. Skipping LoRA merge.") + except Exception as e: + logger.error(f"Error merging LoRA weights: {e}") + + # if we only want to save the model, we can skip the rest + if args.save_merged_model: + # Implement saving logic here if merge_lora_weights doesn't handle it + logger.info(f"Saving merged model to {args.save_merged_model} and exiting.") + # Example: save_model(model, args.save_merged_model) + return None, None # Indicate no generation occurred + + + # optimize model: fp8 conversion, block swap etc. + optimize_model(model, args, device) + if args.use_teacache: + logger.info(f"Initializing TeaCache: steps={args.teacache_steps}, threshold={args.teacache_thresh}") + # The model's initialize_teacache expects num_steps and rel_l1_thresh + model.initialize_teacache( + enable_teacache=True, + num_steps=args.teacache_steps, + rel_l1_thresh=args.teacache_thresh + ) + else: + logger.info("TeaCache is disabled.") + # Ensure it's explicitly disabled in the model too, just in case + model.initialize_teacache(enable_teacache=False) + + # --- Sampling --- + latent_window_size = args.latent_window_size # default is 9 (consistent with F1 demo) + + if args.video_sections is not None: + total_latent_sections = args.video_sections + logger.info(f"Using --video_sections: {total_latent_sections} sections.") + else: + total_latent_sections = (video_seconds * args.fps) / (latent_window_size * 4) + total_latent_sections = int(max(round(total_latent_sections), 1)) + logger.info(f"Calculated total_latent_sections from video_seconds: {total_latent_sections} sections.") + + # set random generator + seed_g = torch.Generator(device="cpu") # Keep noise on CPU initially + seed_g.manual_seed(seed) + + # F1 expects frames = latent_window_size * 4 - 3 + # Our script's default decode uses latent_window_size * 4 - 3 overlap + # Let's calculate F1 frames per section explicitly + f1_frames_per_section = latent_window_size * 4 - 3 + + logger.info( + f"Mode: {'F1' if args.is_f1 else 'Standard'}, " + f"Video size: {height}x{width}@{video_seconds:.2f}s, fps: {args.fps}, num sections: {total_latent_sections}, " + f"infer_steps: {args.infer_steps}, frames per generation step: {f1_frames_per_section}" + ) + + # Determine compute dtype based on model/args + compute_dtype = model.dtype if hasattr(model, 'dtype') else torch.bfloat16 # Default for F1 + if args.fp8 or args.fp8_scaled: + # FP8 might still use bfloat16/float16 for some operations + logger.info("FP8 enabled, using bfloat16 for intermediate computations.") + compute_dtype = torch.bfloat16 # Or potentially float16 depending on model/ops + logger.info(f"Using compute dtype: {compute_dtype}") + + +# --- F1 Model Specific Sampling Logic --- + if args.is_f1: # Renamed from args.f1 in simpler script to args.is_f1 + logger.info("Starting F1 model sampling process.") + + logger.info(f"F1 Mode: Using video dimensions {height}x{width} for latent operations and generation.") + history_latents = torch.zeros((1, 16, 19, height // 8, width // 8), dtype=torch.float32, device='cpu') + + start_latent_0 = context_img.get(0, {}).get("start_latent") + if start_latent_0 is None: + raise ValueError("Cannot find start_latent for section 0 in context_img.") + + if start_latent_0.shape[3] != (height // 8) or start_latent_0.shape[4] != (width // 8): + logger.error(f"Mismatch between start_latent_0 dimensions ({start_latent_0.shape[3]}x{start_latent_0.shape[4]}) " + f"and history_latents dimensions ({height//8}x{width//8}). This should not happen with current logic.") + + history_latents = torch.cat([history_latents, start_latent_0.cpu().float()], dim=2) + + history_pixels_for_preview_f1_cpu = None + if args.full_preview and args.preview_latent_every is not None: + if vae is None: + logger.error("VAE not available for initial F1 preview setup.") + else: + logger.info("F1 Full Preview: Decoding initial start_latent for preview history.") + vae.to(device) + initial_latent_for_preview = start_latent_0.to(device, dtype=vae.dtype if hasattr(vae, 'dtype') else torch.float16) + # Assuming vae_decode returns BCTHW or CTHW. Ensure BCTHW for history_pixels. + decoded_initial = hunyuan.vae_decode(initial_latent_for_preview, vae).cpu() + if decoded_initial.ndim == 4: # CTHW + history_pixels_for_preview_f1_cpu = decoded_initial.unsqueeze(0) + elif decoded_initial.ndim == 5: # BCTHW + history_pixels_for_preview_f1_cpu = decoded_initial + else: + logger.error(f"Unexpected dimensions from initial VAE decode: {decoded_initial.shape}") + vae.to("cpu") + clean_memory_on_device(device) + + total_generated_latent_frames = 1 # Account for the initial start_latent_0 in history_latents + + if args.preview_latent_every and not args.full_preview: + previewer = LatentPreviewer(args, vae, None, gen_settings.device, compute_dtype, model_type="framepack") + else: + previewer = None + + for section_index in range(total_latent_sections): + logger.info(f"--- F1 Section {section_index + 1} / {total_latent_sections} ---") + f1_split_sizes = [1, 16, 2, 1, args.latent_window_size] + f1_indices = torch.arange(0, sum(f1_split_sizes)).unsqueeze(0).to(device) + ( + f1_clean_latent_indices_start, + f1_clean_latent_4x_indices, + f1_clean_latent_2x_indices, + f1_clean_latent_1x_indices, + f1_latent_indices, + ) = f1_indices.split(f1_split_sizes, dim=1) + f1_clean_latent_indices = torch.cat([f1_clean_latent_indices_start, f1_clean_latent_1x_indices], dim=1) + + current_image_context_section_idx = section_index if section_index in context_img else 0 + current_start_latent = context_img[current_image_context_section_idx]["start_latent"].to(device, dtype=torch.float32) + + current_history_for_f1_clean = history_latents[:, :, -sum([16, 2, 1]):, :, :].to(device, dtype=torch.float32) + f1_clean_latents_4x, f1_clean_latents_2x, f1_clean_latents_1x = current_history_for_f1_clean.split([16, 2, 1], dim=2) + + f1_clean_latents_combined = torch.cat([current_start_latent, f1_clean_latents_1x], dim=2) + + context_section_idx = section_index if section_index in context else 0 + llama_vec = context[context_section_idx]["llama_vec"].to(device, dtype=compute_dtype) + llama_attention_mask = context[context_section_idx]["llama_attention_mask"].to(device) + clip_l_pooler = context[context_section_idx]["clip_l_pooler"].to(device, dtype=compute_dtype) + image_encoder_last_hidden_state = context_img[current_image_context_section_idx]["image_encoder_last_hidden_state"].to(device, dtype=compute_dtype) + llama_vec_n = context_null["llama_vec"].to(device, dtype=compute_dtype) + llama_attention_mask_n = context_null["llama_attention_mask"].to(device) + clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=compute_dtype) + + # generated_latents_step is on GPU after sample_hunyuan + generated_latents_step = sample_hunyuan( + transformer=model, sampler=args.sample_solver, width=width, height=height, + frames=f1_frames_per_section, real_guidance_scale=args.guidance_scale, + distilled_guidance_scale=args.embedded_cfg_scale, guidance_rescale=args.guidance_rescale, + num_inference_steps=args.infer_steps, generator=seed_g, + prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler, + negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n, + device=device, dtype=compute_dtype, image_embeddings=image_encoder_last_hidden_state, + latent_indices=f1_latent_indices, clean_latents=f1_clean_latents_combined, clean_latent_indices=f1_clean_latent_indices, + clean_latents_2x=f1_clean_latents_2x, clean_latent_2x_indices=f1_clean_latent_2x_indices, + clean_latents_4x=f1_clean_latents_4x, clean_latent_4x_indices=f1_clean_latent_4x_indices, + ) + + newly_generated_latent_frames_count_this_step = int(generated_latents_step.shape[2]) + history_latents = torch.cat([history_latents, generated_latents_step.cpu().float()], dim=2) + total_generated_latent_frames += newly_generated_latent_frames_count_this_step + + if args.preview_latent_every is not None and (section_index + 1) % args.preview_latent_every == 0: + if args.full_preview: + logger.info(f"Saving full F1 preview at section {section_index + 1}") + if vae is None: + logger.error("VAE not available for full F1 preview.") + else: + preview_filename_full = os.path.join(args.save_path, f"latent_preview_{args.preview_suffix if args.preview_suffix else section_index + 1}.mp4") + + latents_this_step_for_decode = generated_latents_step.to(device, dtype=vae.dtype if hasattr(vae, 'dtype') else torch.float16) + + vae.to(device) + pixels_this_step_decoded_cpu = hunyuan.vae_decode(latents_this_step_for_decode, vae).cpu() + vae.to("cpu") + + if pixels_this_step_decoded_cpu.ndim == 4: + pixels_this_step_decoded_cpu = pixels_this_step_decoded_cpu.unsqueeze(0) + + if history_pixels_for_preview_f1_cpu is None: + history_pixels_for_preview_f1_cpu = pixels_this_step_decoded_cpu + else: + overlap_pixels = args.latent_window_size * 4 - 3 + history_pixels_for_preview_f1_cpu = soft_append_bcthw( + history_pixels_for_preview_f1_cpu, + pixels_this_step_decoded_cpu, + overlap=overlap_pixels + ) + + save_bcthw_as_mp4(history_pixels_for_preview_f1_cpu, preview_filename_full, fps=args.fps, crf=getattr(args, 'mp4_crf', 16)) + logger.info(f"Full F1 preview saved to {preview_filename_full}") + + del latents_this_step_for_decode, pixels_this_step_decoded_cpu + clean_memory_on_device(device) + elif previewer is not None: + logger.info(f"Previewing latents at F1 section {section_index + 1}") + preview_latents_f1_for_pv = history_latents[:, :, -total_generated_latent_frames:, :, :].to(gen_settings.device) + previewer.preview(preview_latents_f1_for_pv, section_index, preview_suffix=args.preview_suffix) + del preview_latents_f1_for_pv + clean_memory_on_device(gen_settings.device) + + del generated_latents_step, current_history_for_f1_clean, f1_clean_latents_combined + del f1_clean_latents_1x, f1_clean_latents_2x, f1_clean_latents_4x, current_start_latent + del llama_vec, llama_attention_mask, clip_l_pooler, image_encoder_last_hidden_state + del llama_vec_n, llama_attention_mask_n, clip_l_pooler_n + clean_memory_on_device(device) + + real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :] + # No resizing needed as generation happened at target dimensions. + + # --- Standard Model Sampling Logic --- + else: # Standard mode + logger.info("Starting standard model sampling process.") + history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32, device='cpu') + if end_latent is not None: + logger.info(f"Using end image: {args.end_image_path}") + history_latents[:, :, 0:1] = end_latent.cpu().float() + + total_generated_latent_frames = 0 + + history_pixels_for_preview_std_cpu = None # Initialize pixel history + # For standard mode (backward generation), the first chunk generated is the "end" of the video. + # If end_latent is provided and previews are on, we should decode it to start the preview history. + if args.full_preview and args.preview_latent_every is not None and end_latent is not None: + if vae is None: + logger.error("VAE not available for initial Standard mode preview setup with end_latent.") + else: + logger.info("Standard Full Preview: Decoding initial end_latent for preview history.") + vae.to(device) + initial_latent_for_preview = end_latent.to(device, dtype=vae.dtype if hasattr(vae, 'dtype') else torch.float16) + decoded_initial = hunyuan.vae_decode(initial_latent_for_preview, vae).cpu() + if decoded_initial.ndim == 4: # CTHW + history_pixels_for_preview_std_cpu = decoded_initial.unsqueeze(0) + elif decoded_initial.ndim == 5: # BCTHW + history_pixels_for_preview_std_cpu = decoded_initial + else: + logger.error(f"Unexpected dimensions from initial VAE decode for end_latent: {decoded_initial.shape}") + vae.to("cpu") + clean_memory_on_device(device) + + + latent_paddings = list(reversed(range(total_latent_sections))) + if total_latent_sections > 4: + logger.info("Using F1-style latent padding heuristic for > 4 sections.") + latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] + + if args.preview_latent_every and not args.full_preview: + previewer = LatentPreviewer(args, vae, None, gen_settings.device, compute_dtype, model_type="framepack") + else: + previewer = None + + for section_index_reverse, latent_padding in enumerate(latent_paddings): + section_index = total_latent_sections - 1 - section_index_reverse + section_index_from_last = -(section_index_reverse + 1) + logger.info(f"--- Standard Section {section_index + 1} / {total_latent_sections} (Reverse Index {section_index_reverse}, Padding {latent_padding}) ---") + + is_last_section = latent_padding == 0 + latent_padding_size = latent_padding * latent_window_size + + apply_section_image = False + if section_index_from_last in context_img: + image_index = section_index_from_last + if not is_last_section: apply_section_image = True + elif section_index in context_img: + image_index = section_index + if not is_last_section: apply_section_image = True + else: + image_index = 0 + + start_latent_section = context_img[image_index]["start_latent"].to(device, dtype=torch.float32) + if apply_section_image: + latent_padding_size = 0 + logger.info(f"Applying experimental section image, forcing latent_padding_size = 0") + + split_sizes_std = [1, latent_padding_size, latent_window_size, 1, 2, 16] + indices_std = torch.arange(0, sum(split_sizes_std)).unsqueeze(0).to(device) + ( + clean_latent_indices_pre, blank_indices, latent_indices, + clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices, + ) = indices_std.split(split_sizes_std, dim=1) + clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) + + current_history_std = history_latents[:, :, :19].to(device, dtype=torch.float32) + clean_latents_post, clean_latents_2x, clean_latents_4x = current_history_std.split([1, 2, 16], dim=2) + clean_latents = torch.cat([start_latent_section, clean_latents_post], dim=2) + + if section_index_from_last in context: prompt_index = section_index_from_last + elif section_index in context: prompt_index = section_index + else: prompt_index = 0 + context_for_index = context[prompt_index] + logger.info(f"Using prompt from section {prompt_index}: '{context_for_index['prompt'][:100]}...'") + + llama_vec = context_for_index["llama_vec"].to(device, dtype=compute_dtype) + llama_attention_mask = context_for_index["llama_attention_mask"].to(device) + clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=compute_dtype) + image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(device, dtype=compute_dtype) + llama_vec_n = context_null["llama_vec"].to(device, dtype=compute_dtype) + llama_attention_mask_n = context_null["llama_attention_mask"].to(device) + clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=compute_dtype) + + sampler_to_use = args.sample_solver + guidance_scale_to_use = args.guidance_scale + embedded_cfg_scale_to_use = args.embedded_cfg_scale + guidance_rescale_to_use = args.guidance_rescale + + # generated_latents_step is on GPU after sample_hunyuan + generated_latents_step_gpu = sample_hunyuan( + transformer=model, sampler=sampler_to_use, width=width, height=height, + frames=f1_frames_per_section, real_guidance_scale=guidance_scale_to_use, + distilled_guidance_scale=embedded_cfg_scale_to_use, guidance_rescale=guidance_rescale_to_use, + num_inference_steps=args.infer_steps, generator=seed_g, + prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler, + negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n, + device=device, dtype=compute_dtype, image_embeddings=image_encoder_last_hidden_state, + latent_indices=latent_indices, clean_latents=clean_latents, clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, clean_latent_4x_indices=clean_latent_4x_indices, + ) + + # Move to CPU for history accumulation and potential preview decode + generated_latents_step = generated_latents_step_gpu.cpu().float() + + if is_last_section: # This is the first iteration in reverse, corresponds to earliest part of generated video + logger.info("Standard Mode: Last section (first in reverse loop), prepending start_latent_section for this chunk.") + generated_latents_step = torch.cat([start_latent_section.cpu().float(), generated_latents_step], dim=2) + + current_step_latents_cpu = generated_latents_step.clone() # This is what was generated/prepended in this step + + total_generated_latent_frames += int(generated_latents_step.shape[2]) + history_latents = torch.cat([generated_latents_step, history_latents], dim=2) # Prepend to full latent history + + real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] + + if args.preview_latent_every is not None and (section_index_reverse + 1) % args.preview_latent_every == 0: + if args.full_preview: + logger.info(f"Saving full preview at standard section {section_index + 1} (Reverse Index {section_index_reverse})") + if vae is None: + logger.error("VAE not available for full standard preview.") + else: + preview_filename_full_std = os.path.join(args.save_path, f"latent_preview_{args.preview_suffix if args.preview_suffix else section_index_reverse + 1}.mp4") + + latents_this_step_for_decode = current_step_latents_cpu.to(device, dtype=vae.dtype if hasattr(vae, 'dtype') else torch.float16) + + vae.to(device) + pixels_this_step_decoded_cpu = hunyuan.vae_decode(latents_this_step_for_decode, vae).cpu() + vae.to("cpu") + + if pixels_this_step_decoded_cpu.ndim == 4: + pixels_this_step_decoded_cpu = pixels_this_step_decoded_cpu.unsqueeze(0) + + if history_pixels_for_preview_std_cpu is None: + history_pixels_for_preview_std_cpu = pixels_this_step_decoded_cpu + else: + overlap_pixels = args.latent_window_size * 4 - 3 + # Standard mode prepends, so new pixels are first arg for soft_append + history_pixels_for_preview_std_cpu = soft_append_bcthw( + pixels_this_step_decoded_cpu, + history_pixels_for_preview_std_cpu, + overlap=overlap_pixels + ) + + save_bcthw_as_mp4(history_pixels_for_preview_std_cpu, preview_filename_full_std, fps=args.fps, crf=getattr(args, 'mp4_crf', 16)) + logger.info(f"Full standard preview saved to {preview_filename_full_std}") + del latents_this_step_for_decode, pixels_this_step_decoded_cpu + clean_memory_on_device(device) + elif previewer is not None: + logger.info(f"Previewing latents at standard section {section_index + 1} (Reverse Index {section_index_reverse})") + preview_latents_std_for_pv = real_history_latents.to(gen_settings.device) + previewer.preview(preview_latents_std_for_pv, section_index, preview_suffix=args.preview_suffix) + del preview_latents_std_for_pv + clean_memory_on_device(gen_settings.device) + + logger.info(f"Section {section_index + 1} finished. Total latent frames: {total_generated_latent_frames}. History shape: {history_latents.shape}") + + del generated_latents_step, current_history_std, clean_latents, clean_latents_post, clean_latents_2x, clean_latents_4x + del llama_vec, llama_attention_mask, clip_l_pooler, image_encoder_last_hidden_state, start_latent_section + del llama_vec_n, llama_attention_mask_n, clip_l_pooler_n + # Explicitly delete the GPU tensor if it was created + if 'generated_latents_step_gpu' in locals(): del generated_latents_step_gpu + clean_memory_on_device(device) + + gc.collect() + clean_memory_on_device(device) + + # Return the final generated latents (CPU tensor) and the VAE + # The shape should be (B, C, T_total, H, W) + logger.info(f"Generation complete. Final latent shape: {real_history_latents.shape}") + return vae, real_history_latents # Return VAE along with latents + + +def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int, original_base_name: Optional[str] = None) -> str: # Add original_base_name + """Save latent to file + + Args: + latent: Latent tensor (CTHW expected) + args: command line arguments + height: height of frame + width: width of frame + original_base_name: Optional base name from loaded file + + Returns: + str: Path to saved latent file + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + seed = args.seed + original_name = "" if original_base_name is None else f"_{original_base_name}" # Use provided base name + video_seconds = args.video_seconds + latent_path = f"{save_path}/{time_flag}_{seed}{original_name}_latent.safetensors" # Add original name to file + + # Ensure latent is on CPU before saving + latent = latent.detach().cpu() + + if args.no_metadata: + metadata = None + else: + # (Metadata creation remains the same) + metadata = { + "seeds": f"{seed}", + "prompt": f"{args.prompt}", + "height": f"{height}", + "width": f"{width}", + "video_seconds": f"{video_seconds}", + "infer_steps": f"{args.infer_steps}", + "guidance_scale": f"{args.guidance_scale}", + "latent_window_size": f"{args.latent_window_size}", + "embedded_cfg_scale": f"{args.embedded_cfg_scale}", + "guidance_rescale": f"{args.guidance_rescale}", + "sample_solver": f"{args.sample_solver}", + # "latent_window_size": f"{args.latent_window_size}", # Duplicate key + "fps": f"{args.fps}", + "is_f1": f"{args.is_f1}", # Add F1 flag to metadata + } + if args.negative_prompt is not None: + metadata["negative_prompt"] = f"{args.negative_prompt}" + # Add other relevant args like LoRA, compile settings, etc. if desired + + sd = {"latent": latent.contiguous()} + save_file(sd, latent_path, metadata=metadata) + logger.info(f"Latent saved to: {latent_path}") + + return latent_path + + +def save_video( + video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None, latent_frames: Optional[int] = None +) -> str: + """Save video to file + + Args: + video: Video tensor + args: command line arguments + original_base_name: Original base name (if latents are loaded from files) + + Returns: + str: Path to saved video file + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + seed = args.seed + original_name = "" if original_base_name is None else f"_{original_base_name}" + latent_frames = "" if latent_frames is None else f"_{latent_frames}" + video_path = f"{save_path}/{time_flag}_{seed}{original_name}{latent_frames}.mp4" + + video = video.unsqueeze(0) + if args.codec is not None: + save_videos_grid_advanced(video, video_path, args.codec, args.container, rescale=True, fps=args.fps, keep_frames=args.keep_pngs) + else: + save_videos_grid(video, video_path, fps=args.fps, rescale=True) + logger.info(f"Video saved to: {video_path}") + + return video_path + + +def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: + """Save images to directory + + Args: + sample: Video tensor + args: command line arguments + original_base_name: Original base name (if latents are loaded from files) + + Returns: + str: Path to saved images directory + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + seed = args.seed + original_name = "" if original_base_name is None else f"_{original_base_name}" + image_name = f"{time_flag}_{seed}{original_name}" + sample = sample.unsqueeze(0) + save_images_grid(sample, save_path, image_name, rescale=True) + logger.info(f"Sample images saved to: {save_path}/{image_name}") + + return f"{save_path}/{image_name}" + + +# In fpack_generate_video.py + +def save_output( + args: argparse.Namespace, + vae: AutoencoderKLCausal3D, + latent: torch.Tensor, + device: torch.device, + original_base_names: Optional[List[str]] = None, +) -> None: + """save output + + Args: + args: command line arguments + vae: VAE model + latent: latent tensor (should be BCTHW or CTHW) + device: device to use + original_base_names: original base names (if latents are loaded from files) + """ + if latent.ndim == 4: # Add batch dim if missing (CTHW -> BCTHW) + latent = latent.unsqueeze(0) + elif latent.ndim != 5: + raise ValueError(f"Unexpected latent dimensions: {latent.ndim}. Expected 4 or 5.") + + # Latent shape is BCTHW + batch_size, channels, latent_frames, latent_height, latent_width = latent.shape + height = latent_height * 8 + width = latent_width * 8 + logger.info(f"Saving output. Latent shape: {latent.shape}; Target pixel shape: {height}x{width}") + + if args.output_type == "latent" or args.output_type == "both": + # save latent (use first name if multiple originals) + base_name = original_base_names[0] if original_base_names else None + save_latent(latent[0], args, height, width, original_base_name=base_name) # Save first batch item if B > 1 + if args.output_type == "latent": + return + + if args.video_sections is not None: + total_latent_sections = args.video_sections + else: + total_latent_sections = (args.video_seconds * args.fps) / (args.latent_window_size * 4) + total_latent_sections = int(max(round(total_latent_sections), 1)) + + logger.info(f"Decoding using total_latent_sections = {total_latent_sections} (derived from {'--video_sections' if args.video_sections is not None else '--video_seconds'}).") + + # Decode (handle potential batch > 1?) + # decode_latent expects BCTHW or CTHW, and returns CTHW + # Currently process only the first item in the batch for saving video/images + video = decode_latent(args.latent_window_size, total_latent_sections, args.bulk_decode, vae, latent[0], device) + + if args.output_type == "video" or args.output_type == "both": + # save video + original_name = original_base_names[0] if original_base_names else None + save_video(video, args, original_name, latent_frames=latent_frames) # Pass latent frames count + + elif args.output_type == "images": + # save images + original_name = original_base_names[0] if original_base_names else None + save_images(video, args, original_name) + + +def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]: + """Process multiple prompts for batch mode + + Args: + prompt_lines: List of prompt lines + base_args: Base command line arguments + + Returns: + List[Dict]: List of prompt data dictionaries + """ + prompts_data = [] + + for line in prompt_lines: + line = line.strip() + if not line or line.startswith("#"): # Skip empty lines and comments + continue + + # Parse prompt line and create override dictionary + prompt_data = parse_prompt_line(line) + logger.info(f"Parsed prompt data: {prompt_data}") + prompts_data.append(prompt_data) + + return prompts_data + + +def get_generation_settings(args: argparse.Namespace) -> GenerationSettings: + device = torch.device(args.device) + + dit_weight_dtype = None # default + if args.fp8_scaled: + dit_weight_dtype = None # various precision weights, so don't cast to specific dtype + elif args.fp8: + dit_weight_dtype = torch.float8_e4m3fn + + logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}") + + gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype) + return gen_settings + + +# In fpack_generate_video.py + +def main(): + # Parse arguments + args = parse_args() + + # Check if latents are provided + latents_mode = args.latent_path is not None and len(args.latent_path) > 0 + + # Set device + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + logger.info(f"Using device: {device}") + args.device = device # Ensure args has the final device + + if latents_mode: + # --- Latent Decode Mode --- + # (Keep existing logic, but maybe add F1 flag reading from metadata?) + original_base_names = [] + latents_list = [] + seeds = [] + is_f1_from_metadata = False # Default + + # Allow only one latent file for simplicity now + if len(args.latent_path) > 1: + logger.warning("Loading multiple latents is not fully supported for metadata consistency. Using first latent's metadata.") + + for i, latent_path in enumerate(args.latent_path): + logger.info(f"Loading latent from: {latent_path}") + base_name = os.path.splitext(os.path.basename(latent_path))[0] + original_base_names.append(base_name) + seed = 0 # Default seed + + if not latent_path.lower().endswith(".safetensors"): + logger.warning(f"Loading from non-safetensors file {latent_path}. Metadata might be missing.") + latents = torch.load(latent_path, map_location="cpu") + if isinstance(latents, dict) and "latent" in latents: # Handle potential dict structure + latents = latents["latent"] + else: + try: + # Load latent tensor + loaded_data = load_file(latent_path, device="cpu") # Load to CPU + latents = loaded_data["latent"] + + # Load metadata + metadata = {} + with safe_open(latent_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + logger.info(f"Loaded metadata: {metadata}") + + # Apply metadata only from the first file for consistency + if i == 0: + if "seeds" in metadata: + try: + seed = int(metadata["seeds"]) + except ValueError: + logger.warning(f"Could not parse seed from metadata: {metadata['seeds']}") + if "height" in metadata and "width" in metadata: + try: + height = int(metadata["height"]) + width = int(metadata["width"]) + args.video_size = [height, width] + logger.info(f"Set video size from metadata: {height}x{width}") + except ValueError: + logger.warning(f"Could not parse height/width from metadata.") + if "video_seconds" in metadata: + try: + args.video_seconds = float(metadata["video_seconds"]) + logger.info(f"Set video seconds from metadata: {args.video_seconds}") + except ValueError: + logger.warning(f"Could not parse video_seconds from metadata.") + if "fps" in metadata: + try: + args.fps = int(metadata["fps"]) + logger.info(f"Set fps from metadata: {args.fps}") + except ValueError: + logger.warning(f"Could not parse fps from metadata.") + if "is_f1" in metadata: + is_f1_from_metadata = metadata["is_f1"].lower() == 'true' + if args.is_f1 != is_f1_from_metadata: + logger.warning(f"Metadata indicates is_f1={is_f1_from_metadata}, overriding command line argument --is_f1={args.is_f1}") + args.is_f1 = is_f1_from_metadata + + + except Exception as e: + logger.error(f"Error loading safetensors file {latent_path}: {e}") + continue # Skip this file + + # Use seed from first file for all if multiple latents are somehow processed + if i == 0: + args.seed = seed + seeds.append(seed) # Store all seeds read + + logger.info(f"Loaded latent shape: {latents.shape}") + + if latents.ndim == 5: # [BCTHW] + if latents.shape[0] > 1: + logger.warning("Latent file contains batch size > 1. Using only the first item.") + latents = latents[0] # Use first item -> [CTHW] + elif latents.ndim != 4: + logger.error(f"Unexpected latent dimension {latents.ndim} in {latent_path}. Skipping.") + continue + + latents_list.append(latents) + + if not latents_list: + logger.error("No valid latents loaded. Exiting.") + return + + # Stack latents into a batch if multiple were loaded (BCTHW) + # Note: Saving output currently only processes the first batch item. + latent_batch = torch.stack(latents_list, dim=0) + + # Load VAE needed for decoding + vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) + # Call save_output with the batch + save_output(args, vae, latent_batch, device, original_base_names) + + elif args.from_file: + # Batch mode from file (Not Implemented) + logger.error("Batch mode (--from_file) is not implemented yet.") + # with open(args.from_file, "r", encoding="utf-8") as f: + # prompt_lines = f.readlines() + # prompts_data = preprocess_prompts_for_batch(prompt_lines, args) + # process_batch_prompts(prompts_data, args) # Needs implementation + raise NotImplementedError("Batch mode is not implemented yet.") + + elif args.interactive: + # Interactive mode (Not Implemented) + logger.error("Interactive mode (--interactive) is not implemented yet.") + # process_interactive(args) # Needs implementation + raise NotImplementedError("Interactive mode is not implemented yet.") + + else: + # --- Single prompt mode (original behavior + F1 support) --- + gen_settings = get_generation_settings(args) + + # Generate returns (vae, latent) + vae, latent = generate(args, gen_settings) # VAE might be loaded inside generate + + if latent is None: # Handle cases like --save_merged_model + logger.info("Generation did not produce latents (e.g., --save_merged_model used). Exiting.") + return + + # Ensure VAE is available (it should be returned by generate) + if vae is None: + logger.error("VAE not available after generation. Cannot save output.") + return + + # Save output expects BCTHW or CTHW, generate returns BCTHW + # save_output handles the batch dimension internally now. + save_output(args, vae, latent, device) + + # Clean up VAE if it was loaded here + del vae + gc.collect() + clean_memory_on_device(device) + + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/frame_pack/__init__.py b/frame_pack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/frame_pack/bucket_tools.py b/frame_pack/bucket_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..1d531642945e5951214e8a7bd6fbd39d824664a5 --- /dev/null +++ b/frame_pack/bucket_tools.py @@ -0,0 +1,157 @@ +# Base configuration for scaling bucket options +_BASE_RESOLUTION = 640 +_BASE_BUCKET_OPTIONS = [ + (416, 960), (448, 864), (480, 832), (512, 768), (544, 704), + (576, 672), (608, 640), (640, 608), (672, 576), (704, 544), + (768, 512), (832, 480), (864, 448), (960, 416), +] + +# Cache for generated bucket options to avoid redundant calculations +_generated_bucket_cache = {} + +def _round_to_multiple(number, multiple): + """Rounds a number to the nearest multiple of a given number.""" + if multiple == 0: + # Default behavior: round to nearest int. Could also raise an error. + return int(round(number)) + return int(multiple * round(float(number) / multiple)) + +def _adjust_resolution(resolution, divisor=32): + """ + Adjusts a given resolution to the nearest multiple of 'divisor'. + If the input resolution is positive but rounds to 0 (e.g., resolution=10, divisor=32), + it's adjusted to 'divisor'. + If the input resolution is non-positive (<=0), it defaults to 'divisor'. + """ + if resolution <= 0: + return divisor # Default to minimum valid resolution for non-positive inputs + + adjusted = _round_to_multiple(resolution, divisor) + + # If resolution was positive but _round_to_multiple resulted in 0 + # (e.g. input 10 for divisor 32 rounds to 0), ensure it's at least the divisor. + if adjusted == 0: + return divisor + return adjusted + +def generate_scaled_buckets(target_resolution_input, + base_resolution=_BASE_RESOLUTION, + base_options=_BASE_BUCKET_OPTIONS, + divisor=32): + """ + Generates scaled bucket options for a target resolution. + + The target_resolution_input is first adjusted to the nearest multiple of 'divisor'. + Bucket dimensions are scaled from 'base_options' (which are for 'base_resolution') + to the adjusted target resolution. These scaled dimensions are then rounded to the + nearest multiple of 'divisor' and ensured to be at least 'divisor'. + + Args: + target_resolution_input (int): The desired target resolution. + base_resolution (int): The resolution for which 'base_options' are defined. + base_options (list of tuples): A list of (height, width) tuples for 'base_resolution'. + divisor (int): The number that resolutions and bucket dimensions should be multiples of. + + Returns: + list of tuples: Scaled and adjusted bucket options (height, width). + """ + # Adjust the target resolution for scaling + actual_target_resolution = _adjust_resolution(target_resolution_input, divisor) + + if actual_target_resolution in _generated_bucket_cache: + return _generated_bucket_cache[actual_target_resolution] + + # Optimization: If adjusted target resolution matches base resolution. + # This assumes base_options are already compliant with the divisor. + # (Our _BASE_BUCKET_OPTIONS are multiples of 32, so this is fine for divisor=32). + if actual_target_resolution == base_resolution: + options_to_return = list(base_options) # Return a copy + _generated_bucket_cache[actual_target_resolution] = options_to_return + return options_to_return + + scaled_options = [] + seen_options = set() # To handle potential duplicates after rounding + + # Prevent division by zero if base_resolution is 0 (though _BASE_RESOLUTION is 640). + if base_resolution == 0: + # Fallback: return a single square bucket of the target resolution. + # This case should not be hit with current constants. + default_bucket = (actual_target_resolution, actual_target_resolution) + _generated_bucket_cache[actual_target_resolution] = [default_bucket] + return [default_bucket] + + scale_factor = float(actual_target_resolution) / base_resolution + + for base_h, base_w in base_options: + scaled_h_float = base_h * scale_factor + scaled_w_float = base_w * scale_factor + + scaled_h = _round_to_multiple(scaled_h_float, divisor) + scaled_w = _round_to_multiple(scaled_w_float, divisor) + + # Ensure minimum dimension is at least the divisor + scaled_h = max(scaled_h, divisor) + scaled_w = max(scaled_w, divisor) + + bucket_tuple = (scaled_h, scaled_w) + if bucket_tuple not in seen_options: + scaled_options.append(bucket_tuple) + seen_options.add(bucket_tuple) + + # If base_options was empty (not the case for internal use but could be if called externally), + # scaled_options would be empty. Provide a default bucket in such a scenario. + # actual_target_resolution is guaranteed to be >= divisor by _adjust_resolution. + if not scaled_options: + default_bucket = (actual_target_resolution, actual_target_resolution) + scaled_options.append(default_bucket) + + _generated_bucket_cache[actual_target_resolution] = scaled_options + return scaled_options + +def find_nearest_bucket(h, w, resolution=640): + """ + Finds the nearest bucket for a given height (h) and width (w) + at a specified target resolution. + + The 'resolution' parameter is the user's intended target resolution. + This function will: + 1. Adjust this resolution to the nearest multiple of 32 (minimum 32). + 2. Generate a list of bucket options (height, width pairs) by scaling + predefined base options (for 640px) to this adjusted resolution. + All generated bucket dimensions will also be multiples of 32 and at least 32. + 3. Find the bucket from this generated list that is "nearest" to the + aspect ratio of the input h, w. The nearness metric is + abs(input_h * bucket_w - input_w * bucket_h). + + Args: + h (int): The height of the image/item. + w (int): The width of the image/item. + resolution (int): The target resolution for which to find buckets. + Defaults to 640. + + Returns: + tuple: A (bucket_h, bucket_w) tuple representing the best bucket found. + """ + # generate_scaled_buckets handles the adjustment of 'resolution' internally + # and uses a divisor of 32 by default for its calculations. + # The problem statement implies a fixed divisor of 32 for this tool. + current_bucket_options = generate_scaled_buckets(resolution, divisor=32) + + # Failsafe: If generate_scaled_buckets somehow returned an empty list (e.g., if _BASE_BUCKET_OPTIONS was empty), + # provide a default bucket based on the adjusted resolution. + if not current_bucket_options: + adjusted_res_for_fallback = _adjust_resolution(resolution, 32) + return (adjusted_res_for_fallback, adjusted_res_for_fallback) + + min_metric = float('inf') + best_bucket = None + # Since current_bucket_options is guaranteed to be non-empty by the check above (or by generate_scaled_buckets's own logic + # when _BASE_BUCKET_OPTIONS is populated), best_bucket will be assigned in the loop. + + for (bucket_h, bucket_w) in current_bucket_options: + metric = abs(h * bucket_w - w * bucket_h) + if metric <= min_metric: # Using "<=" preserves original behavior (last encountered wins on ties) + min_metric = metric + best_bucket = (bucket_h, bucket_w) + + return best_bucket \ No newline at end of file diff --git a/frame_pack/clip_vision.py b/frame_pack/clip_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..1c919296b23084ac00e3e4440657d368df1ee86e --- /dev/null +++ b/frame_pack/clip_vision.py @@ -0,0 +1,14 @@ +import numpy as np + + +def hf_clip_vision_encode(image, feature_extractor, image_encoder): + assert isinstance(image, np.ndarray) + assert image.ndim == 3 and image.shape[2] == 3 + assert image.dtype == np.uint8 + + preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to( + device=image_encoder.device, dtype=image_encoder.dtype + ) + image_encoder_output = image_encoder(**preprocessed) + + return image_encoder_output diff --git a/frame_pack/framepack_utils.py b/frame_pack/framepack_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a57364aa6bb0d67c0491b532646daa2fa6d36fc1 --- /dev/null +++ b/frame_pack/framepack_utils.py @@ -0,0 +1,273 @@ +import os +import logging +from types import SimpleNamespace +from typing import Optional, Union + +import accelerate +from accelerate import Accelerator, init_empty_weights +import torch +from safetensors.torch import load_file +from transformers import ( + LlamaTokenizerFast, + LlamaConfig, + LlamaModel, + CLIPTokenizer, + CLIPTextModel, + CLIPConfig, + SiglipImageProcessor, + SiglipVisionModel, + SiglipVisionConfig, +) + +from utils.safetensors_utils import load_split_weights +from hunyuan_model.vae import load_vae as hunyuan_load_vae + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def load_vae( + vae_path: str, vae_chunk_size: Optional[int], vae_spatial_tile_sample_min_size: Optional[int], device: Union[str, torch.device] +): + # single file and directory (contains 'vae') support + if os.path.isdir(vae_path): + vae_path = os.path.join(vae_path, "vae", "diffusion_pytorch_model.safetensors") + else: + vae_path = vae_path + + vae_dtype = torch.float16 # if vae_dtype is None else str_to_dtype(vae_dtype) + vae, _, s_ratio, t_ratio = hunyuan_load_vae(vae_dtype=vae_dtype, device=device, vae_path=vae_path) + vae.eval() + # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} + + # set chunk_size to CausalConv3d recursively + chunk_size = vae_chunk_size + if chunk_size is not None: + vae.set_chunk_size_for_causal_conv_3d(chunk_size) + logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d") + + if vae_spatial_tile_sample_min_size is not None: + vae.enable_spatial_tiling(True) + vae.tile_sample_min_size = vae_spatial_tile_sample_min_size + vae.tile_latent_min_size = vae_spatial_tile_sample_min_size // 8 + logger.info(f"Enabled spatial tiling with min size {vae_spatial_tile_sample_min_size}") + # elif vae_tiling: + else: + vae.enable_spatial_tiling(True) + + return vae + + +# region Text Encoders + +# Text Encoder configs are copied from HunyuanVideo repo + +LLAMA_CONFIG = { + "architectures": ["LlamaModel"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "mlp_bias": False, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.46.3", + "use_cache": True, + "vocab_size": 128320, +} + +CLIP_CONFIG = { + # "_name_or_path": "/raid/aryan/llava-llama-3-8b-v1_1-extracted/text_encoder_2", + "architectures": ["CLIPTextModel"], + "attention_dropout": 0.0, + "bos_token_id": 0, + "dropout": 0.0, + "eos_token_id": 2, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 77, + "model_type": "clip_text_model", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 1, + "projection_dim": 768, + "torch_dtype": "float16", + "transformers_version": "4.48.0.dev0", + "vocab_size": 49408, +} + + +def load_text_encoder1( + args, fp8_llm: Optional[bool] = False, device: Optional[Union[str, torch.device]] = None +) -> tuple[LlamaTokenizerFast, LlamaModel]: + # single file, split file and directory (contains 'text_encoder') support + logger.info(f"Loading text encoder 1 tokenizer") + tokenizer1 = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer") + + logger.info(f"Loading text encoder 1 from {args.text_encoder1}") + if os.path.isdir(args.text_encoder1): + # load from directory, configs are in the directory + text_encoder1 = LlamaModel.from_pretrained(args.text_encoder1, subfolder="text_encoder", torch_dtype=torch.float16) + else: + # load from file, we create the model with the appropriate config + config = LlamaConfig(**LLAMA_CONFIG) + with init_empty_weights(): + text_encoder1 = LlamaModel._from_config(config, torch_dtype=torch.float16) + + state_dict = load_split_weights(args.text_encoder1) + + # support weights from ComfyUI + if "model.embed_tokens.weight" in state_dict: + for key in list(state_dict.keys()): + if key.startswith("model."): + new_key = key.replace("model.", "") + state_dict[new_key] = state_dict[key] + del state_dict[key] + if "tokenizer" in state_dict: + state_dict.pop("tokenizer") + if "lm_head.weight" in state_dict: + state_dict.pop("lm_head.weight") + + # # support weights from ComfyUI + # if "tokenizer" in state_dict: + # state_dict.pop("tokenizer") + + text_encoder1.load_state_dict(state_dict, strict=True, assign=True) + + if fp8_llm: + org_dtype = text_encoder1.dtype + logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn") + text_encoder1.to(device=device, dtype=torch.float8_e4m3fn) + + # prepare LLM for fp8 + def prepare_fp8(llama_model: LlamaModel, target_dtype): + def forward_hook(module): + def forward(hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) + return module.weight.to(input_dtype) * hidden_states.to(input_dtype) + + return forward + + for module in llama_model.modules(): + if module.__class__.__name__ in ["Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["LlamaRMSNorm"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + prepare_fp8(text_encoder1, org_dtype) + else: + text_encoder1.to(device) + + text_encoder1.eval() + return tokenizer1, text_encoder1 + + +def load_text_encoder2(args) -> tuple[CLIPTokenizer, CLIPTextModel]: + # single file and directory (contains 'text_encoder_2') support + logger.info(f"Loading text encoder 2 tokenizer") + tokenizer2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer_2") + + logger.info(f"Loading text encoder 2 from {args.text_encoder2}") + if os.path.isdir(args.text_encoder2): + # load from directory, configs are in the directory + text_encoder2 = CLIPTextModel.from_pretrained(args.text_encoder2, subfolder="text_encoder_2", torch_dtype=torch.float16) + else: + # we only have one file, so we can load it directly + config = CLIPConfig(**CLIP_CONFIG) + with init_empty_weights(): + text_encoder2 = CLIPTextModel._from_config(config, torch_dtype=torch.float16) + + state_dict = load_file(args.text_encoder2) + + text_encoder2.load_state_dict(state_dict, strict=True, assign=True) + + text_encoder2.eval() + return tokenizer2, text_encoder2 + + +# endregion + +# region image encoder + +# Siglip configs are copied from FramePack repo +FEATURE_EXTRACTOR_CONFIG = { + "do_convert_rgb": None, + "do_normalize": True, + "do_rescale": True, + "do_resize": True, + "image_mean": [0.5, 0.5, 0.5], + "image_processor_type": "SiglipImageProcessor", + "image_std": [0.5, 0.5, 0.5], + "processor_class": "SiglipProcessor", + "resample": 3, + "rescale_factor": 0.00392156862745098, + "size": {"height": 384, "width": 384}, +} +IMAGE_ENCODER_CONFIG = { + "_name_or_path": "/home/lvmin/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-Redux-dev/snapshots/1282f955f706b5240161278f2ef261d2a29ad649/image_encoder", + "architectures": ["SiglipVisionModel"], + "attention_dropout": 0.0, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 384, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "patch_size": 14, + "torch_dtype": "bfloat16", + "transformers_version": "4.46.2", +} + + +def load_image_encoders(args): + logger.info(f"Loading image encoder feature extractor") + feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG) + + # single file, split file and directory (contains 'image_encoder') support + logger.info(f"Loading image encoder from {args.image_encoder}") + if os.path.isdir(args.image_encoder): + # load from directory, configs are in the directory + image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder, subfolder="image_encoder", torch_dtype=torch.float16) + else: + # load from file, we create the model with the appropriate config + config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG) + with init_empty_weights(): + image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16) + + state_dict = load_file(args.image_encoder) + + image_encoder.load_state_dict(state_dict, strict=True, assign=True) + + image_encoder.eval() + return feature_extractor, image_encoder + + +# endregion diff --git a/frame_pack/hunyuan.py b/frame_pack/hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..5f8f9b446c50bdd65867a9147f40dfebd0de5fb2 --- /dev/null +++ b/frame_pack/hunyuan.py @@ -0,0 +1,116 @@ +# original code: https://github.com/lllyasviel/FramePack +# original license: Apache-2.0 + +import torch + +# from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE +# from diffusers_helper.utils import crop_or_pad_yield_mask +from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D +from hunyuan_model.text_encoder import PROMPT_TEMPLATE + + +@torch.no_grad() +def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256): + assert isinstance(prompt, str) + + prompt = [prompt] + + # LLAMA + + prompt_llama = [PROMPT_TEMPLATE["dit-llm-encode-video"]["template"].format(p) for p in prompt] + crop_start = PROMPT_TEMPLATE["dit-llm-encode-video"]["crop_start"] + + llama_inputs = tokenizer( + prompt_llama, + padding="max_length", + max_length=max_length + crop_start, + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + + llama_input_ids = llama_inputs.input_ids.to(text_encoder.device) + llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device) + llama_attention_length = int(llama_attention_mask.sum()) + + llama_outputs = text_encoder( + input_ids=llama_input_ids, + attention_mask=llama_attention_mask, + output_hidden_states=True, + ) + + llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length] + # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:] + llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length] + + assert torch.all(llama_attention_mask.bool()) + + # CLIP + + clip_l_input_ids = tokenizer_2( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ).input_ids + clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output + + return llama_vec, clip_l_pooler + + +@torch.no_grad() +def vae_decode_fake(latents): + latent_rgb_factors = [ + [-0.0395, -0.0331, 0.0445], + [0.0696, 0.0795, 0.0518], + [0.0135, -0.0945, -0.0282], + [0.0108, -0.0250, -0.0765], + [-0.0209, 0.0032, 0.0224], + [-0.0804, -0.0254, -0.0639], + [-0.0991, 0.0271, -0.0669], + [-0.0646, -0.0422, -0.0400], + [-0.0696, -0.0595, -0.0894], + [-0.0799, -0.0208, -0.0375], + [0.1166, 0.1627, 0.0962], + [0.1165, 0.0432, 0.0407], + [-0.2315, -0.1920, -0.1355], + [-0.0270, 0.0401, -0.0821], + [-0.0616, -0.0997, -0.0727], + [0.0249, -0.0469, -0.1703], + ] # From comfyui + + latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] + + weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None] + bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) + + images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1) + images = images.clamp(0.0, 1.0) + + return images + + +@torch.no_grad() +def vae_decode(latents, vae, image_mode=False) -> torch.Tensor: + latents = latents / vae.config.scaling_factor + + if not image_mode: + image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample + else: + latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2) + image = [vae.decode(l.unsqueeze(2)).sample for l in latents] + image = torch.cat(image, dim=2) + + return image + + +@torch.no_grad() +def vae_encode(image, vae: AutoencoderKLCausal3D) -> torch.Tensor: + latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + return latents diff --git a/frame_pack/hunyuan_video_packed.py b/frame_pack/hunyuan_video_packed.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a860ea524c2d3e204a1f80f8874c57892401ff --- /dev/null +++ b/frame_pack/hunyuan_video_packed.py @@ -0,0 +1,2049 @@ +# original code: https://github.com/lllyasviel/FramePack +# original license: Apache-2.0 + +import glob +import math +import numbers +import os +from types import SimpleNamespace +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import einops +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from modules.custom_offloading_utils import ModelOffloader +from utils.safetensors_utils import load_split_weights +from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8 +from accelerate import init_empty_weights + +try: + # raise NotImplementedError + from xformers.ops import memory_efficient_attention as xformers_attn_func + + print("Xformers is installed!") +except: + print("Xformers is not installed!") + xformers_attn_func = None + +try: + # raise NotImplementedError + from flash_attn import flash_attn_varlen_func, flash_attn_func + + print("Flash Attn is installed!") +except: + print("Flash Attn is not installed!") + flash_attn_varlen_func = None + flash_attn_func = None + +try: + # raise NotImplementedError + from sageattention import sageattn_varlen, sageattn + + print("Sage Attn is installed!") +except: + print("Sage Attn is not installed!") + sageattn_varlen = None + sageattn = None + + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +# region diffusers + +# copied from diffusers with some modifications to minimize dependencies +# original code: https://github.com/huggingface/diffusers/ +# original license: Apache-2.0 + +ACT2CLS = { + "swish": nn.SiLU, + "silu": nn.SiLU, + "mish": nn.Mish, + "gelu": nn.GELU, + "relu": nn.ReLU, +} + + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACT2CLS: + return ACT2CLS[act_fn]() + else: + raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}") + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class FP32SiLU(nn.Module): + r""" + SiLU activation function with input upcasted to torch.float32. + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return F.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + # if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): + # # fp16 gelu not supported on mps before torch 2.0 + # return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + return F.gelu(gate, approximate=self.approximate) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + elif act_fn == "silu_fp32": + self.act_1 = FP32SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LayerNormFramePack(nn.LayerNorm): + # casting to dtype of input tensor is added + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x) + + +class FP32LayerNormFramePack(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + origin_dtype = x.dtype + return torch.nn.functional.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class RMSNormFramePack(nn.Module): + r""" + RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al. + + Args: + dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True. + eps (`float`): Small value to use when calculating the reciprocal of the square-root. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + bias (`bool`, defaults to False): If also training the `bias` param. + """ + + def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False): + super().__init__() + + self.eps = eps + self.elementwise_affine = elementwise_affine + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + self.weight = None + self.bias = None + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + if bias: + self.bias = nn.Parameter(torch.zeros(dim)) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is None: + return hidden_states.to(input_dtype) + + return hidden_states.to(input_dtype) * self.weight.to(input_dtype) + + +class AdaLayerNormContinuousFramePack(nn.Module): + r""" + Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNormFramePack(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x, conditioning_embedding): + emb = self.linear(self.silu(conditioning_embedding)) + scale, shift = emb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class LinearActivation(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.activation = get_activation(activation) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + return self.activation(hidden_states) + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + # if activation_fn == "gelu": + # act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + # elif activation_fn == "geglu": + # act_fn = GEGLU(dim, inner_dim, bias=bias) + # elif activation_fn == "geglu-approximate": + # act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + # elif activation_fn == "swiglu": + # act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") + else: + raise ValueError(f"Unknown activation function: {activation_fn}") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + # deprecate("scale", "1.0.0", deprecation_message) + raise ValueError("scale is not supported in this version. Please remove it.") + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +# @maybe_allow_in_graph +class Attention(nn.Module): + r""" + Minimal copy of Attention class from diffusers. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + bias: bool = False, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + eps: float = 1e-5, + processor: Optional[any] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim # if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + + self.scale = dim_head**-0.5 + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.added_kv_proj_dim = added_kv_proj_dim + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "rms_norm": + self.norm_q = RMSNormFramePack(dim_head, eps=eps) + self.norm_k = RMSNormFramePack(dim_head, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + + self.added_proj_bias = True # added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=True)) + # self.to_out.append(nn.Dropout(dropout)) + self.to_out.append(nn.Identity()) # dropout=0.0 + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=True) + else: + self.to_add_out = None + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "rms_norm": + self.norm_added_q = RMSNormFramePack(dim_head, eps=eps) + self.norm_added_k = RMSNormFramePack(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`") + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + if processor is None: + processor = AttnProcessor2_0() + self.set_processor(processor) + + def set_processor(self, processor: any) -> None: + self.processor = processor + + def get_processor(self) -> any: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0, output_size=attention_mask.shape[0] * head_size) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1, output_size=attention_mask.shape[1] * head_size) + + return attention_mask + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + query_dtype = query.dtype # store dtype before potentially deleting query + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + del query, key, value, attention_mask # free memory + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query_dtype) # use stored dtype + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states + + +# endregion diffusers + + +def pad_for_3d_conv(x, kernel_size): + b, c, t, h, w = x.shape + pt, ph, pw = kernel_size + pad_t = (pt - (t % pt)) % pt + pad_h = (ph - (h % ph)) % ph + pad_w = (pw - (w % pw)) % pw + return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") + + +def center_down_sample_3d(x, kernel_size): + # pt, ph, pw = kernel_size + # cp = (pt * ph * pw) // 2 + # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw) + # xc = xp[cp] + # return xc + return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) + + +def get_cu_seqlens(text_mask, img_len): + batch_size = text_mask.shape[0] + text_len = text_mask.sum(dim=1) + max_len = text_mask.shape[1] + img_len + + cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device) # ensure device match + + for i in range(batch_size): + s = text_len[i] + img_len + s1 = i * max_len + s + s2 = (i + 1) * max_len + cu_seqlens[2 * i + 1] = s1 + cu_seqlens[2 * i + 2] = s2 + + return cu_seqlens + + +def apply_rotary_emb_transposed(x, freqs_cis): + cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) + del freqs_cis + x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + del x_real, x_imag + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=None, split_attn=False): + if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None: + if attn_mode == "sageattn" or attn_mode is None and sageattn is not None: + x = sageattn(q, k, v, tensor_layout="NHD") + return x + + if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None: + x = flash_attn_func(q, k, v) + return x + + if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None: + x = xformers_attn_func(q, k, v) + return x + + x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose( + 1, 2 + ) + return x + if split_attn: + if attn_mode == "sageattn" or attn_mode is None and sageattn is not None: + x = torch.empty_like(q) + for i in range(q.size(0)): + x[i : i + 1] = sageattn(q[i : i + 1], k[i : i + 1], v[i : i + 1], tensor_layout="NHD") + return x + + if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None: + x = torch.empty_like(q) + for i in range(q.size(0)): + x[i : i + 1] = flash_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1]) + return x + + if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None: + x = torch.empty_like(q) + for i in range(q.size(0)): + x[i : i + 1] = xformers_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1]) + return x + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + x = torch.empty_like(q) + for i in range(q.size(0)): + x[i : i + 1] = torch.nn.functional.scaled_dot_product_attention(q[i : i + 1], k[i : i + 1], v[i : i + 1]) + x = x.transpose(1, 2) + return x + + batch_size = q.shape[0] + q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) + k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) + v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) + if attn_mode == "sageattn" or attn_mode is None and sageattn_varlen is not None: + x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + del q, k, v # free memory + elif attn_mode == "flash" or attn_mode is None and flash_attn_varlen_func is not None: + x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + del q, k, v # free memory + else: + raise NotImplementedError("No Attn Installed or batch_size > 1 is not supported in this configuration. Try `--split_attn`.") + x = x.view(batch_size, max_seqlen_q, *x.shape[2:]) + return x + + +class HunyuanAttnProcessorFlashAttnDouble: + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states, + attention_mask, + image_rotary_emb, + attn_mode: Optional[str] = None, + split_attn: Optional[bool] = False, + ): + cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask + + # Project image latents + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + del hidden_states # free memory + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = apply_rotary_emb_transposed(query, image_rotary_emb) + key = apply_rotary_emb_transposed(key, image_rotary_emb) + del image_rotary_emb # free memory + + # Project context (text/encoder) embeddings + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + txt_length = encoder_hidden_states.shape[1] # store length before deleting + del encoder_hidden_states # free memory + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + # Concatenate image and context q, k, v + query = torch.cat([query, encoder_query], dim=1) + key = torch.cat([key, encoder_key], dim=1) + value = torch.cat([value, encoder_value], dim=1) + del encoder_query, encoder_key, encoder_value # free memory + + hidden_states_attn = attn_varlen_func( + query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn + ) + del query, key, value # free memory + hidden_states_attn = hidden_states_attn.flatten(-2) + + hidden_states, encoder_hidden_states = hidden_states_attn[:, :-txt_length], hidden_states_attn[:, -txt_length:] + del hidden_states_attn # free memory + + # Apply output projections + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) # Dropout/Identity + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanAttnProcessorFlashAttnSingle: + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states, + attention_mask, + image_rotary_emb, + attn_mode: Optional[str] = None, + split_attn: Optional[bool] = False, + ): + cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask + txt_length = encoder_hidden_states.shape[1] # Store text length + + # Concatenate image and context inputs + hidden_states_cat = torch.cat([hidden_states, encoder_hidden_states], dim=1) + del hidden_states, encoder_hidden_states # free memory + + # Project concatenated inputs + query = attn.to_q(hidden_states_cat) + key = attn.to_k(hidden_states_cat) + value = attn.to_v(hidden_states_cat) + del hidden_states_cat # free memory + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1) + key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1) + del image_rotary_emb # free memory + + hidden_states = attn_varlen_func( + query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn + ) + del query, key, value # free memory + hidden_states = hidden_states.flatten(-2) + + hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] + + return hidden_states, encoder_hidden_states + + +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=-1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + # Self-attention + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + del norm_hidden_states # free memory + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + del attn_output, gate_msa # free memory + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + del ff_output, gate_mlp # free memory + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + self_attn_mask = None + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + self_attn_mask[:, :, :, 0] = True + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings(embedding_dim=hidden_size, pooled_projection_dim=in_channels) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + del pooled_projections # free memory + + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + del temb, attention_mask # free memory + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, rope_dim, theta): + super().__init__() + self.DT, self.DY, self.DX = rope_dim + self.theta = theta + + @torch.no_grad() + def get_frequency(self, dim, pos): + T, H, W = pos.shape + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim)) + freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0) + return freqs.cos(), freqs.sin() + + @torch.no_grad() + def forward_inner(self, frame_indices, height, width, device): + GT, GY, GX = torch.meshgrid( + frame_indices.to(device=device, dtype=torch.float32), + torch.arange(0, height, device=device, dtype=torch.float32), + torch.arange(0, width, device=device, dtype=torch.float32), + indexing="ij", + ) + + FCT, FST = self.get_frequency(self.DT, GT) + del GT # free memory + FCY, FSY = self.get_frequency(self.DY, GY) + del GY # free memory + FCX, FSX = self.get_frequency(self.DX, GX) + del GX # free memory + + result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) + del FCT, FCY, FCX, FST, FSY, FSX # free memory + + # Return result already on the correct device + return result # Shape (2 * total_dim / 2, T, H, W) -> (total_dim, T, H, W) + + @torch.no_grad() + def forward(self, frame_indices, height, width, device): + frame_indices = frame_indices.unbind(0) + results = [self.forward_inner(f, height, width, device) for f in frame_indices] + results = torch.stack(results, dim=0) + return results + + +class AdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward( + self, x: torch.Tensor, emb: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = emb.unsqueeze(-2) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormZeroSingle(nn.Module): + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + emb = emb.unsqueeze(-2) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + emb = emb.unsqueeze(-2) + emb = self.linear(self.silu(emb)) + scale, shift = emb.chunk(2, dim=-1) + del emb # free memory + x = self.norm(x) * (1 + scale) + shift + return x + + +class HunyuanVideoSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + attn_mode: Optional[str] = None, + split_attn: Optional[bool] = False, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + self.attn_mode = attn_mode + self.split_attn = split_attn + + # Attention layer (pre_only=True means no output projection in Attention module itself) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanAttnProcessorFlashAttnSingle(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, # Crucial: Attn processor will return raw attention output + ) + + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + del encoder_hidden_states # free memory + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + attn_mode=self.attn_mode, + split_attn=self.split_attn, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + del norm_hidden_states, norm_encoder_hidden_states, context_attn_output # free memory + del image_rotary_emb + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + del attn_output, mlp_hidden_states # free memory + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + attn_mode: Optional[str] = None, + split_attn: Optional[bool] = False, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + self.attn_mode = attn_mode + self.split_attn = split_attn + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanAttnProcessorFlashAttnDouble(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + attn_mode=self.attn_mode, + split_attn=self.split_attn, + ) + del norm_hidden_states, norm_encoder_hidden_states, freqs_cis # free memory + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa + del attn_output, gate_msa # free memory + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa + del context_attn_output, c_gate_msa # free memory + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + del shift_mlp, scale_mlp # free memory + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + del c_shift_mlp, c_scale_mlp # free memory + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + del norm_hidden_states # free memory + context_ff_output = self.ff_context(norm_encoder_hidden_states) + del norm_encoder_hidden_states # free memory + + hidden_states = hidden_states + gate_mlp * ff_output + del ff_output, gate_mlp # free memory + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + del context_ff_output, c_gate_mlp # free memory + + return hidden_states, encoder_hidden_states + + +class ClipVisionProjection(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.up = nn.Linear(in_channels, out_channels * 3) + self.down = nn.Linear(out_channels * 3, out_channels) + + def forward(self, x): + projected_x = self.down(nn.functional.silu(self.up(x))) + return projected_x + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__(self, patch_size, in_chans, embed_dim): + super().__init__() + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + +class HunyuanVideoPatchEmbedForCleanLatents(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + @torch.no_grad() + def initialize_weight_from_another_conv3d(self, another_layer): + weight = another_layer.weight.detach().clone() + bias = another_layer.bias.detach().clone() + + sd = { + "proj.weight": weight.clone(), + "proj.bias": bias.clone(), + "proj_2x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=2, hk=2, wk=2) / 8.0, + "proj_2x.bias": bias.clone(), + "proj_4x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=4, hk=4, wk=4) / 64.0, + "proj_4x.bias": bias.clone(), + } + + sd = {k: v.clone() for k, v in sd.items()} + + self.load_state_dict(sd) + return + + +class HunyuanVideoTransformer3DModelPacked(nn.Module): # (PreTrainedModelMixin, GenerationMixin, + # ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + # @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), + has_image_proj=False, + image_proj_dim=1152, + has_clean_x_embedder=False, + attn_mode: Optional[str] = None, + split_attn: Optional[bool] = False, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + self.config_patch_size = patch_size + self.config_patch_size_t = patch_size_t + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + + self.clean_x_embedder = None + self.image_projection = None + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, + attention_head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + attn_mode=attn_mode, + split_attn=split_attn, + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, + attention_head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + attn_mode=attn_mode, + split_attn=split_attn, + ) + for _ in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.inner_dim = inner_dim + self.use_gradient_checkpointing = False + self.enable_teacache = False + + # if has_image_proj: + # self.install_image_projection(image_proj_dim) + self.image_projection = ClipVisionProjection(in_channels=image_proj_dim, out_channels=self.inner_dim) + # self.config["has_image_proj"] = True + # self.config["image_proj_dim"] = in_channels + + # if has_clean_x_embedder: + # self.install_clean_x_embedder() + self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim) + # self.config["has_clean_x_embedder"] = True + + self.high_quality_fp32_output_for_inference = True # False # change default to True + + # Block swapping attributes (initialized to None) + self.blocks_to_swap = None + self.offloader_double = None + self.offloader_single = None + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.use_gradient_checkpointing = True + print("Gradient checkpointing enabled for HunyuanVideoTransformer3DModelPacked.") # Logging + + def disable_gradient_checkpointing(self): + self.use_gradient_checkpointing = False + print("Gradient checkpointing disabled for HunyuanVideoTransformer3DModelPacked.") # Logging + + def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15): + self.enable_teacache = enable_teacache + self.cnt = 0 + self.num_steps = num_steps + self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_residual = None + self.teacache_rescale_func = np.poly1d([7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]) + if enable_teacache: + print(f"TeaCache enabled: num_steps={num_steps}, rel_l1_thresh={rel_l1_thresh}") + else: + print("TeaCache disabled.") + + def gradient_checkpointing_method(self, block, *args): + if self.use_gradient_checkpointing: + result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False) + else: + result = block(*args) + return result + + def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool): + self.blocks_to_swap = num_blocks + self.num_double_blocks = len(self.transformer_blocks) + self.num_single_blocks = len(self.single_transformer_blocks) + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1 + + assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, ( + f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = ModelOffloader( + "double", + self.transformer_blocks, + self.num_double_blocks, + double_blocks_to_swap, + supports_backward, + device, + # debug=True # Optional debugging + ) + self.offloader_single = ModelOffloader( + "single", + self.single_transformer_blocks, + self.num_single_blocks, + single_blocks_to_swap, + supports_backward, + device, # , debug=True + ) + print( + f"HunyuanVideoTransformer3DModelPacked: Block swap enabled. Swapping {num_blocks} blocks, " + + f"double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}, supports_backward: {supports_backward}." + ) + + def switch_block_swap_for_inference(self): + if self.blocks_to_swap and self.blocks_to_swap > 0: + self.offloader_double.set_forward_only(True) + self.offloader_single.set_forward_only(True) + self.prepare_block_swap_before_forward() + print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward only.") + + def switch_block_swap_for_training(self): + if self.blocks_to_swap and self.blocks_to_swap > 0: + self.offloader_double.set_forward_only(False) + self.offloader_single.set_forward_only(False) + self.prepare_block_swap_before_forward() + print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward and backward.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + saved_double_blocks = self.transformer_blocks + saved_single_blocks = self.single_transformer_blocks + self.transformer_blocks = None + self.single_transformer_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.transformer_blocks = saved_double_blocks + self.single_transformer_blocks = saved_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.transformer_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_transformer_blocks) + + def process_input_hidden_states( + self, + latents, + latent_indices=None, + clean_latents=None, + clean_latent_indices=None, + clean_latents_2x=None, + clean_latent_2x_indices=None, + clean_latents_4x=None, + clean_latent_4x_indices=None, + ): + hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents) + B, C, T, H, W = hidden_states.shape + + if latent_indices is None: + latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1) + + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device) + rope_freqs = rope_freqs.flatten(2).transpose(1, 2) + + if clean_latents is not None and clean_latent_indices is not None: + clean_latents = clean_latents.to(hidden_states) + clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents) + clean_latents = clean_latents.flatten(2).transpose(1, 2) + + clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device) + clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([clean_latents, hidden_states], dim=1) + rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1) + + if clean_latents_2x is not None and clean_latent_2x_indices is not None: + clean_latents_2x = clean_latents_2x.to(hidden_states) + clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4)) + clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x) + clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) + + clean_latent_2x_rope_freqs = self.rope( + frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device + ) + clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2)) + clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2)) + clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1) + rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1) + + if clean_latents_4x is not None and clean_latent_4x_indices is not None: + clean_latents_4x = clean_latents_4x.to(hidden_states) + clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8)) + clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x) + clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) + + clean_latent_4x_rope_freqs = self.rope( + frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device + ) + clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4)) + clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4)) + clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1) + rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1) + + return hidden_states, rope_freqs + + def forward( + self, + hidden_states, + timestep, + encoder_hidden_states, + encoder_attention_mask, + pooled_projections, + guidance, + latent_indices=None, + clean_latents=None, + clean_latent_indices=None, + clean_latents_2x=None, + clean_latent_2x_indices=None, + clean_latents_4x=None, + clean_latent_4x_indices=None, + image_embeddings=None, + attention_kwargs=None, + return_dict=True, + ): + + if attention_kwargs is None: + attention_kwargs = {} + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p, p_t = self.config_patch_size, self.config_patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + original_context_length = post_patch_num_frames * post_patch_height * post_patch_width + + hidden_states, rope_freqs = self.process_input_hidden_states( + hidden_states, + latent_indices, + clean_latents, + clean_latent_indices, + clean_latents_2x, + clean_latent_2x_indices, + clean_latents_4x, + clean_latent_4x_indices, + ) + del ( + latent_indices, + clean_latents, + clean_latent_indices, + clean_latents_2x, + clean_latent_2x_indices, + clean_latents_4x, + clean_latent_4x_indices, + ) # free memory + + temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections) + encoder_hidden_states = self.gradient_checkpointing_method( + self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask + ) + + if self.image_projection is not None: + assert image_embeddings is not None, "You must use image embeddings!" + extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings) + extra_attention_mask = torch.ones( + (batch_size, extra_encoder_hidden_states.shape[1]), + dtype=encoder_attention_mask.dtype, + device=encoder_attention_mask.device, + ) + + # must cat before (not after) encoder_hidden_states, due to attn masking + encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1) + encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1) + del extra_encoder_hidden_states, extra_attention_mask # free memory + + with torch.no_grad(): + if batch_size == 1: + # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want + # If they are not same, then their impls are wrong. Ours are always the correct one. + text_len = encoder_attention_mask.sum().item() + encoder_hidden_states = encoder_hidden_states[:, :text_len] + attention_mask = None, None, None, None + else: + img_seq_len = hidden_states.shape[1] + txt_seq_len = encoder_hidden_states.shape[1] + + cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len) + cu_seqlens_kv = cu_seqlens_q + max_seqlen_q = img_seq_len + txt_seq_len + max_seqlen_kv = max_seqlen_q + + attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + del cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv # free memory + del encoder_attention_mask # free memory + + if self.enable_teacache: + modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] + + if self.cnt == 0 or self.cnt == self.num_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + # Ensure both tensors are on the same device before comparison + prev_input = self.previous_modulated_input.to(modulated_inp.device) + curr_rel_l1 = ( + ((modulated_inp - prev_input).abs().mean() / prev_input.abs().mean()) + .cpu() + .item() + ) + self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1) + should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh + + if should_calc: + self.accumulated_rel_l1_distance = 0 + + # Explicitly store the tensor on the current device + self.previous_modulated_input = modulated_inp.detach().clone() + self.cnt += 1 + + if self.cnt == self.num_steps: + self.cnt = 0 + + if not should_calc: + # Ensure residual is on the same device as hidden_states + hidden_states = hidden_states + self.previous_residual.to(hidden_states.device) + else: + ori_hidden_states = hidden_states.clone() + + # --- BEFORE --- + # for block_id, block in enumerate(self.transformer_blocks): + # hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + # block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs + # ) + # + # for block_id, block in enumerate(self.single_transformer_blocks): + # hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + # block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs + # ) + # --- AFTER --- + for block_id, block in enumerate(self.transformer_blocks): + if self.blocks_to_swap: # Add block swap logic here + self.offloader_double.wait_for_block(block_id) + + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs + ) + + if self.blocks_to_swap: # Add block swap logic here + self.offloader_double.submit_move_blocks_forward(self.transformer_blocks, block_id) + + for block_id, block in enumerate(self.single_transformer_blocks): + if self.blocks_to_swap: # Add block swap logic here + self.offloader_single.wait_for_block(block_id) + + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs + ) + + if self.blocks_to_swap: # Add block swap logic here + self.offloader_single.submit_move_blocks_forward(self.single_transformer_blocks, block_id) + # --- END MODIFICATION --- + + # Store residual on the same device + self.previous_residual = (hidden_states - ori_hidden_states).detach().clone() + del ori_hidden_states + else: + for block_id, block in enumerate(self.transformer_blocks): + if self.blocks_to_swap: + self.offloader_double.wait_for_block(block_id) + + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs + ) + + if self.blocks_to_swap: + self.offloader_double.submit_move_blocks_forward(self.transformer_blocks, block_id) + + for block_id, block in enumerate(self.single_transformer_blocks): + if self.blocks_to_swap: + self.offloader_single.wait_for_block(block_id) + + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs + ) + + if self.blocks_to_swap: + self.offloader_single.submit_move_blocks_forward(self.single_transformer_blocks, block_id) + + del attention_mask, rope_freqs # free memory + del encoder_hidden_states # free memory + + hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb) + + hidden_states = hidden_states[:, -original_context_length:, :] + + if self.high_quality_fp32_output_for_inference: + hidden_states = hidden_states.to(dtype=torch.float32) + if self.proj_out.weight.dtype != torch.float32: + self.proj_out.to(dtype=torch.float32) + + hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states) + + hidden_states = einops.rearrange( + hidden_states, + "b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)", + t=post_patch_num_frames, + h=post_patch_height, + w=post_patch_width, + pt=p_t, + ph=p, + pw=p, + ) + + if return_dict: + # return Transformer2DModelOutput(sample=hidden_states) + return SimpleNamespace(sample=hidden_states) + + return (hidden_states,) + + def fp8_optimization( + self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool, use_scaled_mm: bool = False + ) -> dict[str, torch.Tensor]: # Return type hint added + """ + Optimize the model state_dict with fp8. + + Args: + state_dict (dict[str, torch.Tensor]): + The state_dict of the model. + device (torch.device): + The device to calculate the weight. + move_to_device (bool): + Whether to move the weight to the device after optimization. + use_scaled_mm (bool): + Whether to use scaled matrix multiplication for FP8. + """ + TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"] + EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8 + + # inplace optimization + state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device) + + # apply monkey patching + apply_fp8_monkey_patch(self, state_dict, use_scaled_mm=use_scaled_mm) + + return state_dict + + +def create_hunyuan_video_transformer_3d_model(attn_mode: str, split_attn: bool = False) -> HunyuanVideoTransformer3DModelPacked: + with init_empty_weights(): + logger.info(f"Creating HunyuanVideoTransformer3DModelPacked") + model = HunyuanVideoTransformer3DModelPacked( + attention_head_dim=128, + guidance_embeds=True, + has_clean_x_embedder=True, + has_image_proj=True, + image_proj_dim=1152, + in_channels=16, + mlp_ratio=4.0, + num_attention_heads=24, + num_layers=20, + num_refiner_layers=2, + num_single_layers=40, + out_channels=16, + patch_size=2, + patch_size_t=1, + pooled_projection_dim=768, + qk_norm="rms_norm", + rope_axes_dim=(16, 56, 56), + rope_theta=256.0, + text_embed_dim=4096, + attn_mode=attn_mode, + split_attn=split_attn, + ) + return model + + +def load_packed_model( + device: Union[str, torch.device], + dit_path: str, + attn_mode: str, + loading_device: Union[str, torch.device], + fp8_scaled: bool = False, + split_attn: bool = False, +) -> HunyuanVideoTransformer3DModelPacked: + # TODO support split_attn + device = torch.device(device) + loading_device = torch.device(loading_device) + + if os.path.isdir(dit_path): + # we don't support from_pretrained for now, so loading safetensors directly + safetensor_files = glob.glob(os.path.join(dit_path, "*.safetensors")) + if len(safetensor_files) == 0: + raise ValueError(f"Cannot find safetensors file in {dit_path}") + # sort by name and take the first one + safetensor_files.sort() + dit_path = safetensor_files[0] + + model = create_hunyuan_video_transformer_3d_model(attn_mode, split_attn=split_attn) + + # if fp8_scaled, load model weights to CPU to reduce VRAM usage. Otherwise, load to the specified device (CPU for block swap or CUDA for others) + dit_loading_device = torch.device("cpu") if fp8_scaled else loading_device + logger.info(f"Loading DiT model from {dit_path}, device={dit_loading_device}") + + # load model weights with the specified dtype or as is + sd = load_split_weights(dit_path, device=dit_loading_device, disable_mmap=True) + + if fp8_scaled: + # fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap) + logger.info(f"Optimizing model weights to fp8. This may take a while.") + sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu") + + if loading_device.type != "cpu": + # make sure all the model weights are on the loading_device + logger.info(f"Moving weights to {loading_device}") + for key in sd.keys(): + sd[key] = sd[key].to(loading_device) + + info = model.load_state_dict(sd, strict=True, assign=True) + logger.info(f"Loaded DiT model from {dit_path}, info={info}") + + return model diff --git a/frame_pack/k_diffusion_hunyuan.py b/frame_pack/k_diffusion_hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..60524eae0d6c9571ee90164ce520ba41cd8d3d20 --- /dev/null +++ b/frame_pack/k_diffusion_hunyuan.py @@ -0,0 +1,128 @@ +# original code: https://github.com/lllyasviel/FramePack +# original license: Apache-2.0 + +import torch +import math + +# from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc +# from diffusers_helper.k_diffusion.wrapper import fm_wrapper +# from diffusers_helper.utils import repeat_to_batch_size +from frame_pack.uni_pc_fm import sample_unipc +from frame_pack.wrapper import fm_wrapper +from frame_pack.utils import repeat_to_batch_size + + +def flux_time_shift(t, mu=1.15, sigma=1.0): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0): + k = (y2 - y1) / (x2 - x1) + b = y1 - k * x1 + mu = k * context_length + b + mu = min(mu, math.log(exp_max)) + return mu + + +def get_flux_sigmas_from_mu(n, mu): + sigmas = torch.linspace(1, 0, steps=n + 1) + sigmas = flux_time_shift(sigmas, mu=mu) + return sigmas + + +# @torch.inference_mode() +def sample_hunyuan( + transformer, + sampler="unipc", + initial_latent=None, + concat_latent=None, + strength=1.0, + width=512, + height=512, + frames=16, + real_guidance_scale=1.0, + distilled_guidance_scale=6.0, + guidance_rescale=0.0, + shift=None, + num_inference_steps=25, + batch_size=None, + generator=None, + prompt_embeds=None, + prompt_embeds_mask=None, + prompt_poolers=None, + negative_prompt_embeds=None, + negative_prompt_embeds_mask=None, + negative_prompt_poolers=None, + dtype=torch.bfloat16, + device=None, + negative_kwargs=None, + callback=None, + **kwargs, +): + device = device or transformer.device + + if batch_size is None: + batch_size = int(prompt_embeds.shape[0]) + + latents = torch.randn( + (batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device + ).to(device=device, dtype=torch.float32) + + B, C, T, H, W = latents.shape + seq_length = T * H * W // 4 # 9*80*80//4 = 14400 + + if shift is None: + mu = calculate_flux_mu(seq_length, exp_max=7.0) # 1.9459... if seq_len is large, mu is clipped. + else: + mu = math.log(shift) + + sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device) + + k_model = fm_wrapper(transformer) + + if initial_latent is not None: + sigmas = sigmas * strength + first_sigma = sigmas[0].to(device=device, dtype=torch.float32) + initial_latent = initial_latent.to(device=device, dtype=torch.float32) + latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma + + if concat_latent is not None: + concat_latent = concat_latent.to(latents) + + distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype) + + prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size) + prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size) + prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size) + negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size) + negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size) + negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size) + concat_latent = repeat_to_batch_size(concat_latent, batch_size) + + sampler_kwargs = dict( + dtype=dtype, + cfg_scale=real_guidance_scale, + cfg_rescale=guidance_rescale, + concat_latent=concat_latent, + positive=dict( + pooled_projections=prompt_poolers, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_embeds_mask, + guidance=distilled_guidance, + **kwargs, + ), + negative=dict( + pooled_projections=negative_prompt_poolers, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_embeds_mask, + guidance=distilled_guidance, + **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}), + ), + ) + + if sampler == "unipc": + results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback) + else: + raise NotImplementedError(f"Sampler {sampler} is not supported.") + + return results diff --git a/frame_pack/uni_pc_fm.py b/frame_pack/uni_pc_fm.py new file mode 100644 index 0000000000000000000000000000000000000000..43a198f9f1c408b8c84b47a675c871aaf71bc418 --- /dev/null +++ b/frame_pack/uni_pc_fm.py @@ -0,0 +1,142 @@ +# Better Flow Matching UniPC by Lvmin Zhang +# (c) 2025 +# CC BY-SA 4.0 +# Attribution-ShareAlike 4.0 International Licence + + +import torch + +from tqdm.auto import trange + + +def expand_dims(v, dims): + return v[(...,) + (None,) * (dims - 1)] + + +class FlowMatchUniPC: + def __init__(self, model, extra_args, variant='bh1'): + self.model = model + self.variant = variant + self.extra_args = extra_args + + def model_fn(self, x, t): + return self.model(x, t, **self.extra_args) + + def update_fn(self, x, model_prev_list, t_prev_list, t, order): + assert order <= len(model_prev_list) + dims = x.dim() + + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = - torch.log(t_prev_0) + lambda_t = - torch.log(t) + model_prev_0 = model_prev_list[-1] + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = - torch.log(t_prev_i) + rk = ((lambda_prev_i - lambda_prev_0) / h)[0] + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h[0] + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError('Bad variant!') + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=x.device) + + use_predictor = len(D1s) > 0 + + if use_predictor: + D1s = torch.stack(D1s, dim=1) + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + rhos_p = None + + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0 + + if use_predictor: + pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) + else: + pred_res = 0 + + x_t = x_t_ - expand_dims(B_h, dims) * pred_res + model_t = self.model_fn(x_t, t) + + if D1s is not None: + corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) + else: + corr_res = 0 + + D1_t = (model_t - model_prev_0) + x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + + return x_t, model_t + + def sample(self, x, sigmas, callback=None, disable_pbar=False): + order = min(3, len(sigmas) - 2) + model_prev_list, t_prev_list = [], [] + for i in trange(len(sigmas) - 1, disable=disable_pbar): + vec_t = sigmas[i].expand(x.shape[0]) + + with torch.no_grad(): + if i == 0: + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + elif i < order: + init_order = i + x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + else: + x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + + model_prev_list = model_prev_list[-order:] + t_prev_list = t_prev_list[-order:] + + if callback is not None: + callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]}) + + return model_prev_list[-1] + + +def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): + assert variant in ['bh1', 'bh2'] + return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable) diff --git a/frame_pack/utils.py b/frame_pack/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f69bd5fea25b2e8a6c80a774c0bd8eeb5926d0a7 --- /dev/null +++ b/frame_pack/utils.py @@ -0,0 +1,617 @@ +import os +import cv2 +import json +import random +import glob +import torch +import einops +import numpy as np +import datetime +import torchvision + +import safetensors.torch as sf +from PIL import Image + + +def min_resize(x, m): + if x.shape[0] < x.shape[1]: + s0 = m + s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) + else: + s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) + s1 = m + new_max = max(s1, s0) + raw_max = max(x.shape[0], x.shape[1]) + if new_max < raw_max: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + y = cv2.resize(x, (s1, s0), interpolation=interpolation) + return y + + +def d_resize(x, y): + H, W, C = y.shape + new_min = min(H, W) + raw_min = min(x.shape[0], x.shape[1]) + if new_min < raw_min: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + y = cv2.resize(x, (W, H), interpolation=interpolation) + return y + + +def resize_and_center_crop(image, target_width, target_height): + if target_height == image.shape[0] and target_width == image.shape[1]: + return image + + pil_image = Image.fromarray(image) + original_width, original_height = pil_image.size + scale_factor = max(target_width / original_width, target_height / original_height) + resized_width = int(round(original_width * scale_factor)) + resized_height = int(round(original_height * scale_factor)) + resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) + left = (resized_width - target_width) / 2 + top = (resized_height - target_height) / 2 + right = (resized_width + target_width) / 2 + bottom = (resized_height + target_height) / 2 + cropped_image = resized_image.crop((left, top, right, bottom)) + return np.array(cropped_image) + + +def resize_and_center_crop_pytorch(image, target_width, target_height): + B, C, H, W = image.shape + + if H == target_height and W == target_width: + return image + + scale_factor = max(target_width / W, target_height / H) + resized_width = int(round(W * scale_factor)) + resized_height = int(round(H * scale_factor)) + + resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode="bilinear", align_corners=False) + + top = (resized_height - target_height) // 2 + left = (resized_width - target_width) // 2 + cropped = resized[:, :, top : top + target_height, left : left + target_width] + + return cropped + + +def resize_without_crop(image, target_width, target_height): + if target_height == image.shape[0] and target_width == image.shape[1]: + return image + + pil_image = Image.fromarray(image) + resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) + return np.array(resized_image) + + +def just_crop(image, w, h): + if h == image.shape[0] and w == image.shape[1]: + return image + + original_height, original_width = image.shape[:2] + k = min(original_height / h, original_width / w) + new_width = int(round(w * k)) + new_height = int(round(h * k)) + x_start = (original_width - new_width) // 2 + y_start = (original_height - new_height) // 2 + cropped_image = image[y_start : y_start + new_height, x_start : x_start + new_width] + return cropped_image + + +def write_to_json(data, file_path): + temp_file_path = file_path + ".tmp" + with open(temp_file_path, "wt", encoding="utf-8") as temp_file: + json.dump(data, temp_file, indent=4) + os.replace(temp_file_path, file_path) + return + + +def read_from_json(file_path): + with open(file_path, "rt", encoding="utf-8") as file: + data = json.load(file) + return data + + +def get_active_parameters(m): + return {k: v for k, v in m.named_parameters() if v.requires_grad} + + +def cast_training_params(m, dtype=torch.float32): + result = {} + for n, param in m.named_parameters(): + if param.requires_grad: + param.data = param.to(dtype) + result[n] = param + return result + + +def separate_lora_AB(parameters, B_patterns=None): + parameters_normal = {} + parameters_B = {} + + if B_patterns is None: + B_patterns = [".lora_B.", "__zero__"] + + for k, v in parameters.items(): + if any(B_pattern in k for B_pattern in B_patterns): + parameters_B[k] = v + else: + parameters_normal[k] = v + + return parameters_normal, parameters_B + + +def set_attr_recursive(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + setattr(obj, attrs[-1], value) + return + + +def print_tensor_list_size(tensors): + total_size = 0 + total_elements = 0 + + if isinstance(tensors, dict): + tensors = tensors.values() + + for tensor in tensors: + total_size += tensor.nelement() * tensor.element_size() + total_elements += tensor.nelement() + + total_size_MB = total_size / (1024**2) + total_elements_B = total_elements / 1e9 + + print(f"Total number of tensors: {len(tensors)}") + print(f"Total size of tensors: {total_size_MB:.2f} MB") + print(f"Total number of parameters: {total_elements_B:.3f} billion") + return + + +@torch.no_grad() +def batch_mixture(a, b=None, probability_a=0.5, mask_a=None): + batch_size = a.size(0) + + if b is None: + b = torch.zeros_like(a) + + if mask_a is None: + mask_a = torch.rand(batch_size) < probability_a + + mask_a = mask_a.to(a.device) + mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) + result = torch.where(mask_a, a, b) + return result + + +@torch.no_grad() +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +@torch.no_grad() +def supress_lower_channels(m, k, alpha=0.01): + data = m.weight.data.clone() + + assert int(data.shape[1]) >= k + + data[:, :k] = data[:, :k] * alpha + m.weight.data = data.contiguous().clone() + return m + + +def freeze_module(m): + if not hasattr(m, "_forward_inside_frozen_module"): + m._forward_inside_frozen_module = m.forward + m.requires_grad_(False) + m.forward = torch.no_grad()(m.forward) + return m + + +def get_latest_safetensors(folder_path): + safetensors_files = glob.glob(os.path.join(folder_path, "*.safetensors")) + + if not safetensors_files: + raise ValueError("No file to resume!") + + latest_file = max(safetensors_files, key=os.path.getmtime) + latest_file = os.path.abspath(os.path.realpath(latest_file)) + return latest_file + + +def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): + tags = tags_str.split(", ") + tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) + prompt = ", ".join(tags) + return prompt + + +def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0): + numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma) + if round_to_int: + numbers = np.round(numbers).astype(int) + return numbers.tolist() + + +def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False): + edges = np.linspace(0, 1, n + 1) + points = np.random.uniform(edges[:-1], edges[1:]) + numbers = inclusive + (exclusive - inclusive) * points + if round_to_int: + numbers = np.round(numbers).astype(int) + return numbers.tolist() + + +def soft_append_bcthw(history, current, overlap=0): + if overlap <= 0: + return torch.cat([history, current], dim=2) + + assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})" + assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})" + + weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) + blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] + output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) + + return output.to(history) + + +def save_bcthw_as_mp4(x, output_filename, fps=10): + b, c, t, h, w = x.shape + + per_row = b + for p in [6, 5, 4, 3, 2]: + if b % p == 0: + per_row = p + break + + os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) + x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5 + x = x.detach().cpu().to(torch.uint8) + x = einops.rearrange(x, "(m n) c t h w -> t (m h) (n w) c", n=per_row) + torchvision.io.write_video(output_filename, x, fps=fps, video_codec="libx264", options={"crf": "0"}) + + # write tensor as .pt file + torch.save(x, output_filename.replace(".mp4", ".pt")) + + return x + + +def save_bcthw_as_png(x, output_filename): + os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) + x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5 + x = x.detach().cpu().to(torch.uint8) + x = einops.rearrange(x, "b c t h w -> c (b h) (t w)") + torchvision.io.write_png(x, output_filename) + return output_filename + + +def save_bchw_as_png(x, output_filename): + os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) + x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5 + x = x.detach().cpu().to(torch.uint8) + x = einops.rearrange(x, "b c h w -> c h (b w)") + torchvision.io.write_png(x, output_filename) + return output_filename + + +def add_tensors_with_padding(tensor1, tensor2): + if tensor1.shape == tensor2.shape: + return tensor1 + tensor2 + + shape1 = tensor1.shape + shape2 = tensor2.shape + + new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) + + padded_tensor1 = torch.zeros(new_shape) + padded_tensor2 = torch.zeros(new_shape) + + padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 + padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 + + result = padded_tensor1 + padded_tensor2 + return result + + +def print_free_mem(): + torch.cuda.empty_cache() + free_mem, total_mem = torch.cuda.mem_get_info(0) + free_mem_mb = free_mem / (1024**2) + total_mem_mb = total_mem / (1024**2) + print(f"Free memory: {free_mem_mb:.2f} MB") + print(f"Total memory: {total_mem_mb:.2f} MB") + return + + +def print_gpu_parameters(device, state_dict, log_count=1): + summary = {"device": device, "keys_count": len(state_dict)} + + logged_params = {} + for i, (key, tensor) in enumerate(state_dict.items()): + if i >= log_count: + break + logged_params[key] = tensor.flatten()[:3].tolist() + + summary["params"] = logged_params + + print(str(summary)) + return + + +def visualize_txt_as_img(width, height, text, font_path="font/DejaVuSans.ttf", size=18): + from PIL import Image, ImageDraw, ImageFont + + txt = Image.new("RGB", (width, height), color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype(font_path, size=size) + + if text == "": + return np.array(txt) + + # Split text into lines that fit within the image width + lines = [] + words = text.split() + current_line = words[0] + + for word in words[1:]: + line_with_word = f"{current_line} {word}" + if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width: + current_line = line_with_word + else: + lines.append(current_line) + current_line = word + + lines.append(current_line) + + # Draw the text line by line + y = 0 + line_height = draw.textbbox((0, 0), "A", font=font)[3] + + for line in lines: + if y + line_height > height: + break # stop drawing if the next line will be outside the image + draw.text((0, y), line, fill="black", font=font) + y += line_height + + return np.array(txt) + + +def blue_mark(x): + x = x.copy() + c = x[:, :, 2] + b = cv2.blur(c, (9, 9)) + x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1) + return x + + +def green_mark(x): + x = x.copy() + x[:, :, 2] = -1 + x[:, :, 0] = -1 + return x + + +def frame_mark(x): + x = x.copy() + x[:64] = -1 + x[-64:] = -1 + x[:, :8] = 1 + x[:, -8:] = 1 + return x + + +@torch.inference_mode() +def pytorch2numpy(imgs): + results = [] + for x in imgs: + y = x.movedim(0, -1) + y = y * 127.5 + 127.5 + y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) + results.append(y) + return results + + +@torch.inference_mode() +def numpy2pytorch(imgs): + h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 + h = h.movedim(-1, 1) + return h + + +@torch.no_grad() +def duplicate_prefix_to_suffix(x, count, zero_out=False): + if zero_out: + return torch.cat([x, torch.zeros_like(x[:count])], dim=0) + else: + return torch.cat([x, x[:count]], dim=0) + + +def weighted_mse(a, b, weight): + return torch.mean(weight.float() * (a.float() - b.float()) ** 2) + + +def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0): + x = (x - x_min) / (x_max - x_min) + x = max(0.0, min(x, 1.0)) + x = x**sigma + return y_min + x * (y_max - y_min) + + +def expand_to_dims(x, target_dims): + return x.view(*x.shape, *([1] * max(0, target_dims - x.dim()))) + + +def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int): + if tensor is None: + return None + + first_dim = tensor.shape[0] + + if first_dim == batch_size: + return tensor + + if batch_size % first_dim != 0: + raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.") + + repeat_times = batch_size // first_dim + + return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1)) + + +def dim5(x): + return expand_to_dims(x, 5) + + +def dim4(x): + return expand_to_dims(x, 4) + + +def dim3(x): + return expand_to_dims(x, 3) + + +def crop_or_pad_yield_mask(x, length): + B, F, C = x.shape + device = x.device + dtype = x.dtype + + if F < length: + y = torch.zeros((B, length, C), dtype=dtype, device=device) + mask = torch.zeros((B, length), dtype=torch.bool, device=device) + y[:, :F, :] = x + mask[:, :F] = True + return y, mask + + return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device) + + +def extend_dim(x, dim, minimal_length, zero_pad=False): + original_length = int(x.shape[dim]) + + if original_length >= minimal_length: + return x + + if zero_pad: + padding_shape = list(x.shape) + padding_shape[dim] = minimal_length - original_length + padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device) + else: + idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1) + last_element = x[idx] + padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim) + + return torch.cat([x, padding], dim=dim) + + +def lazy_positional_encoding(t, repeats=None): + if not isinstance(t, list): + t = [t] + + from diffusers.models.embeddings import get_timestep_embedding + + te = torch.tensor(t) + te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0) + + if repeats is None: + return te + + te = te[:, None, :].expand(-1, repeats, -1) + + return te + + +def state_dict_offset_merge(A, B, C=None): + result = {} + keys = A.keys() + + for key in keys: + A_value = A[key] + B_value = B[key].to(A_value) + + if C is None: + result[key] = A_value + B_value + else: + C_value = C[key].to(A_value) + result[key] = A_value + B_value - C_value + + return result + + +def state_dict_weighted_merge(state_dicts, weights): + if len(state_dicts) != len(weights): + raise ValueError("Number of state dictionaries must match number of weights") + + if not state_dicts: + return {} + + total_weight = sum(weights) + + if total_weight == 0: + raise ValueError("Sum of weights cannot be zero") + + normalized_weights = [w / total_weight for w in weights] + + keys = state_dicts[0].keys() + result = {} + + for key in keys: + result[key] = state_dicts[0][key] * normalized_weights[0] + + for i in range(1, len(state_dicts)): + state_dict_value = state_dicts[i][key].to(result[key]) + result[key] += state_dict_value * normalized_weights[i] + + return result + + +def group_files_by_folder(all_files): + grouped_files = {} + + for file in all_files: + folder_name = os.path.basename(os.path.dirname(file)) + if folder_name not in grouped_files: + grouped_files[folder_name] = [] + grouped_files[folder_name].append(file) + + list_of_lists = list(grouped_files.values()) + return list_of_lists + + +def generate_timestamp(): + now = datetime.datetime.now() + timestamp = now.strftime("%y%m%d_%H%M%S") + milliseconds = f"{int(now.microsecond / 1000):03d}" + random_number = random.randint(0, 9999) + return f"{timestamp}_{milliseconds}_{random_number}" + + +def write_PIL_image_with_png_info(image, metadata, path): + from PIL.PngImagePlugin import PngInfo + + png_info = PngInfo() + for key, value in metadata.items(): + png_info.add_text(key, value) + + image.save(path, "PNG", pnginfo=png_info) + return image + + +def torch_safe_save(content, path): + torch.save(content, path + "_tmp") + os.replace(path + "_tmp", path) + return path + + +def move_optimizer_to_device(optimizer, device): + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) diff --git a/frame_pack/wrapper.py b/frame_pack/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..cc420da4db1134deca30648077923021b35f82d1 --- /dev/null +++ b/frame_pack/wrapper.py @@ -0,0 +1,51 @@ +import torch + + +def append_dims(x, target_dims): + return x[(...,) + (None,) * (target_dims - x.ndim)] + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0): + if guidance_rescale == 0: + return noise_cfg + + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg + return noise_cfg + + +def fm_wrapper(transformer, t_scale=1000.0): + def k_model(x, sigma, **extra_args): + dtype = extra_args['dtype'] + cfg_scale = extra_args['cfg_scale'] + cfg_rescale = extra_args['cfg_rescale'] + concat_latent = extra_args['concat_latent'] + + original_dtype = x.dtype + sigma = sigma.float() + + x = x.to(dtype) + timestep = (sigma * t_scale).to(dtype) + + if concat_latent is None: + hidden_states = x + else: + hidden_states = torch.cat([x, concat_latent.to(x)], dim=1) + + pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float() + + if cfg_scale == 1.0: + pred_negative = torch.zeros_like(pred_positive) + else: + pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float() + + pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative) + pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale) + + x0 = x.float() - pred.float() * append_dims(sigma, x.ndim) + + return x0.to(dtype=original_dtype) + + return k_model diff --git a/framepack_generate_video.py b/framepack_generate_video.py new file mode 100644 index 0000000000000000000000000000000000000000..3c72d9e6dae8d97f0426d77216e3a6a7aca106bb --- /dev/null +++ b/framepack_generate_video.py @@ -0,0 +1,958 @@ +# Combined and Corrected Script +#!/usr/bin/env python3 + +import argparse +import os +import sys +import time +import random +import traceback +from datetime import datetime +from pathlib import Path +import re # For parsing section args + +import einops +import numpy as np +import torch +import av # For saving video (used by save_bcthw_as_mp4) +from PIL import Image +from tqdm import tqdm +import cv2 + + +# --- Dependencies from diffusers_helper --- +# Ensure this library is installed or in the PYTHONPATH +try: + # from diffusers_helper.hf_login import login # Not strictly needed for inference if models public/cached + from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode #, vae_decode_fake # vae_decode_fake not used here + from diffusers_helper.utils import (save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, + resize_and_center_crop, generate_timestamp) + from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked + from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan + from diffusers_helper.memory import (cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, + offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, + DynamicSwapInstaller, unload_complete_models, load_model_as_complete) + from diffusers_helper.clip_vision import hf_clip_vision_encode + from diffusers_helper.bucket_tools import find_nearest_bucket#, bucket_options # bucket_options no longer needed here +except ImportError: + print("Error: Could not import modules from 'diffusers_helper'.") + print("Please ensure the 'diffusers_helper' library is installed and accessible.") + print("You might need to clone the repository and add it to your PYTHONPATH.") + sys.exit(1) +# --- End Dependencies --- + +from diffusers import AutoencoderKLHunyuanVideo +from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer +from transformers import SiglipImageProcessor, SiglipVisionModel + +# --- Constants --- +DIMENSION_MULTIPLE = 16 # VAE and model constraints often require divisibility by 8 or 16. 16 is safer. +SECTION_ARG_PATTERN = re.compile(r"^(\d+):([^:]+)(?::(.*))?$") # Regex for section arg: number:image_path[:prompt] + +def parse_section_args(section_strings): + """ Parses the --section arguments into a dictionary. """ + section_data = {} + if not section_strings: + return section_data + for section_str in section_strings: + match = SECTION_ARG_PATTERN.match(section_str) + if not match: + print(f"Warning: Invalid section format: '{section_str}'. Expected 'number:image_path[:prompt]'. Skipping.") + continue + section_index_str, image_path, prompt_text = match.groups() + section_index = int(section_index_str) + prompt_text = prompt_text if prompt_text else None + if not os.path.exists(image_path): + print(f"Warning: Image path for section {section_index} ('{image_path}') not found. Skipping section.") + continue + if section_index in section_data: + print(f"Warning: Duplicate section index {section_index}. Overwriting previous entry.") + section_data[section_index] = (image_path, prompt_text) + print(f"Parsed section {section_index}: Image='{image_path}', Prompt='{prompt_text}'") + return section_data + + +def parse_args(): + parser = argparse.ArgumentParser(description="FramePack HunyuanVideo inference script (CLI version with Advanced End Frame & Section Control)") + + # --- Model Paths --- + parser.add_argument('--transformer_path', type=str, default='lllyasviel/FramePackI2V_HY', help="Path to the FramePack Transformer model") + parser.add_argument('--vae_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the VAE model directory") + parser.add_argument('--text_encoder_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the Llama text encoder directory") + parser.add_argument('--text_encoder_2_path', type=str, default='hunyuanvideo-community/HunyuanVideo', help="Path to the CLIP text encoder directory") + parser.add_argument('--image_encoder_path', type=str, default='lllyasviel/flux_redux_bfl', help="Path to the SigLIP image encoder directory") + parser.add_argument('--hf_home', type=str, default='./hf_download', help="Directory to download/cache Hugging Face models") + + # --- Input --- + parser.add_argument("--input_image", type=str, required=True, help="Path to the input image (start frame)") + parser.add_argument("--end_frame", type=str, default=None, help="Path to the optional end frame image (video end)") + parser.add_argument("--prompt", type=str, required=True, help="Default prompt for generation") + parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for generation") + # <<< START: Modified Arguments for End Frame >>> + parser.add_argument("--end_frame_weight", type=float, default=0.3, help="End frame influence weight (0.0-1.0) for blending modes ('half', 'progressive'). Higher blends more end frame *conditioning latent*.") # Default lowered further + parser.add_argument("--end_frame_influence", type=str, default="last", + choices=["last", "half", "progressive", "bookend"], + help="How to use the global end frame: 'last' (uses end frame for initial context only, no latent blending), 'half' (blends start/end conditioning latents for second half of video), 'progressive' (gradually blends conditioning latents from end to start), 'bookend' (uses end frame conditioning latent ONLY for first generated section IF no section keyframe set, no blending otherwise). All modes use start image embedding.") # Help text updated + # <<< END: Modified Arguments for End Frame >>> + # <<< START: New Arguments for Section Control >>> + parser.add_argument("--section", type=str, action='append', + help="Define a keyframe section. Format: 'index:image_path[:prompt]'. Index 0 is the last generated section (video start), 1 is second last, etc. Repeat for multiple sections. Example: --section 0:path/to/start_like.png:'A sunrise' --section 2:path/to/mid.png") + # <<< END: New Arguments for Section Control >>> + + # --- Output Resolution (Choose ONE method) --- + parser.add_argument("--target_resolution", type=int, default=None, help=f"Target resolution for the longer side for automatic aspect ratio calculation (bucketing). Used if --width and --height are not specified. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.") + parser.add_argument("--width", type=int, default=None, help=f"Explicit target width for the output video. Overrides --target_resolution. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.") + parser.add_argument("--height", type=int, default=None, help=f"Explicit target height for the output video. Overrides --target_resolution. Must be positive and ideally divisible by {DIMENSION_MULTIPLE}.") + + # --- Output --- + parser.add_argument("--save_path", type=str, required=True, help="Directory to save the generated video") + parser.add_argument("--save_intermediate_sections", action='store_true', help="Save the video after each section is generated and decoded.") + parser.add_argument("--save_section_final_frames", action='store_true', help="Save the final decoded frame of each generated section as a PNG image.") + + + # --- Generation Parameters (Matching Gradio Demo Defaults where applicable) --- + parser.add_argument("--seed", type=int, default=None, help="Seed for generation. Random if not set.") + parser.add_argument("--total_second_length", type=float, default=5.0, help="Total desired video length in seconds") + parser.add_argument("--fps", type=int, default=30, help="Frames per second for the output video") + parser.add_argument("--steps", type=int, default=25, help="Number of inference steps (changing not recommended)") + parser.add_argument("--distilled_guidance_scale", "--gs", type=float, default=10.0, help="Distilled CFG Scale (gs)") + parser.add_argument("--cfg", type=float, default=1.0, help="Classifier-Free Guidance Scale (fixed at 1.0 for FramePack usually)") + parser.add_argument("--rs", type=float, default=0.0, help="CFG Rescale (fixed at 0.0 for FramePack usually)") + parser.add_argument("--latent_window_size", type=int, default=9, help="Latent window size (changing not recommended)") + + # --- Performance / Memory --- + parser.add_argument('--high_vram', action='store_true', help="Force high VRAM mode (loads all models to GPU)") + parser.add_argument('--low_vram', action='store_true', help="Force low VRAM mode (uses dynamic swapping)") + parser.add_argument("--gpu_memory_preservation", type=float, default=6.0, help="GPU memory (GB) to preserve when offloading (low VRAM mode)") + parser.add_argument('--use_teacache', action='store_true', default=True, help="Use TeaCache optimization (default: True)") + parser.add_argument('--no_teacache', action='store_false', dest='use_teacache', help="Disable TeaCache optimization") + parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda', 'cpu'). Auto-detects if None.") + + args = parser.parse_args() + + # --- Argument Validation --- + if args.seed is None: + args.seed = random.randint(0, 2**32 - 1) + print(f"Generated random seed: {args.seed}") + + if args.width is not None and args.height is not None: + if args.width <= 0 or args.height <= 0: + print(f"Error: Explicit --width ({args.width}) and --height ({args.height}) must be positive.") + sys.exit(1) + if args.target_resolution is not None: + print("Warning: Both --width/--height and --target_resolution specified. Using explicit --width and --height.") + args.target_resolution = None + elif args.target_resolution is not None: + if args.target_resolution <= 0: + print(f"Error: --target_resolution ({args.target_resolution}) must be positive.") + sys.exit(1) + if args.width is not None or args.height is not None: + print("Error: Cannot specify --target_resolution with only one of --width or --height. Provide both or neither.") + sys.exit(1) + else: + print(f"Warning: No resolution specified. Defaulting to --target_resolution 640.") + args.target_resolution = 640 + + if args.end_frame_weight < 0.0 or args.end_frame_weight > 1.0: + print(f"Error: --end_frame_weight must be between 0.0 and 1.0 (got {args.end_frame_weight}).") + sys.exit(1) + + if args.width is not None and args.width % DIMENSION_MULTIPLE != 0: + print(f"Warning: Specified --width ({args.width}) is not divisible by {DIMENSION_MULTIPLE}. It will be rounded down.") + if args.height is not None and args.height % DIMENSION_MULTIPLE != 0: + print(f"Warning: Specified --height ({args.height}) is not divisible by {DIMENSION_MULTIPLE}. It will be rounded down.") + if args.target_resolution is not None and args.target_resolution % DIMENSION_MULTIPLE != 0: + print(f"Warning: Specified --target_resolution ({args.target_resolution}) is not divisible by {DIMENSION_MULTIPLE}. The calculated dimensions will be rounded down.") + + if args.end_frame and not os.path.exists(args.end_frame): + print(f"Error: End frame image not found at '{args.end_frame}'.") + sys.exit(1) + + args.section_data = parse_section_args(args.section) + + os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(args.hf_home)) + os.makedirs(os.environ['HF_HOME'], exist_ok=True) + + return args + + +def load_models(args): + """Loads all necessary models.""" + print("Loading models...") + if args.device: + device = torch.device(args.device) + else: + device = torch.device(gpu if torch.cuda.is_available() else cpu) + print(f"Using device: {device}") + + print(" Loading Text Encoder 1 (Llama)...") + text_encoder = LlamaModel.from_pretrained(args.text_encoder_path, subfolder='text_encoder', torch_dtype=torch.float16).cpu() + print(" Loading Text Encoder 2 (CLIP)...") + text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, subfolder='text_encoder_2', torch_dtype=torch.float16).cpu() + print(" Loading Tokenizer 1 (Llama)...") + tokenizer = LlamaTokenizerFast.from_pretrained(args.text_encoder_path, subfolder='tokenizer') + print(" Loading Tokenizer 2 (CLIP)...") + tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path, subfolder='tokenizer_2') + print(" Loading VAE...") + vae = AutoencoderKLHunyuanVideo.from_pretrained(args.vae_path, subfolder='vae', torch_dtype=torch.float16).cpu() + print(" Loading Image Feature Extractor (SigLIP)...") + feature_extractor = SiglipImageProcessor.from_pretrained(args.image_encoder_path, subfolder='feature_extractor') + print(" Loading Image Encoder (SigLIP)...") + image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder_path, subfolder='image_encoder', torch_dtype=torch.float16).cpu() + print(" Loading Transformer (FramePack)...") + transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(args.transformer_path, torch_dtype=torch.bfloat16).cpu() + + vae.eval() + text_encoder.eval() + text_encoder_2.eval() + image_encoder.eval() + transformer.eval() + + transformer.high_quality_fp32_output_for_inference = True + print('transformer.high_quality_fp32_output_for_inference = True') + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + text_encoder_2.requires_grad_(False) + image_encoder.requires_grad_(False) + transformer.requires_grad_(False) + + print("Models loaded.") + return { + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "vae": vae, + "feature_extractor": feature_extractor, + "image_encoder": image_encoder, + "transformer": transformer, + "device": device + } + +def adjust_to_multiple(value, multiple): + """Rounds down value to the nearest multiple.""" + return (value // multiple) * multiple + +def mix_latents(latent_a, latent_b, weight_b): + """Mix two latents with the specified weight for latent_b.""" + if latent_a is None: return latent_b + if latent_b is None: return latent_a + + target_device = latent_a.device + target_dtype = latent_a.dtype + if latent_b.device != target_device: + latent_b = latent_b.to(target_device) + if latent_b.dtype != target_dtype: + latent_b = latent_b.to(dtype=target_dtype) + + if isinstance(weight_b, torch.Tensor): + weight_b = weight_b.item() + + weight_b = max(0.0, min(1.0, weight_b)) + + if weight_b == 0.0: + return latent_a + elif weight_b == 1.0: + return latent_b + else: + return (1.0 - weight_b) * latent_a + weight_b * latent_b + +def mix_embeddings(embed_a, embed_b, weight_b): + """Mix two embedding tensors (like CLIP image embeddings) with the specified weight for embed_b.""" + if embed_a is None: return embed_b + if embed_b is None: return embed_a + + target_device = embed_a.device + target_dtype = embed_a.dtype + if embed_b.device != target_device: + embed_b = embed_b.to(target_device) + if embed_b.dtype != target_dtype: + embed_b = embed_b.to(dtype=target_dtype) + + if isinstance(weight_b, torch.Tensor): + weight_b = weight_b.item() + + weight_b = max(0.0, min(1.0, weight_b)) + + if weight_b == 0.0: + return embed_a + elif weight_b == 1.0: + return embed_b + else: + return (1.0 - weight_b) * embed_a + weight_b * embed_b + + +def preprocess_image_for_generation(image_path, target_width, target_height, job_id, output_dir, frame_name="input"): + """Loads, processes, and saves a single image.""" + try: + image = Image.open(image_path).convert('RGB') + image_np = np.array(image) + except Exception as e: + print(f"Error loading image '{image_path}': {e}") + raise + + H_orig, W_orig, _ = image_np.shape + print(f" {frame_name.capitalize()} image loaded ({W_orig}x{H_orig}): '{image_path}'") + + image_resized_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height) + try: + Image.fromarray(image_resized_np).save(output_dir / f'{job_id}_{frame_name}_resized_{target_width}x{target_height}.png') + except Exception as e: + print(f"Warning: Could not save resized image preview for {frame_name}: {e}") + + image_pt = torch.from_numpy(image_resized_np).float() / 127.5 - 1.0 + image_pt = image_pt.permute(2, 0, 1)[None, :, None] # B=1, C=3, T=1, H, W + print(f" {frame_name.capitalize()} image processed to tensor shape: {image_pt.shape}") + + return image_np, image_resized_np, image_pt + + +@torch.no_grad() +def generate_video(args, models): + """Generates the video using the loaded models and arguments.""" + + # Unpack models + text_encoder = models["text_encoder"] + text_encoder_2 = models["text_encoder_2"] + tokenizer = models["tokenizer"] + tokenizer_2 = models["tokenizer_2"] + vae = models["vae"] + feature_extractor = models["feature_extractor"] + image_encoder = models["image_encoder"] + transformer = models["transformer"] + device = models["device"] + + # --- Determine Memory Mode --- + if args.high_vram and args.low_vram: + print("Warning: Both --high_vram and --low_vram specified. Defaulting to auto-detection.") + force_high_vram = force_low_vram = False + else: + force_high_vram = args.high_vram + force_low_vram = args.low_vram + + if force_high_vram: + high_vram = True + elif force_low_vram: + high_vram = False + else: + free_mem_gb = get_cuda_free_memory_gb(device) if device.type == 'cuda' else 0 + high_vram = free_mem_gb > 60 + print(f'Auto-detected Free VRAM {free_mem_gb:.2f} GB -> High-VRAM Mode: {high_vram}') + + # --- Configure Models based on VRAM mode --- + if not high_vram: + print("Configuring for Low VRAM mode...") + vae.enable_slicing() + vae.enable_tiling() + print(" Installing DynamicSwap for Transformer...") + DynamicSwapInstaller.install_model(transformer, device=device) + print(" Installing DynamicSwap for Text Encoder 1...") + DynamicSwapInstaller.install_model(text_encoder, device=device) + print("Unloading models from GPU (Low VRAM setup)...") + unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer) + else: + print("Configuring for High VRAM mode (moving models to GPU)...") + text_encoder.to(device) + text_encoder_2.to(device) + image_encoder.to(device) + vae.to(device) + transformer.to(device) + print(" Models moved to GPU.") + + # --- Prepare Inputs --- + print("Preparing inputs...") + prompt = args.prompt + n_prompt = args.negative_prompt + seed = args.seed + total_second_length = args.total_second_length + latent_window_size = args.latent_window_size + steps = args.steps + cfg = args.cfg + gs = args.distilled_guidance_scale + rs = args.rs + gpu_memory_preservation = args.gpu_memory_preservation + use_teacache = args.use_teacache + fps = args.fps + end_frame_path = args.end_frame + end_frame_influence = args.end_frame_influence + end_frame_weight = args.end_frame_weight + section_data = args.section_data + save_intermediate = args.save_intermediate_sections + save_section_frames = args.save_section_final_frames + + total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) + total_latent_sections = int(max(round(total_latent_sections), 1)) + print(f"Calculated total latent sections: {total_latent_sections}") + + job_id = generate_timestamp() + f"_seed{seed}" + output_dir = Path(args.save_path) + output_dir.mkdir(parents=True, exist_ok=True) + final_video_path = None + + # --- Section Preprocessing Storage --- + section_latents = {} + section_image_embeddings = {} # Still store, might be useful later + section_prompt_embeddings = {} + + try: + # --- Text Encoding (Global Prompts) --- + print("Encoding global text prompts...") + if not high_vram: + print(" Low VRAM mode: Loading Text Encoders to GPU...") + fake_diffusers_current_device(text_encoder, device) + load_model_as_complete(text_encoder_2, target_device=device) + print(" Text Encoders loaded.") + + global_llama_vec, global_clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) + + if cfg == 1.0: + print(" CFG scale is 1.0, using zero negative embeddings.") + global_llama_vec_n, global_clip_l_pooler_n = torch.zeros_like(global_llama_vec), torch.zeros_like(global_clip_l_pooler) + else: + print(f" Encoding negative prompt: '{n_prompt}'") + global_llama_vec_n, global_clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) + + global_llama_vec, global_llama_attention_mask = crop_or_pad_yield_mask(global_llama_vec, length=512) + global_llama_vec_n, global_llama_attention_mask_n = crop_or_pad_yield_mask(global_llama_vec_n, length=512) + print(" Global text encoded and processed.") + + # --- Section Text Encoding --- + if section_data: + print("Encoding section-specific prompts...") + for section_index, (img_path, prompt_text) in section_data.items(): + if prompt_text: + print(f" Encoding prompt for section {section_index}: '{prompt_text}'") + sec_llama_vec, sec_clip_pooler = encode_prompt_conds(prompt_text, text_encoder, text_encoder_2, tokenizer, tokenizer_2) + sec_llama_vec, _ = crop_or_pad_yield_mask(sec_llama_vec, length=512) + section_prompt_embeddings[section_index] = ( + sec_llama_vec.cpu().to(transformer.dtype), + sec_clip_pooler.cpu().to(transformer.dtype) + ) + print(f" Section {section_index} prompt encoded and stored on CPU.") + else: + print(f" Section {section_index} has no specific prompt, will use global prompt.") + + if not high_vram: + print(" Low VRAM mode: Unloading Text Encoders from GPU...") + unload_complete_models(text_encoder_2) + print(" Text Encoder 2 unloaded.") + + # --- Input Image Processing & Dimension Calculation --- + print("Processing input image and determining dimensions...") + try: + input_image_np_orig, _, _ = preprocess_image_for_generation( + args.input_image, 1, 1, job_id, output_dir, "temp_input_orig" + ) + except Exception as e: + print(f"Error loading input image '{args.input_image}' for dimension check: {e}") + raise + H_orig, W_orig, _ = input_image_np_orig.shape + print(f" Input image original size: {W_orig}x{H_orig}") + + if args.width is not None and args.height is not None: + target_w, target_h = args.width, args.height + print(f" Using explicit target dimensions: {target_w}x{target_h}") + elif args.target_resolution is not None: + print(f" Calculating dimensions based on target resolution for longer side: {args.target_resolution}") + target_h, target_w = find_nearest_bucket(H_orig, W_orig, resolution=args.target_resolution) + print(f" Calculated dimensions (before adjustment): {target_w}x{target_h}") + else: + raise ValueError("Internal Error: Resolution determination failed.") + + final_w = adjust_to_multiple(target_w, DIMENSION_MULTIPLE) + final_h = adjust_to_multiple(target_h, DIMENSION_MULTIPLE) + + if final_w <= 0 or final_h <= 0: + print(f"Error: Calculated dimensions ({target_w}x{target_h}) resulted in non-positive dimensions after adjusting to be divisible by {DIMENSION_MULTIPLE} ({final_w}x{final_h}).") + raise ValueError("Adjusted dimensions are invalid.") + + if final_w != target_w or final_h != target_h: + print(f"Warning: Adjusted dimensions from {target_w}x{target_h} to {final_w}x{final_h} to be divisible by {DIMENSION_MULTIPLE}.") + else: + print(f" Final dimensions confirmed: {final_w}x{final_h}") + + width, height = final_w, final_h + + if width * height > 1024 * 1024: + print(f"Warning: Target resolution {width}x{height} is large. Ensure you have sufficient VRAM.") + + _, input_image_resized_np, input_image_pt = preprocess_image_for_generation( + args.input_image, width, height, job_id, output_dir, "input" + ) + + end_frame_resized_np = None + end_frame_pt = None + if end_frame_path: + _, end_frame_resized_np, end_frame_pt = preprocess_image_for_generation( + end_frame_path, width, height, job_id, output_dir, "end" + ) + + section_images_resized_np = {} + section_images_pt = {} + if section_data: + print("Processing section keyframe images...") + for section_index, (img_path, _) in section_data.items(): + _, sec_resized_np, sec_pt = preprocess_image_for_generation( + img_path, width, height, job_id, output_dir, f"section{section_index}" + ) + section_images_resized_np[section_index] = sec_resized_np + section_images_pt[section_index] = sec_pt + + # --- VAE Encoding --- + print("VAE encoding initial frame...") + if not high_vram: + print(" Low VRAM mode: Loading VAE to GPU...") + load_model_as_complete(vae, target_device=device) + print(" VAE loaded.") + + input_image_pt_dev = input_image_pt.to(device=device, dtype=vae.dtype) + start_latent = vae_encode(input_image_pt_dev, vae) # GPU, vae.dtype + print(f" Initial latent shape: {start_latent.shape}") + print(f" Start latent stats - Min: {start_latent.min().item():.4f}, Max: {start_latent.max().item():.4f}, Mean: {start_latent.mean().item():.4f}") + + end_frame_latent = None + if end_frame_pt is not None: + print("VAE encoding end frame...") + end_frame_pt_dev = end_frame_pt.to(device=device, dtype=vae.dtype) + end_frame_latent = vae_encode(end_frame_pt_dev, vae) # GPU, vae.dtype + print(f" End frame latent shape: {end_frame_latent.shape}") + print(f" End frame latent stats - Min: {end_frame_latent.min().item():.4f}, Max: {end_frame_latent.max().item():.4f}, Mean: {end_frame_latent.mean().item():.4f}") + if end_frame_latent.shape != start_latent.shape: + print(f"Warning: End frame latent shape mismatch. Reshaping.") + try: + end_frame_latent = end_frame_latent.reshape(start_latent.shape) + except Exception as reshape_err: + print(f"Error reshaping end frame latent: {reshape_err}. Disabling end frame.") + end_frame_latent = None + + if section_images_pt: + print("VAE encoding section keyframes...") + for section_index, sec_pt in section_images_pt.items(): + sec_pt_dev = sec_pt.to(device=device, dtype=vae.dtype) + sec_latent = vae_encode(sec_pt_dev, vae) # GPU, vae.dtype + print(f" Section {section_index} latent shape: {sec_latent.shape}") + if sec_latent.shape != start_latent.shape: + print(f" Warning: Section {section_index} latent shape mismatch. Reshaping.") + try: + sec_latent = sec_latent.reshape(start_latent.shape) + except Exception as reshape_err: + print(f" Error reshaping section {section_index} latent: {reshape_err}. Skipping section latent.") + continue + # Store on CPU as float32 for context/blending later + section_latents[section_index] = sec_latent.cpu().float() + print(f" Section {section_index} latent encoded and stored on CPU.") + + if not high_vram: + print(" Low VRAM mode: Unloading VAE from GPU...") + unload_complete_models(vae) + print(" VAE unloaded.") + + # Move essential latents to CPU as float32 for context/blending + start_latent = start_latent.cpu().float() + if end_frame_latent is not None: + end_frame_latent = end_frame_latent.cpu().float() + + # --- CLIP Vision Encoding --- + print("CLIP Vision encoding image(s)...") + if not high_vram: + print(" Low VRAM mode: Loading Image Encoder to GPU...") + load_model_as_complete(image_encoder, target_device=device) + print(" Image Encoder loaded.") + + # Encode start frame - WILL BE USED CONSISTENTLY for image_embeddings + image_encoder_output = hf_clip_vision_encode(input_image_resized_np, feature_extractor, image_encoder) + start_image_embedding = image_encoder_output.last_hidden_state # GPU, image_encoder.dtype + print(f" Start image embedding shape: {start_image_embedding.shape}") + + # Encode end frame (if provided) - Only needed if extending later + # end_frame_embedding = None # Not needed for this strategy + # if end_frame_resized_np is not None: + # pass # Skip encoding for now + + # Encode section frames (if provided) - Store for potential future use + if section_images_resized_np: + print("CLIP Vision encoding section keyframes (storing on CPU)...") + for section_index, sec_resized_np in section_images_resized_np.items(): + sec_output = hf_clip_vision_encode(sec_resized_np, feature_extractor, image_encoder) + sec_embedding = sec_output.last_hidden_state + section_image_embeddings[section_index] = sec_embedding.cpu().to(transformer.dtype) + print(f" Section {section_index} embedding shape: {sec_embedding.shape}. Stored on CPU.") + + if not high_vram: + print(" Low VRAM mode: Unloading Image Encoder from GPU...") + unload_complete_models(image_encoder) + print(" Image Encoder unloaded.") + + # Move start image embedding to CPU (transformer dtype) + target_dtype = transformer.dtype + start_image_embedding = start_image_embedding.cpu().to(target_dtype) + + # --- Prepare Global Embeddings for Transformer (CPU, transformer.dtype) --- + print("Preparing global embeddings for Transformer...") + global_llama_vec = global_llama_vec.cpu().to(target_dtype) + global_llama_vec_n = global_llama_vec_n.cpu().to(target_dtype) + global_clip_l_pooler = global_clip_l_pooler.cpu().to(target_dtype) + global_clip_l_pooler_n = global_clip_l_pooler_n.cpu().to(target_dtype) + print(f" Global Embeddings prepared on CPU with dtype {target_dtype}.") + + # --- Sampling Setup --- + print("Setting up sampling...") + rnd = torch.Generator(cpu).manual_seed(seed) + num_frames = latent_window_size * 4 - 3 + print(f" Latent frames per sampling step (num_frames input): {num_frames}") + + latent_c, latent_h, latent_w = start_latent.shape[1], start_latent.shape[3], start_latent.shape[4] + context_latents = torch.zeros(size=(1, latent_c, 1 + 2 + 16, latent_h, latent_w), dtype=torch.float32).cpu() + + accumulated_generated_latents = None + history_pixels = None + + latent_paddings = list(reversed(range(total_latent_sections))) + if total_latent_sections > 4: + latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] + print(f" Using adjusted padding sequence for >4 sections: {latent_paddings}") + else: + print(f" Using standard padding sequence: {latent_paddings}") + + # --- [MODIFIED] Restore Initial Context Initialization --- + if end_frame_latent is not None: + print(" Initializing context buffer's first slot with end frame latent.") + context_latents[:, :, 0:1, :, :] = end_frame_latent.cpu().float() # Ensure float32 CPU + else: + print(" No end frame latent available. Initial context remains zeros.") + # --- End Modified Context Initialization --- + + # --- Main Sampling Loop (Generates Backward: End -> Start) --- + start_time = time.time() + num_loops = len(latent_paddings) + + for i_loop, latent_padding in enumerate(latent_paddings): + section_start_time = time.time() + current_section_index_from_end = latent_padding + is_first_generation_step = (i_loop == 0) + is_last_generation_step = (latent_padding == 0) + + print(f"\n--- Starting Generation Step {i_loop+1}/{num_loops} (Section Index from End: {current_section_index_from_end}, First Step: {is_first_generation_step}, Last Step: {is_last_generation_step}) ---") + latent_padding_size = latent_padding * latent_window_size + print(f' Padding size (latent frames): {latent_padding_size}, Window size (latent frames): {latent_window_size}') + + # --- Select Conditioning Inputs for this Section --- + + # 1. Conditioning Latent (`clean_latents_pre`) - Calculate Blend + # Determine the base latent (start or section-specific) + base_conditioning_latent = start_latent # Default to start (float32 CPU) + if current_section_index_from_end in section_latents: + base_conditioning_latent = section_latents[current_section_index_from_end] # Use section if available (float32 CPU) + print(f" Using SECTION {current_section_index_from_end} latent as base conditioning latent.") + else: + print(f" Using START frame latent as base conditioning latent.") + + # Apply 'bookend' override to the base latent for the first step only + if end_frame_influence == "bookend" and is_first_generation_step and end_frame_latent is not None: + if current_section_index_from_end not in section_latents: + base_conditioning_latent = end_frame_latent # float32 CPU + print(" Applying 'bookend': Overriding base conditioning latent with END frame latent for first step.") + + # Blend the base conditioning latent with the end frame latent based on mode/weight + current_conditioning_latent = base_conditioning_latent # Initialize with base + current_end_frame_latent_weight = 0.0 + if end_frame_latent is not None: # Only blend if end frame exists + if end_frame_influence == 'progressive': + progress = i_loop / max(1, num_loops - 1) + current_end_frame_latent_weight = args.end_frame_weight * (1.0 - progress) + elif end_frame_influence == 'half': + if i_loop < num_loops / 2: + current_end_frame_latent_weight = args.end_frame_weight + # For 'last' and 'bookend', weight remains 0, no blending needed + + current_end_frame_latent_weight = max(0.0, min(1.0, current_end_frame_latent_weight)) + + if current_end_frame_latent_weight > 1e-4: # Mix only if weight is significant + print(f" Blending Conditioning Latent: Base<-{1.0-current_end_frame_latent_weight:.3f} | End->{current_end_frame_latent_weight:.3f} (Mode: {end_frame_influence})") + # Ensure both inputs to mix_latents are float32 CPU + current_conditioning_latent = mix_latents(base_conditioning_latent.cpu().float(), + end_frame_latent.cpu().float(), + current_end_frame_latent_weight) + #else: + # print(f" Using BASE conditioning latent (Mode: {end_frame_influence}, Blend Weight near zero).") # Can be verbose + #else: + # print(f" Using BASE conditioning latent (No end frame specified for blending).") # Can be verbose + + + # 2. Image Embedding - Use Fixed Start Embedding + current_image_embedding = start_image_embedding # transformer.dtype CPU + print(f" Using fixed START frame image embedding.") + + + # 3. Text Embedding (Select section or global) + if current_section_index_from_end in section_prompt_embeddings: + current_llama_vec, current_clip_pooler = section_prompt_embeddings[current_section_index_from_end] + print(f" Using SECTION {current_section_index_from_end} prompt embeddings.") + else: + current_llama_vec = global_llama_vec + current_clip_pooler = global_clip_l_pooler + print(f" Using GLOBAL prompt embeddings.") + + current_llama_vec_n = global_llama_vec_n + current_clip_pooler_n = global_clip_l_pooler_n + current_llama_attention_mask = global_llama_attention_mask + current_llama_attention_mask_n = global_llama_attention_mask_n + + # --- Prepare Sampler Inputs --- + indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) + clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = \ + indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) + clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) + + # Prepare conditioning latents (float32 CPU) + clean_latents_pre = current_conditioning_latent # Use the potentially blended one + clean_latents_post, clean_latents_2x, clean_latents_4x = \ + context_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2) + clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) + print(f" Final Conditioning shapes (CPU): clean={clean_latents.shape}, 2x={clean_latents_2x.shape}, 4x={clean_latents_4x.shape}") + print(f" Clean Latents Pre stats - Min: {clean_latents_pre.min().item():.4f}, Max: {clean_latents_pre.max().item():.4f}, Mean: {clean_latents_pre.mean().item():.4f}") + + + # Load Transformer (Low VRAM) + if not high_vram: + print(" Moving Transformer to GPU...") + unload_complete_models() + move_model_to_device_with_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation) + fake_diffusers_current_device(text_encoder, device) + + # Configure TeaCache + if use_teacache: + transformer.initialize_teacache(enable_teacache=True, num_steps=steps) + print(" TeaCache enabled.") + else: + transformer.initialize_teacache(enable_teacache=False) + print(" TeaCache disabled.") + + # --- Run Sampling --- + print(f" Starting sampling ({steps} steps) for {num_frames} latent frames...") + sampling_step_start_time = time.time() + + pbar = tqdm(total=steps, desc=f" Section {current_section_index_from_end} Sampling", leave=False) + def callback(d): + pbar.update(1) + return + + current_sampler_device = transformer.device + current_text_encoder_device = text_encoder.device if not high_vram else device + + # Move tensors to device just before sampling + _prompt_embeds = current_llama_vec.to(current_text_encoder_device) + _prompt_embeds_mask = current_llama_attention_mask.to(current_text_encoder_device) + _prompt_poolers = current_clip_pooler.to(current_sampler_device) + _negative_prompt_embeds = current_llama_vec_n.to(current_text_encoder_device) + _negative_prompt_embeds_mask = current_llama_attention_mask_n.to(current_text_encoder_device) + _negative_prompt_poolers = current_clip_pooler_n.to(current_sampler_device) + _image_embeddings = current_image_embedding.to(current_sampler_device) # Fixed start embedding + _latent_indices = latent_indices.to(current_sampler_device) + # Pass conditioning latents (now potentially blended) to sampler + _clean_latents = clean_latents.to(current_sampler_device, dtype=transformer.dtype) + _clean_latent_indices = clean_latent_indices.to(current_sampler_device) + _clean_latents_2x = clean_latents_2x.to(current_sampler_device, dtype=transformer.dtype) + _clean_latent_2x_indices = clean_latent_2x_indices.to(current_sampler_device) + _clean_latents_4x = clean_latents_4x.to(current_sampler_device, dtype=transformer.dtype) + _clean_latent_4x_indices = clean_latent_4x_indices.to(current_sampler_device) + + generated_latents_gpu = sample_hunyuan( + transformer=transformer, + sampler='unipc', + width=width, + height=height, + frames=num_frames, + real_guidance_scale=cfg, + distilled_guidance_scale=gs, + guidance_rescale=rs, + num_inference_steps=steps, + generator=rnd, + prompt_embeds=_prompt_embeds, + prompt_embeds_mask=_prompt_embeds_mask, + prompt_poolers=_prompt_poolers, + negative_prompt_embeds=_negative_prompt_embeds, + negative_prompt_embeds_mask=_negative_prompt_embeds_mask, + negative_prompt_poolers=_negative_prompt_poolers, + device=current_sampler_device, + dtype=transformer.dtype, + image_embeddings=_image_embeddings, # Using fixed start embedding + latent_indices=_latent_indices, + clean_latents=_clean_latents, # Using potentially blended latents + clean_latent_indices=_clean_latent_indices, + clean_latents_2x=_clean_latents_2x, + clean_latent_2x_indices=_clean_latent_2x_indices, + clean_latents_4x=_clean_latents_4x, + clean_latent_4x_indices=_clean_latent_4x_indices, + callback=callback, + ) + pbar.close() + sampling_step_end_time = time.time() + print(f" Sampling finished in {sampling_step_end_time - sampling_step_start_time:.2f} seconds.") + print(f" Raw generated latent shape for this step: {generated_latents_gpu.shape}") + print(f" Generated latents stats (GPU) - Min: {generated_latents_gpu.min().item():.4f}, Max: {generated_latents_gpu.max().item():.4f}, Mean: {generated_latents_gpu.mean().item():.4f}") + + # Move generated latents to CPU as float32 + generated_latents_cpu = generated_latents_gpu.cpu().float() + del generated_latents_gpu, _prompt_embeds, _prompt_embeds_mask, _prompt_poolers, _negative_prompt_embeds, _negative_prompt_embeds_mask, _negative_prompt_poolers + del _image_embeddings, _latent_indices, _clean_latents, _clean_latent_indices, _clean_latents_2x, _clean_latent_2x_indices, _clean_latents_4x, _clean_latent_4x_indices + if device.type == 'cuda': torch.cuda.empty_cache() + + # Offload Transformer and TE1 (Low VRAM) + if not high_vram: + print(" Low VRAM mode: Offloading Transformer and Text Encoder from GPU...") + offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation) + offload_model_from_device_for_memory_preservation(text_encoder, target_device=device, preserved_memory_gb=gpu_memory_preservation) + print(" Transformer and Text Encoder offloaded.") + + # --- History/Context Update --- + if is_last_generation_step: + print(" Last generation step: Prepending start frame latent to generated latents.") + generated_latents_cpu = torch.cat([start_latent.cpu().float(), generated_latents_cpu], dim=2) + print(f" Shape after prepending start latent: {generated_latents_cpu.shape}") + + context_latents = torch.cat([generated_latents_cpu, context_latents], dim=2) + print(f" Context buffer updated. New shape: {context_latents.shape}") + + # Accumulate the generated latents for the final video output + if accumulated_generated_latents is None: + accumulated_generated_latents = generated_latents_cpu + else: + accumulated_generated_latents = torch.cat([generated_latents_cpu, accumulated_generated_latents], dim=2) + + current_total_latent_frames = accumulated_generated_latents.shape[2] + print(f" Accumulated generated latents updated. Total latent frames: {current_total_latent_frames}") + print(f" Accumulated latents stats - Min: {accumulated_generated_latents.min().item():.4f}, Max: {accumulated_generated_latents.max().item():.4f}, Mean: {accumulated_generated_latents.mean().item():.4f}") + + # --- VAE Decoding & Merging --- + print(" Decoding generated latents and merging video...") + decode_start_time = time.time() + + if not high_vram: + print(" Moving VAE to GPU...") + offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation) + unload_complete_models(text_encoder, text_encoder_2, image_encoder) + load_model_as_complete(vae, target_device=device) + print(" VAE loaded.") + + print(f" Decoding current section's latents (shape: {generated_latents_cpu.shape}) for append.") + latents_to_decode_for_append = generated_latents_cpu.to(device=device, dtype=vae.dtype) + current_pixels = vae_decode(latents_to_decode_for_append, vae).cpu().float() # Decode and move to CPU float32 + print(f" Decoded pixels for append shape: {current_pixels.shape}") + del latents_to_decode_for_append + if device.type == 'cuda': torch.cuda.empty_cache() + + if history_pixels is None: + history_pixels = current_pixels + print(f" Initialized history_pixels shape: {history_pixels.shape}") + else: + append_overlap = 3 + print(f" Appending section with pixel overlap: {append_overlap}") + history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlap=append_overlap) + print(f" Appended. New total pixel shape: {history_pixels.shape}") + + if not high_vram: + print(" Low VRAM mode: Unloading VAE from GPU...") + unload_complete_models(vae) + print(" VAE unloaded.") + + decode_end_time = time.time() + print(f" Decoding and merging finished in {decode_end_time - decode_start_time:.2f} seconds.") + + # --- Save Intermediate/Section Output --- + current_num_pixel_frames = history_pixels.shape[2] + + if save_section_frames: + try: + first_frame_index = 0 # Index 0 of the newly decoded chunk is the first frame generated in this step + frame_to_save = current_pixels[0, :, first_frame_index, :, :] + frame_to_save = einops.rearrange(frame_to_save, 'c h w -> h w c') + frame_to_save_np = frame_to_save.cpu().numpy() + frame_to_save_np = np.clip((frame_to_save_np * 127.5 + 127.5), 0, 255).astype(np.uint8) + section_frame_filename = output_dir / f'{job_id}_section_start_frame_idx{current_section_index_from_end}.png' # Renamed for clarity + Image.fromarray(frame_to_save_np).save(section_frame_filename) + print(f" Saved first generated pixel frame of section {current_section_index_from_end} (from decoded chunk) to: {section_frame_filename}") + except Exception as e: + print(f" [WARN] Error saving section {current_section_index_from_end} start frame image: {e}") + + if save_intermediate or is_last_generation_step: + output_filename = output_dir / f'{job_id}_step{i_loop+1}_idx{current_section_index_from_end}_frames{current_num_pixel_frames}_{width}x{height}.mp4' + print(f" Saving {'intermediate' if not is_last_generation_step else 'final'} video ({current_num_pixel_frames} frames) to: {output_filename}") + try: + save_bcthw_as_mp4(history_pixels.float(), str(output_filename), fps=int(fps)) + print(f" Saved video using save_bcthw_as_mp4") + if not is_last_generation_step: + print(f"INTERMEDIATE_VIDEO_PATH:{output_filename}") + final_video_path = str(output_filename) + except Exception as e: + print(f" Error saving video using save_bcthw_as_mp4: {e}") + traceback.print_exc() + # Fallback save attempt + try: + first_frame_img = history_pixels.float()[0, :, 0].permute(1, 2, 0).cpu().numpy() + first_frame_img = (first_frame_img * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + frame_path = str(output_filename).replace('.mp4', '_first_frame_ERROR.png') + Image.fromarray(first_frame_img).save(frame_path) + print(f" Saved first frame as image to {frame_path} due to video saving error.") + except Exception as frame_err: + print(f" Could not save first frame either: {frame_err}") + + section_end_time = time.time() + print(f"--- Generation Step {i_loop+1} finished in {section_end_time - section_start_time:.2f} seconds ---") + + if is_last_generation_step: + print("\nFinal generation step completed.") + break + + # --- Final Video Saved During Last Step --- + if final_video_path and os.path.exists(final_video_path): + print(f"\nSuccessfully generated: {final_video_path}") + print(f"ACTUAL_FINAL_PATH:{final_video_path}") + return final_video_path + else: + print("\nError: Final video path not found or not saved correctly.") + return None + + except Exception as e: + print("\n--- ERROR DURING GENERATION ---") + traceback.print_exc() + print("-----------------------------") + if 'history_pixels' in locals() and history_pixels is not None and history_pixels.shape[2] > 0: + partial_output_name = output_dir / f"{job_id}_partial_ERROR_{history_pixels.shape[2]}_frames_{width}x{height}.mp4" + print(f"Attempting to save partial video to: {partial_output_name}") + try: + save_bcthw_as_mp4(history_pixels.float(), str(partial_output_name), fps=fps) + print(f"ACTUAL_FINAL_PATH:{partial_output_name}") + return str(partial_output_name) + except Exception as save_err: + print(f"Error saving partial video during error handling: {save_err}") + traceback.print_exc() + + print("Status: Error occurred, no video saved.") + return None + + finally: + print("Performing final model cleanup...") + try: + unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer) + except Exception as e: + print(f"Error during final model unload: {e}") + pass + if device.type == 'cuda': + torch.cuda.empty_cache() + print("CUDA cache cleared.") + +def main(): + args = parse_args() + models = load_models(args) + final_path = generate_video(args, models) + if final_path: + print(f"\nVideo generation finished. Final path: {final_path}") + sys.exit(0) + else: + print("\nVideo generation failed.") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/framepack_lora_inf_utils.py b/framepack_lora_inf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5d96155b0076308fb05bf5b411118fb6a324b0 --- /dev/null +++ b/framepack_lora_inf_utils.py @@ -0,0 +1,336 @@ +import os +import re +from typing import Optional +import torch +from safetensors.torch import load_file +from tqdm import tqdm + +import logging + +from utils.safetensors_utils import MemoryEfficientSafeOpen + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +from modules.fp8_optimization_utils import optimize_state_dict_with_fp8_on_the_fly + + +def merge_lora_to_state_dict( + model_file: str, + lora_files: Optional[list[str]], + multipliers: Optional[list[float]], + fp8_optimization: bool, + device: torch.device, + move_to_device: bool = False, +) -> dict[str, torch.Tensor]: + """ + Merge LoRA weights into the state dict of a model. + """ + + # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix + basename = os.path.basename(model_file) + match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename) + if match: + prefix = basename[: match.start(2)] + count = int(match.group(3)) + model_files = [os.path.normpath(model_file)] + for i in range(count): + file_name = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors" + file_path = os.path.join(os.path.dirname(model_file), file_name) + file_path = os.path.normpath(file_path) + if os.path.exists(file_path) and file_path not in model_files: + model_files.append(file_path) + logger.info(f"Loading split weights: {model_files}") + else: + model_files = [os.path.normpath(model_file)] + + list_of_lora_sd = [] + if lora_files is not None: + for lora_file in lora_files: + # Load LoRA safetensors file + lora_sd = load_file(lora_file) + + # Check the format of the LoRA file + keys = list(lora_sd.keys()) + if keys[0].startswith("lora_unet_"): + logging.info(f"Musubi Tuner LoRA detected") + + else: + transformer_prefixes = ["diffusion_model", "transformer"] # to ignore Text Encoder modules + lora_suffix = None + prefix = None + for key in keys: + if lora_suffix is None and "lora_A" in key: + lora_suffix = "lora_A" + if prefix is None: + pfx = key.split(".")[0] + if pfx in transformer_prefixes: + prefix = pfx + if lora_suffix is not None and prefix is not None: + break + + if lora_suffix == "lora_A" and prefix is not None: + logging.info(f"Diffusion-pipe (?) LoRA detected") + lora_sd = convert_from_diffusion_pipe_or_something(lora_sd, "lora_unet_") + + else: + logging.info(f"LoRA file format not recognized: {os.path.basename(lora_file)}") + lora_sd = None + + if lora_sd is not None: + # Check LoRA is for FramePack or for HunyuanVideo + is_hunyuan = False + for key in lora_sd.keys(): + if "double_blocks" in key or "single_blocks" in key: + is_hunyuan = True + break + if is_hunyuan: + logging.info("HunyuanVideo LoRA detected, converting to FramePack format") + lora_sd = convert_hunyuan_to_framepack(lora_sd) + + if lora_sd is not None: + list_of_lora_sd.append(lora_sd) + + if len(list_of_lora_sd) == 0: + # no LoRA files found, just load the model + return load_safetensors_with_fp8_optimization(model_files, fp8_optimization, device, move_to_device, weight_hook=None) + + return load_safetensors_with_lora_and_fp8(model_files, list_of_lora_sd, multipliers, fp8_optimization, device, move_to_device) + + +def convert_from_diffusion_pipe_or_something(lora_sd: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]: + """ + Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner. + Copy from Musubi Tuner repo. + """ + # convert from diffusers(?) to default LoRA + # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...} + # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...} + + # note: Diffusers has no alpha, so alpha is set to rank + new_weights_sd = {} + lora_dims = {} + for key, weight in lora_sd.items(): + diffusers_prefix, key_body = key.split(".", 1) + if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer": + print(f"unexpected key: {key} in diffusers format") + continue + + new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.") + new_weights_sd[new_key] = weight + + lora_name = new_key.split(".")[0] # before first dot + if lora_name not in lora_dims and "lora_down" in new_key: + lora_dims[lora_name] = weight.shape[0] + + # add alpha with rank + for lora_name, dim in lora_dims.items(): + new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim) + + return new_weights_sd + + +def load_safetensors_with_lora_and_fp8( + model_files: list[str], + list_of_lora_sd: list[dict[str, torch.Tensor]], + multipliers: Optional[list[float]], + fp8_optimization: bool, + device: torch.device, + move_to_device: bool = False, +) -> dict[str, torch.Tensor]: + """ + Merge LoRA weights into the state dict of a model with fp8 optimization if needed. + """ + if multipliers is None: + multipliers = [1.0] * len(list_of_lora_sd) + if len(multipliers) > len(list_of_lora_sd): + multipliers = multipliers[: len(list_of_lora_sd)] + if len(multipliers) < len(list_of_lora_sd): + multipliers += [1.0] * (len(list_of_lora_sd) - len(multipliers)) + multipliers = [float(m) for m in multipliers] + + list_of_lora_weight_keys = [] + for lora_sd in list_of_lora_sd: + lora_weight_keys = set(lora_sd.keys()) + list_of_lora_weight_keys.append(lora_weight_keys) + + # Merge LoRA weights into the state dict + print(f"Merging LoRA weights into state dict on the fly. multipliers: {multipliers}") + + # make hook for LoRA merging + def weight_hook(model_weight_key, model_weight): + nonlocal list_of_lora_weight_keys, list_of_lora_sd, multipliers + + if not model_weight_key.endswith(".weight"): + return model_weight + + original_device = model_weight.device + if original_device != device: + model_weight = model_weight.to(device) # to make calculation faster + + for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, list_of_lora_sd, multipliers): + # check if this weight has LoRA weights + lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight" + lora_name = "lora_unet_" + lora_name.replace(".", "_") + down_key = lora_name + ".lora_down.weight" + up_key = lora_name + ".lora_up.weight" + alpha_key = lora_name + ".alpha" + if down_key not in lora_weight_keys or up_key not in lora_weight_keys: + return model_weight + + # get LoRA weights + down_weight = lora_sd[down_key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + down_weight = down_weight.to(device) + up_weight = up_weight.to(device) + + # W <- W + U * D + if len(model_weight.size()) == 2: + # linear + if len(up_weight.size()) == 4: # use linear projection mismatch + up_weight = up_weight.squeeze(3).squeeze(2) + down_weight = down_weight.squeeze(3).squeeze(2) + model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + model_weight = ( + model_weight + + multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + model_weight = model_weight + multiplier * conved * scale + + # remove LoRA keys from set + lora_weight_keys.remove(down_key) + lora_weight_keys.remove(up_key) + if alpha_key in lora_weight_keys: + lora_weight_keys.remove(alpha_key) + + model_weight = model_weight.to(original_device) # move back to original device + return model_weight + + state_dict = load_safetensors_with_fp8_optimization( + model_files, fp8_optimization, device, move_to_device, weight_hook=weight_hook + ) + + for lora_weight_keys in list_of_lora_weight_keys: + if len(lora_weight_keys) > 0: + # if there are still LoRA keys left, it means they are not used in the model + # this is a warning, not an error + logger.warning(f"Warning: {len(lora_weight_keys)} LoRA keys not used in the model: {lora_weight_keys}") + + return state_dict + + +def convert_hunyuan_to_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Convert HunyuanVideo LoRA weights to FramePack format. + """ + new_lora_sd = {} + for key, weight in lora_sd.items(): + if "double_blocks" in key: + key = key.replace("double_blocks", "transformer_blocks") + key = key.replace("img_mod_linear", "norm1_linear") + key = key.replace("img_attn_qkv", "attn_to_QKV") # split later + key = key.replace("img_attn_proj", "attn_to_out_0") + key = key.replace("img_mlp_fc1", "ff_net_0_proj") + key = key.replace("img_mlp_fc2", "ff_net_2") + key = key.replace("txt_mod_linear", "norm1_context_linear") + key = key.replace("txt_attn_qkv", "attn_add_QKV_proj") # split later + key = key.replace("txt_attn_proj", "attn_to_add_out") + key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj") + key = key.replace("txt_mlp_fc2", "ff_context_net_2") + elif "single_blocks" in key: + key = key.replace("single_blocks", "single_transformer_blocks") + key = key.replace("linear1", "attn_to_QKVM") # split later + key = key.replace("linear2", "proj_out") + key = key.replace("modulation_linear", "norm_linear") + else: + print(f"Unsupported module name: {key}, only double_blocks and single_blocks are supported") + continue + + if "QKVM" in key: + # split QKVM into Q, K, V, M + key_q = key.replace("QKVM", "q") + key_k = key.replace("QKVM", "k") + key_v = key.replace("QKVM", "v") + key_m = key.replace("attn_to_QKVM", "proj_mlp") + if "_down" in key or "alpha" in key: + # copy QKVM weight or alpha to Q, K, V, M + assert "alpha" in key or weight.size(1) == 3072, f"QKVM weight size mismatch: {key}. {weight.size()}" + new_lora_sd[key_q] = weight + new_lora_sd[key_k] = weight + new_lora_sd[key_v] = weight + new_lora_sd[key_m] = weight + elif "_up" in key: + # split QKVM weight into Q, K, V, M + assert weight.size(0) == 21504, f"QKVM weight size mismatch: {key}. {weight.size()}" + new_lora_sd[key_q] = weight[:3072] + new_lora_sd[key_k] = weight[3072 : 3072 * 2] + new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3] + new_lora_sd[key_m] = weight[3072 * 3 :] # 21504 - 3072 * 3 = 12288 + else: + print(f"Unsupported module name: {key}") + continue + elif "QKV" in key: + # split QKV into Q, K, V + key_q = key.replace("QKV", "q") + key_k = key.replace("QKV", "k") + key_v = key.replace("QKV", "v") + if "_down" in key or "alpha" in key: + # copy QKV weight or alpha to Q, K, V + assert "alpha" in key or weight.size(1) == 3072, f"QKV weight size mismatch: {key}. {weight.size()}" + new_lora_sd[key_q] = weight + new_lora_sd[key_k] = weight + new_lora_sd[key_v] = weight + elif "_up" in key: + # split QKV weight into Q, K, V + assert weight.size(0) == 3072 * 3, f"QKV weight size mismatch: {key}. {weight.size()}" + new_lora_sd[key_q] = weight[:3072] + new_lora_sd[key_k] = weight[3072 : 3072 * 2] + new_lora_sd[key_v] = weight[3072 * 2 :] + else: + print(f"Unsupported module name: {key}") + continue + else: + # no split needed + new_lora_sd[key] = weight + + return new_lora_sd + + +def load_safetensors_with_fp8_optimization( + model_files: list[str], fp8_optimization: bool, device: torch.device, move_to_device: bool, weight_hook: callable = None +) -> dict[str, torch.Tensor]: + """ + Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed. + """ + if fp8_optimization: + TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"] + EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8 + state_dict = optimize_state_dict_with_fp8_on_the_fly( + model_files, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device, weight_hook=weight_hook + ) + else: + state_dict = {} + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file) as f: + for key in tqdm(f.keys(), desc=f"Loading {model_file}", leave=False): + value = f.get_tensor(key) + if weight_hook is not None: + value = weight_hook(key, value) + if move_to_device: + value = value.to(device) + state_dict[key] = value + + return state_dict diff --git a/funconvert_lora.py b/funconvert_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..d480d62f04ed11aa311627afcfbccc805e702bbc --- /dev/null +++ b/funconvert_lora.py @@ -0,0 +1,186 @@ +# convert_lora_i2v_to_fc.py +import torch +import safetensors.torch +import safetensors # Need this for safe_open +import argparse +import os +import re # Regular expressions might be useful for more complex key parsing if needed + +# !!! IMPORTANT: Updated based on the output of analyze_wan_models.py !!! +# The base layer name identified with shape mismatch. +# Check your LoRA file's keys if they use a different prefix (e.g., 'transformer.') +# Assuming the base name identified in LoRA keys matches this. +BASE_LAYERS_TO_SKIP_LORA = { + "patch_embedding", # The layer name from the analysis output + # Add other layers here ONLY if the analysis revealed more mismatches +} +# !!! END IMPORTANT SECTION !!! + +def get_base_layer_name(lora_key: str, prefixes = ["lora_transformer_", "lora_unet_"]): + """ + Attempts to extract the base model layer name from a LoRA key. + Handles common prefixes and suffixes. Adjust prefixes if needed. + + Example: "lora_transformer_patch_embedding_down.weight" -> "patch_embedding" + "lora_transformer_blocks_0_attn_qkv.alpha" -> "blocks.0.attn.qkv" + + Args: + lora_key (str): The key from the LoRA state dictionary. + prefixes (list[str]): A list of potential prefixes used in LoRA keys. + + Returns: + str: The inferred base model layer name. + """ + cleaned_key = lora_key + + # Remove known prefixes + for prefix in prefixes: + if cleaned_key.startswith(prefix): + cleaned_key = cleaned_key[len(prefix):] + break # Assume only one prefix matches + + # Remove known suffixes + # Order matters slightly if one suffix is part of another; list longer ones first if needed + known_suffixes = [ + ".lora_up.weight", + ".lora_down.weight", + "_lora_up.weight", # Include underscore variants just in case + "_lora_down.weight", + ".alpha" + ] + for suffix in known_suffixes: + if cleaned_key.endswith(suffix): + cleaned_key = cleaned_key[:-len(suffix)] + break + + # Replace underscores used by some training scripts with periods for consistency + # if the original model uses periods (like typical PyTorch modules). + # Adjust this logic if the base model itself uses underscores extensively. + cleaned_key = cleaned_key.replace("_", ".") + + # Specific fix for the target layer if prefix/suffix removal was incomplete or ambiguous + # This is somewhat heuristic and might need adjustment based on exact LoRA key naming. + if cleaned_key.startswith("patch.embedding"): # Handle case where prefix removal was incomplete + # Map potential variants back to the canonical name found in analysis + cleaned_key = "patch_embedding" + elif cleaned_key == "patch.embedding.weight": # If suffix removal left .weight attached somehow + cleaned_key = "patch_embedding" + # Add elif clauses here if other specific key mappings are needed + + + return cleaned_key + + +def convert_lora(source_lora_path: str, target_lora_path: str): + """ + Converts an i2v_14B LoRA to be compatible with i2v_14B_FC by + removing LoRA weights associated with layers that have incompatible shapes. + + Args: + source_lora_path (str): Path to the input LoRA file (.safetensors). + target_lora_path (str): Path to save the converted LoRA file (.safetensors). + """ + print(f"Loading source LoRA from: {source_lora_path}") + if not os.path.exists(source_lora_path): + print(f"Error: Source file not found: {source_lora_path}") + return + + try: + # Load tensors and metadata using safe_open for better handling + source_lora_state_dict = {} + metadata = {} + with safetensors.safe_open(source_lora_path, framework="pt", device="cpu") as f: + metadata = f.metadata() # Get metadata if it exists + if metadata is None: # Ensure metadata is a dict even if empty + metadata = {} + for key in f.keys(): + source_lora_state_dict[key] = f.get_tensor(key) # Load tensors + + print(f"Successfully loaded {len(source_lora_state_dict)} tensors.") + if metadata: + print(f"Found metadata: {metadata}") + else: + print("No metadata found.") + + except Exception as e: + print(f"Error loading LoRA file: {e}") + import traceback + traceback.print_exc() + return + + target_lora_state_dict = {} + skipped_keys = [] + kept_keys = [] + base_name_map = {} # Store mapping for reporting + + print(f"\nConverting LoRA weights...") + print(f"Will skip LoRA weights targeting these base layers: {BASE_LAYERS_TO_SKIP_LORA}") + + # Iterate through the loaded tensors + for key, tensor in source_lora_state_dict.items(): + # Use the helper function to extract the base layer name + base_layer_name = get_base_layer_name(key) + base_name_map[key] = base_layer_name # Store for reporting purposes + + # Check if the identified base layer name should be skipped + if base_layer_name in BASE_LAYERS_TO_SKIP_LORA: + skipped_keys.append(key) + else: + # Keep the tensor if its base layer is not in the skip list + target_lora_state_dict[key] = tensor + kept_keys.append(key) + + # --- Reporting --- + print(f"\nConversion Summary:") + print(f" - Total Tensors in Source: {len(source_lora_state_dict)}") + print(f" - Kept {len(kept_keys)} LoRA weight tensors.") + print(f" - Skipped {len(skipped_keys)} LoRA weight tensors (due to incompatible base layer shape):") + + if skipped_keys: + max_print = 15 # Show a few more skipped keys if desired + skipped_sorted = sorted(skipped_keys) # Sort for consistent output order + for i, key in enumerate(skipped_sorted): + base_name = base_name_map.get(key, "N/A") # Get the identified base name + print(f" - {key} (Base Layer Identified: {base_name})") + if i >= max_print -1 and len(skipped_keys) > max_print: + print(f" ... and {len(skipped_keys) - max_print} more.") + break + else: + print(" None") + + # --- Saving --- + print(f"\nSaving converted LoRA ({len(target_lora_state_dict)} tensors) to: {target_lora_path}") + try: + # Save the filtered state dictionary with the original metadata + safetensors.torch.save_file(target_lora_state_dict, target_lora_path, metadata=metadata) + print("Conversion successful!") + except Exception as e: + print(f"Error saving converted LoRA file: {e}") + + +if __name__ == "__main__": + # Setup argument parser + parser = argparse.ArgumentParser( + description="Convert Wan i2v_14B LoRA to i2v_14B_FC LoRA by removing incompatible patch_embedding weights.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("source_lora", type=str, help="Path to the source i2v_14B LoRA file (.safetensors).") + parser.add_argument("target_lora", type=str, help="Path to save the converted i2v_14B_FC LoRA file (.safetensors).") + + # Parse arguments + args = parser.parse_args() + + # --- Input Validation --- + if not os.path.exists(args.source_lora): + print(f"Error: Source LoRA file not found at '{args.source_lora}'") + elif not args.source_lora.lower().endswith(".safetensors"): + print(f"Warning: Source file '{args.source_lora}' does not have a .safetensors extension.") + elif args.source_lora == args.target_lora: + print(f"Error: Source and target paths cannot be the same ('{args.source_lora}'). Choose a different target path.") + elif os.path.exists(args.target_lora): + print(f"Warning: Target file '{args.target_lora}' already exists and will be overwritten.") + # Optionally add a --force flag or prompt user here + convert_lora(args.source_lora, args.target_lora) + else: + # Run the conversion if basic checks pass + convert_lora(args.source_lora, args.target_lora) \ No newline at end of file diff --git a/h1111.py b/h1111.py new file mode 100644 index 0000000000000000000000000000000000000000..3c9559ceddd7b008cebb7fa86373bde35ee5ca4f --- /dev/null +++ b/h1111.py @@ -0,0 +1,8858 @@ +import gradio as gr +from gradio import update as gr_update +import subprocess +import threading +import time +import re +import os +import random +import tiktoken +import sys +import ffmpeg +from typing import List, Tuple, Optional, Generator, Dict, Any +import json +from gradio import themes +from gradio.themes.utils import colors +import subprocess +from PIL import Image +import math +import cv2 +import glob +import shutil +from pathlib import Path +import logging +from datetime import datetime +from tqdm import tqdm +from diffusers_helper.bucket_tools import find_nearest_bucket +import time + + +# Add global stop event +stop_event = threading.Event() +skip_event = threading.Event() +logger = logging.getLogger(__name__) + +def refresh_lora_dropdowns_simple(lora_folder: str) -> List[gr.update]: + """Refreshes LoRA choices, always defaulting the selection to 'None'.""" + new_choices = get_lora_options(lora_folder) + results = [] + print(f"Refreshing LoRA dropdowns. Found choices: {new_choices}") # Debug print + for i in range(4): # Update all 4 slots + results.extend([ + gr.update(choices=new_choices, value="None"), # Always reset value to None + gr.update(value=1.0) # Reset multiplier + ]) + return results + +def process_framepack_extension_video( + input_video: str, + prompt: str, + negative_prompt: str, + seed: int, + batch_count: int, + fpe_use_normal_framepack: bool, + fpe_end_frame: Optional[str], + fpe_end_frame_weight: float, + resolution_max_dim: int, + total_second_length: float, + latent_window_size: int, + steps: int, + cfg_scale: float, # Maps to --cfg + distilled_guidance_scale: float, # Maps to --gs + # rs_scale: float, # --rs, usually 0.0, can be fixed or advanced option + gpu_memory_preservation: float, + use_teacache: bool, + no_resize: bool, + mp4_crf: int, + num_clean_frames: int, + vae_batch_size: int, + save_path: str, # Maps to --output_dir + # Model Paths + fpe_transformer_path: str, # DiT + fpe_vae_path: str, + fpe_text_encoder_path: str, # TE1 + fpe_text_encoder_2_path: str, # TE2 + fpe_image_encoder_path: str, + # Advanced performance + fpe_attn_mode: str, + fpe_fp8_llm: bool, + fpe_vae_chunk_size: Optional[int], + fpe_vae_spatial_tile_sample_min_size: Optional[int], + # LoRAs + fpe_lora_folder: str, + fpe_lora_weight_1: str, fpe_lora_mult_1: float, + fpe_lora_weight_2: str, fpe_lora_mult_2: float, + fpe_lora_weight_3: str, fpe_lora_mult_3: float, + fpe_lora_weight_4: str, fpe_lora_mult_4: float, + # Preview + fpe_enable_preview: bool, + fpe_preview_interval: int, # This arg is not used by f1_video_cli_local.py + fpe_extension_only: bool, + fpe_start_guidance_image: Optional[str], + fpe_start_guidance_image_clip_weight: float, + fpe_use_guidance_image_as_first_latent: bool, + *args: Any # For future expansion or unmapped params, not strictly needed here +) -> Generator[Tuple[List[Tuple[str, str]], Optional[str], str, str], None, None]: + global stop_event, skip_event + stop_event.clear() + skip_event.clear() # Assuming skip_event might be used for batch items + + if not input_video or not os.path.exists(input_video): + yield [], None, "Error: Input video for extension not found.", "" + return + + if not save_path or not save_path.strip(): + save_path = "outputs/framepack_extensions" # Default save path for extensions + os.makedirs(save_path, exist_ok=True) + + # Prepare LoRA arguments + lora_weights_paths = [] + lora_multipliers_values = [] + lora_params_ui = [ + (fpe_lora_weight_1, fpe_lora_mult_1), (fpe_lora_weight_2, fpe_lora_mult_2), + (fpe_lora_weight_3, fpe_lora_mult_3), (fpe_lora_weight_4, fpe_lora_mult_4) + ] + if fpe_lora_folder and os.path.exists(fpe_lora_folder): + for weight_name, mult_val in lora_params_ui: + if weight_name and weight_name != "None": + lora_path = os.path.join(fpe_lora_folder, weight_name) + if os.path.exists(lora_path): + lora_weights_paths.append(lora_path) + lora_multipliers_values.append(str(mult_val)) + else: + print(f"Warning: LoRA file not found: {lora_path}") + + all_generated_videos = [] + script_to_use = "f_video_end_cli_local.py" if fpe_use_normal_framepack else "f1_video_cli_local.py" + model_type_str = "Normal FramePack" if fpe_use_normal_framepack else "FramePack F1" + print(f"Using {model_type_str} model for extension via script: {script_to_use}") + + for i in range(batch_count): + if stop_event.is_set(): + yield all_generated_videos, None, "Generation stopped by user.", "" + return + skip_event.clear() + + current_seed_val = seed + if seed == -1: + current_seed_val = random.randint(0, 2**32 - 1) + elif batch_count > 1: + current_seed_val = seed + i + + # This run_id is not directly used for preview file naming by f1_video_cli_local.py + # as it constructs its own job_id based filenames for section previews. + # run_id = f"{int(time.time())}_{random.randint(1000, 9999)}_ext_s{current_seed_val}" + + current_preview_yield_path = None + last_preview_section_processed = -1 + + status_text = f"Processing Extension {i + 1}/{batch_count} (Seed: {current_seed_val})" + progress_text = "Preparing extension subprocess..." + yield all_generated_videos, current_preview_yield_path, status_text, progress_text + + command = [ + sys.executable, script_to_use, + "--input_video", str(input_video), + "--prompt", str(prompt), + "--n_prompt", str(negative_prompt), + "--seed", str(current_seed_val), + "--resolution_max_dim", str(resolution_max_dim), + "--total_second_length", str(total_second_length), # Script uses this for *additional* length + "--latent_window_size", str(latent_window_size), + "--steps", str(steps), + "--cfg", str(cfg_scale), + "--gs", str(distilled_guidance_scale), + "--rs", "0.0", + "--gpu_memory_preservation", str(gpu_memory_preservation), + "--mp4_crf", str(mp4_crf), + "--num_clean_frames", str(num_clean_frames), + "--vae_batch_size", str(vae_batch_size), + "--output_dir", str(save_path), + "--dit", str(fpe_transformer_path), "--vae", str(fpe_vae_path), + "--text_encoder1", str(fpe_text_encoder_path), "--text_encoder2", str(fpe_text_encoder_2_path), + "--image_encoder", str(fpe_image_encoder_path), + "--attn_mode", str(fpe_attn_mode), + ] + if use_teacache: command.append("--use_teacache") + if no_resize: command.append("--no_resize") + if fpe_fp8_llm: command.append("--fp8_llm") # Though F1 script might not use this + if fpe_vae_chunk_size is not None and fpe_vae_chunk_size > 0: + command.extend(["--vae_chunk_size", str(fpe_vae_chunk_size)]) + if fpe_vae_spatial_tile_sample_min_size is not None and fpe_vae_spatial_tile_sample_min_size > 0: + command.extend(["--vae_spatial_tile_sample_min_size", str(fpe_vae_spatial_tile_sample_min_size)]) + + if lora_weights_paths: + command.extend(["--lora_weight"] + lora_weights_paths) + command.extend(["--lora_multiplier"] + lora_multipliers_values) + if fpe_extension_only: + command.append("--extension_only") + # Script-specific arguments + if fpe_use_normal_framepack: + if fpe_fp8_llm: # Normal FP script uses this + command.append("--fp8_llm") + if fpe_end_frame and os.path.exists(fpe_end_frame): + command.extend(["--end_frame", str(fpe_end_frame)]) + command.extend(["--end_frame_weight", str(fpe_end_frame_weight)]) + else: + if fpe_start_guidance_image and os.path.exists(fpe_start_guidance_image): + command.extend(["--start_guidance_image", str(fpe_start_guidance_image)]) + command.extend(["--start_guidance_image_clip_weight", str(fpe_start_guidance_image_clip_weight)]) + if fpe_use_guidance_image_as_first_latent: + command.append("--use_guidance_image_as_first_latent") + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + print(f"Running FramePack-Extension Command: {' '.join(command)}") + + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, env=env, bufsize=1, universal_newlines=True) + + # Regex patterns based on script + if fpe_use_normal_framepack: # This means f_video_end_cli_local.py was used + final_video_path_regex = re.compile(r"Final (?:extended video saved:|extension-only video saved:) (.*\.mp4)") + # Regex for "--- Generating Extension: ... Section X / Y (backward) ---" + fpe_section_progress_regex = re.compile(r"--- Generating Extension: .*?: Section\s+(\d+)\s*/\s*(\d+)\s+\(backward\)") + tqdm_cli_progress_regex = re.compile(r"Sampling Extension Section .*?:\s*(\d+)%\|.*?\|\s*(\d+/\d+)\s*\[([^<]+)<([^,]+),") + else: # F1 script (f1_video_cli_local.py) was used + final_video_path_regex = re.compile(r"Final (?:extension-only )?video for seed \d+.*? saved as: (.*\.mp4)") + fpe_section_progress_regex = re.compile(r"--- F1 Extension: .*?: Section (\d+)\s*/\s*(\d+) ---") + tqdm_cli_progress_regex = re.compile(r"Sampling Extension Section .*?:\s*(\d+)%\|.*?\|\s*(\d+/\d+)\s*\[([^<]+)<([^,]+),") + fpe_preview_saved_regex = re.compile(r"MP4 Preview for section (\d+) saved: (.*\.mp4)") + + current_video_file_for_item = None + current_section_being_processed = 0 + total_sections_from_log = 0 + + for line in iter(process.stdout.readline, ''): + if stop_event.is_set(): + try: + process.terminate() + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill(); process.wait() + except Exception as e: print(f"Error terminating FPE subprocess: {e}") + yield all_generated_videos, None, "Generation stopped by user.", "" + return + if skip_event.is_set() and batch_count > 1: + print(f"Skip signal received for FPE batch item {i+1}. Terminating subprocess...") + try: + process.terminate() + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill(); process.wait() + except Exception as e: print(f"Error terminating FPE subprocess during skip: {e}") + skip_event.clear() + yield all_generated_videos, current_preview_yield_path, f"Skipping FPE item {i+1}/{batch_count}...", "" + break + + line_strip = line.strip() + if not line_strip: + continue + print(f"FPE_SUBPROCESS: {line_strip}") + + progress_text_update = line_strip + + section_match = fpe_section_progress_regex.search(line_strip) + tqdm_match_cli = tqdm_cli_progress_regex.search(line_strip) + final_video_match = final_video_path_regex.search(line_strip) + preview_saved_match = fpe_preview_saved_regex.search(line_strip) + + if preview_saved_match and fpe_enable_preview: + saved_section_num = int(preview_saved_match.group(1)) + preview_mp4_path_from_log = preview_saved_match.group(2).strip() + if os.path.exists(preview_mp4_path_from_log) and saved_section_num > last_preview_section_processed: + current_preview_yield_path = preview_mp4_path_from_log # Yield clean path + last_preview_section_processed = saved_section_num + print(f"DEBUG FPE: MP4 Preview updated from log - {current_preview_yield_path}") + # This log usually comes *after* the section info, so status might already be updated + + if section_match: + current_section_being_processed = int(section_match.group(1)) + total_sections_from_log = int(section_match.group(2)) + status_text = f"Extending Video {i + 1}/{batch_count} (Seed: {current_seed_val}) - Section {current_section_being_processed}/{total_sections_from_log}" + progress_text_update = f"Starting Section {current_section_being_processed}..." + # Fallback logic for preview (if enabled and explicit log was missed) + # This is less likely to be needed if fpe_preview_saved_regex is robust + if fpe_enable_preview and current_section_being_processed > 1: + section_to_check_for_preview = current_section_being_processed - 1 + if section_to_check_for_preview > last_preview_section_processed: + # Construct the expected preview filename based on f1_video_cli_local.py's naming + # It uses a job_id that includes seed, resolution, etc. We don't know the exact job_id here. + # Relying on "MP4 Preview for section X saved:" log is more reliable. + # For a fallback, we could glob for *partX*.mp4, but that's risky. + # For now, this fallback is removed as the primary log line should be sufficient. + pass + + + elif tqdm_match_cli: + percentage = tqdm_match_cli.group(1) + steps_iter_total = tqdm_match_cli.group(2) + time_elapsed = tqdm_match_cli.group(3).strip() + time_remaining = tqdm_match_cli.group(4).strip() + # Ensure total_sections_from_log is not zero before using in f-string + total_sections_display = total_sections_from_log if total_sections_from_log > 0 else "?" + progress_text_update = f"Section {current_section_being_processed}/{total_sections_display} - Step {steps_iter_total} ({percentage}%) | ETA: {time_remaining}" + status_text = f"Extending Video {i + 1}/{batch_count} (Seed: {current_seed_val}) - Sampling Section {current_section_being_processed}" + + elif final_video_match: + found_video_path = final_video_match.group(1).strip() + if os.path.exists(found_video_path): + current_video_file_for_item = found_video_path + progress_text_update = f"Finalizing: {os.path.basename(current_video_file_for_item)}" + status_text = f"Extension {i + 1}/{batch_count} (Seed: {current_seed_val}) - Saved" + else: + print(f"Warning FPE: Final video path from log not found: {found_video_path}") + + yield all_generated_videos, current_preview_yield_path, status_text, progress_text_update + + process.stdout.close() + return_code = process.wait() + + if return_code == 0 and current_video_file_for_item and os.path.exists(current_video_file_for_item): + all_generated_videos.append((current_video_file_for_item, f"Extended - Seed: {current_seed_val}")) + status_text = f"Extension {i + 1}/{batch_count} (Seed: {current_seed_val}) - Completed and Added" + progress_text = f"Saved: {os.path.basename(current_video_file_for_item)}" + yield all_generated_videos.copy(), None, status_text, progress_text # Clear preview after item completion + elif return_code != 0: + status_text = f"Extension {i + 1}/{batch_count} (Seed: {current_seed_val}) - Failed (Code: {return_code})" + progress_text = f"Subprocess failed. Check console for errors from f1_video_cli_local.py" + yield all_generated_videos.copy(), None, status_text, progress_text + else: # rc == 0 but no video path + status_text = f"Extension {i + 1}/{batch_count} (Seed: {current_seed_val}) - Finished, but no video file confirmed." + progress_text = "Check console logs from f1_video_cli_local.py for the saved path." + yield all_generated_videos.copy(), None, status_text, progress_text + + # The F1 script already cleans up its intermediate _partX files. + # No need for unique_preview_suffix based cleanup here for FPE. + + yield all_generated_videos, None, "FramePack-Extension Batch complete.", "" + +def set_random_seed(): + """Returns -1 to set the seed input to random.""" + return -1 + +def get_step_from_preview_path(path): # Helper function + # Extracts step number from preview filenames like latent_preview_step_005.mp4 + # or for framepack: latent_preview_section_002.mp4 (assuming sections for framepack) + # Let's adjust for potential FramePack naming convention (using 'section' instead of 'step') + base = os.path.basename(path) + match_step = re.search(r"step_(\d+)", base) + if match_step: + return int(match_step.group(1)) + match_section = re.search(r"section_(\d+)", base) # Check for FramePack section naming + if match_section: + # Maybe treat sections differently? Or just return the number? Let's return number. + return int(match_section.group(1)) + return -1 # Default if no number found + +def process_framepack_video( + prompt: str, + negative_prompt: str, + input_image: str, # Start image path + input_end_frame: Optional[str], # End image path + end_frame_influence: str, + end_frame_weight: float, + transformer_path: str, + vae_path: str, + text_encoder_path: str, + text_encoder_2_path: str, + image_encoder_path: str, + target_resolution: Optional[int], + framepack_width: Optional[int], + framepack_height: Optional[int], + original_dims_str: str, # This comes from framepack_original_dims state + total_second_length: float, + framepack_video_sections: Optional[int], + fps: int, + seed: int, + steps: int, + distilled_guidance_scale: float, + cfg: float, + rs: float, + sample_solver: str, + latent_window_size: int, + fp8: bool, + fp8_scaled: bool, + fp8_llm: bool, + blocks_to_swap: int, + bulk_decode: bool, + attn_mode: str, + vae_chunk_size: Optional[int], + vae_spatial_tile_sample_min_size: Optional[int], + device: Optional[str], + use_teacache: bool, + teacache_steps: int, + teacache_thresh: float, + batch_size: int, + save_path: str, + lora_folder: str, + enable_preview: bool, + preview_every_n_sections: int, + use_full_video_preview: bool, + is_f1: bool, + use_random_folder: bool, + input_folder_path: str, + *args: Any +) -> Generator[Tuple[List[Tuple[str, str]], Optional[str], str, str], None, None]: + """Generate video using fpack_generate_video.py""" + global stop_event + stop_event.clear() + + if not save_path or not save_path.strip(): + print("Warning: save_path was empty, defaulting to 'outputs'") + save_path = "outputs" + + num_section_controls = 4 + num_loras = 4 + secs_end = num_section_controls + prompts_end = secs_end + num_section_controls + images_end = prompts_end + num_section_controls + lora_weights_end = images_end + num_loras + lora_mults_end = lora_weights_end + num_loras + + framepack_secs = args[0:secs_end] + framepack_sec_prompts = args[secs_end:prompts_end] + framepack_sec_images = args[prompts_end:images_end] + lora_weights_list = list(args[images_end:lora_weights_end]) + lora_multipliers_list = list(args[lora_weights_end:lora_mults_end]) + + if not use_random_folder and not input_image and not any(img for img in framepack_sec_images if img): + yield [], None, "Error: Input start image or at least one section image override is required when not using folder mode.", "" + return + + if use_random_folder and (not input_folder_path or not os.path.isdir(input_folder_path)): + yield [], None, f"Error: Random image folder path '{input_folder_path}' is invalid or not a directory.", "" + return + + section_prompts_parts = [] + section_images_parts = [] + index_pattern = re.compile(r"^\d+(-\d+)?$") + + for idx_str, sec_prompt, sec_image in zip(framepack_secs, framepack_sec_prompts, framepack_sec_images): + if not idx_str or not isinstance(idx_str, str) or not index_pattern.match(idx_str.strip()): + if idx_str and idx_str.strip(): + print(f"Warning: Invalid section index/range format '{idx_str}'. Skipping.") + continue + current_idx_str = idx_str.strip() + if sec_prompt and sec_prompt.strip(): + section_prompts_parts.append(f"{current_idx_str}:{sec_prompt.strip()}") + if sec_image and os.path.exists(sec_image): + section_images_parts.append(f"{current_idx_str}:{sec_image}") + + final_prompt_arg = prompt + if section_prompts_parts: + final_prompt_arg = ";;;".join(section_prompts_parts) + print(f"Using section prompt overrides: {final_prompt_arg}") + + final_image_path_arg = None + if section_images_parts: + final_image_path_arg = ";;;".join(section_images_parts) + print(f"Using section image overrides for --image_path: {final_image_path_arg}") + elif input_image: + final_image_path_arg = input_image + print(f"Using base input image for --image_path: {final_image_path_arg}") + + # These are batch-wide defaults if not overridden by folder mode + target res per item. + batch_wide_final_height, batch_wide_final_width = None, None + + if framepack_width is not None and framepack_width > 0 and framepack_height is not None and framepack_height > 0: + if framepack_width % 8 != 0 or framepack_height % 8 != 0: + yield [], None, "Error: Explicit Width and Height must be divisible by 8.", "" + return + batch_wide_final_height = int(framepack_height) + batch_wide_final_width = int(framepack_width) + print(f"Using explicit dimensions for all items: H={batch_wide_final_height}, W={batch_wide_final_width}") + elif target_resolution is not None and target_resolution > 0 and not use_random_folder: + # This case applies if: + # 1. Target resolution is set. + # 2. We are NOT in random folder mode (so aspect ratio from UI image is reliable). + if not original_dims_str: # original_dims_str comes from the UI input image + yield [], None, "Error: Target Resolution selected (not in folder mode), but no UI input image provided for aspect ratio.", "" + return + try: + orig_w, orig_h = map(int, original_dims_str.split('x')) + if orig_w <= 0 or orig_h <= 0: + yield [], None, "Error: Invalid original dimensions stored from UI image.", "" + return + bucket_dims = find_nearest_bucket(orig_h, orig_w, resolution=target_resolution) + if bucket_dims: + batch_wide_final_height, batch_wide_final_width = bucket_dims + print(f"Using Target Resolution {target_resolution} with UI image aspect. Batch-wide bucket: H={batch_wide_final_height}, W={batch_wide_final_width}") + else: + yield [], None, f"Error: Could not find bucket for Target Res {target_resolution} and UI image aspect.", "" + return + except Exception as e: + yield [], None, f"Error calculating bucket dimensions from UI image: {e}", "" + return + elif use_random_folder and target_resolution is not None and target_resolution > 0: + # Folder mode with target resolution: resolution will be determined per item. + # batch_wide_final_height and batch_wide_final_width remain None. + print(f"Folder mode with Target Resolution {target_resolution}. Resolution will be determined per item.") + elif not (framepack_width is not None and framepack_width > 0 and framepack_height is not None and framepack_height > 0) and \ + not (target_resolution is not None and target_resolution > 0): + # This is the fallback if no resolution strategy is active for the batch. + yield [], None, "Error: Resolution required. Please provide Target Resolution OR valid Width and Height (divisible by 8).", "" + return + + all_videos = [] + if framepack_video_sections is not None and framepack_video_sections > 0: + total_sections_estimate = framepack_video_sections + print(f"Using user-defined total sections for UI: {total_sections_estimate}") + else: + total_sections_estimate_float = (total_second_length * fps) / (latent_window_size * 4) + total_sections_estimate = int(max(round(total_sections_estimate_float), 1)) + print(f"Calculated total sections for UI from duration: {total_sections_estimate}") + progress_text = f"Starting FramePack generation batch ({total_sections_estimate} estimated sections per video)..." + status_text = "Preparing batch..." + yield all_videos, None, status_text, progress_text + + valid_loras_paths = [] + valid_loras_mults = [] + if lora_folder and os.path.exists(lora_folder): + for weight_name, mult in zip(lora_weights_list, lora_multipliers_list): + if weight_name and weight_name != "None": + if os.path.isabs(weight_name): + lora_path = weight_name + else: + lora_path = os.path.join(lora_folder, weight_name) + if os.path.exists(lora_path): + valid_loras_paths.append(lora_path) + valid_loras_mults.append(str(mult)) + else: + print(f"Warning: LoRA file not found: {lora_path}") + + for i in range(batch_size): # <<< START OF THE BATCH LOOP >>> + if stop_event.is_set(): + yield all_videos, None, "Generation stopped by user.", "" + return + skip_event.clear() + + last_preview_mtime = 0 + + run_id = f"{int(time.time())}_{random.randint(1000, 9999)}" + unique_preview_suffix = f"fpack_{run_id}" + preview_base_path = os.path.join(save_path, f"latent_preview_{unique_preview_suffix}") + preview_mp4_path = preview_base_path + ".mp4" + preview_png_path = preview_base_path + ".png" + + current_seed = seed + if seed == -1: current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: current_seed = seed + i + + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed})" + progress_text_update = f"Item {i+1}/{batch_size}: Preparing..." # Renamed progress_text to progress_text_update for clarity + current_video_path = None + current_preview_yield_path = None + current_input_image_for_item = input_image + current_original_dims_str_for_item = original_dims_str # Use batch-wide original_dims_str initially + + if use_random_folder: + progress_text_update = f"Item {i+1}/{batch_size}: Selecting random image..." + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + + random_image_path, random_status = get_random_image_from_folder(input_folder_path) + if random_image_path is None: + error_msg = f"Error for item {i+1}/{batch_size}: {random_status}. Skipping." + print(error_msg) + yield all_videos.copy(), None, status_text, error_msg + continue + + current_input_image_for_item = random_image_path + progress_text_update = f"Item {i+1}/{batch_size}: Using random image: {os.path.basename(random_image_path)}" + print(progress_text_update) + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + + # Derive original_dims_str_for_item from the random image if using target resolution + # and explicit UI W/H were not provided. + if target_resolution is not None and target_resolution > 0 and \ + not (framepack_width is not None and framepack_width > 0 and framepack_height is not None and framepack_height > 0): + try: + img_for_dims = Image.open(random_image_path) + rand_w, rand_h = img_for_dims.size + current_original_dims_str_for_item = f"{rand_w}x{rand_h}" + print(f"Folder mode item {i+1}: Using random image dims {current_original_dims_str_for_item} for target resolution bucketing.") + except Exception as e: + error_msg = f"Error getting dims for random image {random_image_path}: {e}. Skipping item {i+1}." + print(error_msg) + yield all_videos.copy(), None, status_text, error_msg + continue + + final_image_path_arg_for_item = None + if section_images_parts: + final_image_path_arg_for_item = ";;;".join(section_images_parts) + if current_input_image_for_item: + has_section_0_override = any(part.strip().startswith("0:") for part in section_images_parts) + if not has_section_0_override: + final_image_path_arg_for_item = f"0:{current_input_image_for_item};;;{final_image_path_arg_for_item}" + print(f"Using section image overrides (potentially with prepended base) for --image_path (item {i+1}): {final_image_path_arg_for_item}") + elif current_input_image_for_item: + final_image_path_arg_for_item = current_input_image_for_item + print(f"Using {'random' if use_random_folder else 'base'} input image as the primary for --image_path (item {i+1}): {final_image_path_arg_for_item}") + + if final_image_path_arg_for_item is None: + yield [], None, f"Error for item {i+1}: No valid start image could be determined. Ensure an image is provided.", "" + continue + + final_height_for_item, final_width_for_item = None, None + + # 1. Use batch-wide dimensions if they were set (from explicit UI W/H or target_res + UI image) + if batch_wide_final_height is not None and batch_wide_final_width is not None: + final_height_for_item = batch_wide_final_height + final_width_for_item = batch_wide_final_width + print(f"Item {i+1}: Using batch-wide dimensions: H={final_height_for_item}, W={final_width_for_item}") + # 2. Else, if using target resolution (this implies folder mode, as other cases were handled above) + elif target_resolution is not None and target_resolution > 0: + if not current_original_dims_str_for_item: # This should now be populated for folder mode + yield [], None, f"Error for item {i+1}: Target Resolution selected, but no original dimensions available for aspect ratio.", "" + continue + try: + orig_w_item, orig_h_item = map(int, current_original_dims_str_for_item.split('x')) + if orig_w_item <= 0 or orig_h_item <= 0: + yield [], None, f"Error for item {i+1}: Invalid original dimensions '{current_original_dims_str_for_item}'.", "" + continue + bucket_dims_item = find_nearest_bucket(orig_h_item, orig_w_item, resolution=target_resolution) + if bucket_dims_item: + final_height_for_item, final_width_for_item = bucket_dims_item + print(f"Item {i+1}: Using Target Resolution {target_resolution} with item-specific aspect from '{current_original_dims_str_for_item}'. Bucket: H={final_height_for_item}, W={final_width_for_item}") + else: + yield [], None, f"Error for item {i+1}: Could not find bucket for Target Res {target_resolution} and aspect {current_original_dims_str_for_item}.", "" + continue + except Exception as e_res: + yield [], None, f"Error calculating bucket dimensions for item {i+1} ({current_original_dims_str_for_item}): {e_res}", "" + continue + else: + # This case should ideally not be hit if the initial batch-wide resolution checks were thorough. + # It implies no explicit W/H, no target_res, or some other unhandled state. + yield [], None, f"Error for item {i+1}: Failed to determine resolution strategy for the item.", "" + continue # Skip this item + + if final_height_for_item is None or final_width_for_item is None: # Final check for the item + yield [], None, f"Error for item {i+1}: Final resolution could not be determined for this item.", "" + continue + + # Update status text with the preparing subprocess message + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update # Use progress_text_update + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + clear_cuda_cache() + + command = [ + sys.executable, "fpack_generate_video.py", + "--text_encoder1", text_encoder_path, "--text_encoder2", text_encoder_2_path, + "--image_encoder", image_encoder_path, + *(["--image_path", final_image_path_arg_for_item] if final_image_path_arg_for_item else []), + "--save_path", save_path, "--prompt", final_prompt_arg, + "--video_size", str(final_height_for_item), str(final_width_for_item), + *(["--video_sections", str(framepack_video_sections)] if framepack_video_sections is not None and framepack_video_sections > 0 else ["--video_seconds", str(total_second_length)]), + "--infer_steps", str(steps), "--seed", str(current_seed), + "--embedded_cfg_scale", str(distilled_guidance_scale), + "--guidance_scale", str(cfg), "--guidance_rescale", str(rs), + "--latent_window_size", str(latent_window_size), + "--sample_solver", sample_solver, "--output_type", "video", "--attn_mode", attn_mode + ] + if is_f1: command.append("--is_f1") + if transformer_path and os.path.exists(transformer_path): command.extend(["--dit", transformer_path.strip()]) + if vae_path and os.path.exists(vae_path): command.extend(["--vae", vae_path.strip()]) + if negative_prompt and negative_prompt.strip(): command.extend(["--negative_prompt", negative_prompt.strip()]) + if input_end_frame and os.path.exists(input_end_frame): command.extend(["--end_image_path", input_end_frame]) + if fp8: command.append("--fp8") + if fp8 and fp8_scaled: command.append("--fp8_scaled") + if fp8_llm: command.append("--fp8_llm") + if bulk_decode: command.append("--bulk_decode") + if blocks_to_swap > 0: command.extend(["--blocks_to_swap", str(blocks_to_swap)]) + if vae_chunk_size is not None and vae_chunk_size > 0: command.extend(["--vae_chunk_size", str(vae_chunk_size)]) + if vae_spatial_tile_sample_min_size is not None and vae_spatial_tile_sample_min_size > 0: command.extend(["--vae_spatial_tile_sample_min_size", str(vae_spatial_tile_sample_min_size)]) + if device and device.strip(): command.extend(["--device", device.strip()]) + if valid_loras_paths: + command.extend(["--lora_weight"] + valid_loras_paths) + command.extend(["--lora_multiplier"] + valid_loras_mults) + if enable_preview and preview_every_n_sections > 0: + command.extend(["--preview_latent_every", str(preview_every_n_sections)]) + command.extend(["--preview_suffix", unique_preview_suffix]) + if use_full_video_preview: # Check if full preview is requested + command.append("--full_preview") + print(f"DEBUG: Enabling FULL VIDEO preview every {preview_every_n_sections} sections with suffix {unique_preview_suffix}.") + else: + print(f"DEBUG: Enabling latent preview every {preview_every_n_sections} sections with suffix {unique_preview_suffix}.") + if use_teacache: + command.append("--use_teacache") + command.extend(["--teacache_steps", str(teacache_steps)]) + command.extend(["--teacache_thresh", str(teacache_thresh)]) + + command_str = [str(c) for c in command] + print(f"Running FramePack Command: {' '.join(command_str)}") + + p = subprocess.Popen( + command_str, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 + ) + current_phase = "Preparing" + actual_total_sections = None + display_section_num = 1 + + while True: + if stop_event.is_set(): + try: + p.terminate() + p.wait(timeout=5) + except subprocess.TimeoutExpired: + p.kill(); p.wait() + except Exception as e: + print(f"Error terminating subprocess: {e}") + yield all_videos.copy(), None, "Generation stopped by user.", "" + return + if skip_event.is_set(): + print(f"Skip signal received for batch item {i+1}. Terminating subprocess...") + try: + p.terminate() + p.wait(timeout=5) + except subprocess.TimeoutExpired: + p.kill(); p.wait() + except Exception as e: + print(f"Error terminating subprocess during skip: {e}") + skip_event.clear() + yield all_videos.copy(), current_preview_yield_path, f"Skipping item {i+1}/{batch_size}...", "" + break + + line = p.stdout.readline() + if not line: + if p.poll() is not None: break + time.sleep(0.01); continue + + line = line.strip() + if not line: continue + print(f"SUBPROCESS: {line}") + + section_match = re.search(r"---.*?Section\s+(\d+)\s*/\s*(\d+)(?:\s+|$|\()", line) + tqdm_match = re.search(r'(\d+)\%\|.+\| (\d+)/(\d+) \[(\d{2}:\d{2})<(\d{2}:\d{2})', line) + phase_changed = False # Initialize phase_changed inside the loop + + # Default progress_text_update to the current line for general logging + progress_text_update = line # This was defined outside the loop before, moved inside + + if section_match: + current_section_num_display = int(section_match.group(1)) + total_sections_from_log = int(section_match.group(2)) + display_section_num = current_section_num_display + if actual_total_sections != total_sections_from_log: + actual_total_sections = total_sections_from_log + print(f"Detected/Updated actual total sections: {actual_total_sections}") + new_phase = f"Generating Section {display_section_num}" + if current_phase != new_phase: + current_phase = new_phase + phase_changed = True + progress_text_update = f"Item {i+1}/{batch_size} | Section {display_section_num}/{actual_total_sections} | Preparing..." + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed}) - {current_phase}" + elif tqdm_match: + percentage = int(tqdm_match.group(1)) + current_step = int(tqdm_match.group(2)) + total_steps = int(tqdm_match.group(3)) + time_elapsed = tqdm_match.group(4) + time_remaining = tqdm_match.group(5) + current_total_for_display = actual_total_sections if actual_total_sections is not None else total_sections_estimate + section_str = f"Section {display_section_num}/{current_total_for_display}" + progress_text_update = f"Item {i+1}/{batch_size} | {section_str} | Step {current_step}/{total_steps} ({percentage}%) | Elapsed: {time_elapsed}, Remaining: {time_remaining}" + denoising_phase = f"Denoising Section {display_section_num}" + if current_phase != denoising_phase: + current_phase = denoising_phase + phase_changed = True + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed}) - {current_phase}" + elif "Decoding video..." in line: + if current_phase != "Decoding Video": + current_phase = "Decoding Video" + phase_changed = True + progress_text_update = f"Item {i+1}/{batch_size} | {current_phase}..." + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed}) - {current_phase}" + elif "INFO:__main__:Video saved to:" in line: + match = re.search(r"Video saved to:\s*(.*\.mp4)", line) + if match: + found_video_path = match.group(1).strip() + if os.path.exists(found_video_path): + current_video_path = found_video_path + # Don't add to all_videos here, add after subprocess completion + else: + print(f"Warning: Parsed video path does not exist: {found_video_path}") + status_text = f"Video {i+1}/{batch_size} Saved (Seed: {current_seed})" + progress_text_update = f"Saved: {os.path.basename(found_video_path) if found_video_path else 'Unknown Path'}" + current_phase = "Saved" + phase_changed = True + else: + print(f"Warning: Could not parse video path from INFO line: {line}") + elif "ERROR" in line.upper() or "TRACEBACK" in line.upper(): + status_text = f"Item {i+1}/{batch_size}: Error Detected (Check Console)" + progress_text_update = line + if current_phase != "Error": + current_phase = "Error" + phase_changed = True + elif phase_changed and current_phase not in ["Saved", "Error"]: + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed}) - {current_phase}" + + preview_updated = False + current_mtime_check = 0 # Renamed from current_mtime to avoid conflict + found_preview_path_check = None # Renamed + + if enable_preview: + if os.path.exists(preview_mp4_path): + current_mtime_check = os.path.getmtime(preview_mp4_path) + found_preview_path_check = preview_mp4_path + elif os.path.exists(preview_png_path): + current_mtime_check = os.path.getmtime(preview_png_path) + found_preview_path_check = preview_png_path + + if found_preview_path_check and current_mtime_check > last_preview_mtime: + print(f"DEBUG: Preview file updated: {found_preview_path_check} (mtime: {current_mtime_check})") + current_preview_yield_path = found_preview_path_check + last_preview_mtime = current_mtime_check + preview_updated = True + + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + + p.stdout.close(); rc = p.wait() + clear_cuda_cache(); time.sleep(0.1) + + if rc == 0 and current_video_path and os.path.exists(current_video_path): + all_videos.append((current_video_path, f"Seed: {current_seed}")) # Add video here + parameters = { + "prompt": prompt, "negative_prompt": negative_prompt, + "input_image": os.path.basename(current_input_image_for_item) if current_input_image_for_item else None, + "section_controls": [ + {"index": s, "prompt_override": p_override, "image_override": os.path.basename(img_override) if img_override else None} + for s, p_override, img_override in zip(framepack_secs, framepack_sec_prompts, framepack_sec_images) + if (p_override and p_override.strip()) or img_override + ], + "final_prompt_arg": final_prompt_arg, + "final_image_path_arg": final_image_path_arg_for_item, # Use item-specific image path + "input_end_frame": os.path.basename(input_end_frame) if input_end_frame else None, + "transformer_path": transformer_path, "vae_path": vae_path, + "text_encoder_path": text_encoder_path, "text_encoder_2_path": text_encoder_2_path, + "image_encoder_path": image_encoder_path, + "video_width": final_width_for_item, "video_height": final_height_for_item, + "video_seconds": total_second_length, "fps": fps, "seed": current_seed, + "infer_steps": steps, "embedded_cfg_scale": distilled_guidance_scale, + "guidance_scale": cfg, "guidance_rescale": rs, "sample_solver": sample_solver, + "latent_window_size": latent_window_size, + "fp8": fp8, "fp8_scaled": fp8_scaled, "fp8_llm": fp8_llm, + "blocks_to_swap": blocks_to_swap, "bulk_decode": bulk_decode, "attn_mode": attn_mode, + "vae_chunk_size": vae_chunk_size, "vae_spatial_tile_sample_min_size": vae_spatial_tile_sample_min_size, + "device": device, + "lora_weights": [os.path.basename(p) for p in valid_loras_paths], + "lora_multipliers": [float(m) for m in valid_loras_mults], + "original_dims_str": current_original_dims_str_for_item, + "target_resolution": target_resolution, + "is_f1": is_f1 + } + try: + add_metadata_to_video(current_video_path, parameters) + print(f"Added metadata to {current_video_path}") + except Exception as meta_err: + print(f"Warning: Failed to add metadata to {current_video_path}: {meta_err}") + status_text = f"Item {i+1}/{batch_size} Completed (Seed: {current_seed})" + progress_text_update = f"Video saved: {os.path.basename(current_video_path)}" + current_preview_yield_path = None # Clear preview for next item + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + elif rc != 0: + status_text = f"Item {i+1}/{batch_size} Failed (Seed: {current_seed}, Code: {rc})" + progress_text_update = f"Subprocess failed. Check console logs." + current_preview_yield_path = None # Clear preview + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + else: + status_text = f"Item {i+1}/{batch_size} Finished (Seed: {current_seed}), but no video file confirmed." + progress_text_update = "Check console logs for the saved path." + current_preview_yield_path = None # Clear preview + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + + # Cleanup preview files for the completed item to avoid them being picked up by next item + if enable_preview: + for prev_file in [preview_mp4_path, preview_png_path]: + if os.path.exists(prev_file): + try: + os.remove(prev_file) + print(f"Cleaned up preview file: {prev_file}") + except Exception as e_clean: + print(f"Warning: Could not remove preview file {prev_file}: {e_clean}") + + time.sleep(0.2) + + yield all_videos, None, "FramePack Batch complete", "" + +def calculate_framepack_width(height, original_dims): + """Calculate FramePack width based on height maintaining aspect ratio (divisible by 32)""" + if not original_dims or height is None: + return gr.update() + try: + # Ensure height is an integer and divisible by 32 + height = int(height) + if height <= 0 : return gr.update() + height = (height // 32) * 32 # <-- Use 32 + height = max(64, height) # Min height (64 is divisible by 32) + + orig_w, orig_h = map(int, original_dims.split('x')) + if orig_h == 0: return gr.update() + aspect_ratio = orig_w / orig_h + # Calculate new width, rounding to the nearest multiple of 32 + new_width = round((height * aspect_ratio) / 32) * 32 # <-- Round and use 32 + return gr.update(value=max(64, new_width)) # Ensure minimum size (also divisible by 32) + + except Exception as e: + print(f"Error calculating width: {e}") + return gr.update() + +def calculate_framepack_height(width, original_dims): + """Calculate FramePack height based on width maintaining aspect ratio (divisible by 32)""" + if not original_dims or width is None: + return gr.update() + try: + # Ensure width is an integer and divisible by 32 + width = int(width) + if width <= 0: return gr.update() + width = (width // 32) * 32 # <-- Use 32 + width = max(64, width) # Min width (64 is divisible by 32) + + orig_w, orig_h = map(int, original_dims.split('x')) + if orig_w == 0: return gr.update() + aspect_ratio = orig_w / orig_h + # Calculate new height, rounding to the nearest multiple of 32 + new_height = round((width / aspect_ratio) / 32) * 32 # <-- Round and use 32 + return gr.update(value=max(64, new_height)) # Ensure minimum size (also divisible by 32) + except Exception as e: + print(f"Error calculating height: {e}") + return gr.update() + +def update_framepack_from_scale(scale, original_dims): + """Update FramePack dimensions based on scale percentage (divisible by 32)""" + if not original_dims: + return gr.update(), gr.update(), gr.update() + try: + scale = float(scale) if scale is not None else 100.0 + if scale <= 0: scale = 100.0 + + orig_w, orig_h = map(int, original_dims.split('x')) + scale_factor = scale / 100.0 + + # Calculate and round to the nearest multiple of 32 + new_w = round((orig_w * scale_factor) / 32) * 32 # <-- Round and use 32 + new_h = round((orig_h * scale_factor) / 32) * 32 # <-- Round and use 32 + + # Ensure minimum size (must be multiple of 32) + new_w = max(64, new_w) # 64 is divisible by 32 + new_h = max(64, new_h) + + # Clear target resolution if using scale slider for explicit dims + return gr.update(value=new_w), gr.update(value=new_h), gr.update(value=None) + except Exception as e: + print(f"Error updating from scale: {e}") + return gr.update(), gr.update(), gr.update() + +def process_i2v_single_video( + prompt: str, + image_path: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + clip_vision_path: str, + save_path: str, + flow_shift: float, + cfg_scale: float, # embedded_cfg_scale + guidance_scale: float, # main CFG + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + vae_chunk_size: int, + vae_spatial_tile_min: int, + # --- Explicit LoRA args instead of *lora_params --- + lora1: str = "None", + lora2: str = "None", + lora3: str = "None", + lora4: str = "None", + lora1_multiplier: float = 1.0, + lora2_multiplier: float = 1.0, + lora3_multiplier: float = 1.0, + lora4_multiplier: float = 1.0, + # --- End LoRA args --- + negative_prompt: Optional[str] = None, + use_fp8: bool = False, + fp8_llm: bool = False +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate a single video using hv_i2v_generate_video.py""" + global stop_event + + # ... (Keep existing argument validation and env setup) ... + if stop_event.is_set(): + yield [], "", "" + return + + # Argument validation + if not image_path or not os.path.exists(image_path): + yield [], "Error: Input image not found", f"Cannot find image: {image_path}" + return + # Check clip vision path only if needed (Hunyuan-I2V, not SkyReels-I2V based on script name) + is_hunyuan_i2v = "mp_rank_00_model_states_i2v" in model # Heuristic check + if is_hunyuan_i2v and (not clip_vision_path or not os.path.exists(clip_vision_path)): + yield [], "Error: CLIP Vision model not found", f"Cannot find file: {clip_vision_path}" + return + + if os.path.isabs(model): + model_path = model + else: + model_path = os.path.normpath(os.path.join(dit_folder, model)) + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + current_seed = seed + + clear_cuda_cache() + + command = [ + sys.executable, + "hv_i2v_generate_video.py", # <<< Use the new script + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + # Add clip vision path only if it's likely the Hunyuan I2V model + *(["--clip_vision_path", clip_vision_path] if is_hunyuan_i2v else []), + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(cfg_scale), + "--guidance_scale", str(guidance_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--image_path", image_path + ] + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + + if use_fp8: + command.append("--fp8") + if fp8_llm: + command.append("--fp8_llm") + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + if use_split_attn: + command.append("--split_attn") + + if vae_chunk_size > 0: + command.extend(["--vae_chunk_size", str(vae_chunk_size)]) + if vae_spatial_tile_min > 0: + command.extend(["--vae_spatial_tile_sample_min_size", str(vae_spatial_tile_min)]) + + # --- Updated LoRA handling using named arguments --- + lora_weights_list = [lora1, lora2, lora3, lora4] + lora_multipliers_list = [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier] + valid_loras = [] + for weight, mult in zip(lora_weights_list, lora_multipliers_list): + if weight and weight != "None": + lora_file_path = os.path.join(lora_folder, weight) + if os.path.exists(lora_file_path): + valid_loras.append((lora_file_path, mult)) + else: + print(f"Warning: LoRA file not found: {lora_file_path}") + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + # --- End Updated LoRA handling --- + + # ... (Keep subprocess execution, output collection, and metadata saving logic) ... + command_str = [str(c) for c in command] # Ensure all args are strings + print(f"Running Command (I2V): {' '.join(command_str)}") + + p = subprocess.Popen( + command_str, # Use stringified command + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield videos, current_previews, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') # Print progress to console + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + generated_video_path = None + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + # Try to find the video matching the seed + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + generated_video_path = os.path.join(save_path_abs, matching_videos[0]) + + if generated_video_path: + # Collect parameters for metadata (adjust as needed for i2v specifics) + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "model": model, + "vae": vae, + "te1": te1, + "te2": te2, + "clip_vision_path": clip_vision_path, + "save_path": save_path, + "flow_shift": flow_shift, + "embedded_cfg_scale": cfg_scale, + "guidance_scale": guidance_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "lora_weights": list(lora_weights_list), # Save the list + "lora_multipliers": list(lora_multipliers_list), # Save the list + "input_image": image_path, + "negative_prompt": negative_prompt if negative_prompt else None, + "vae_chunk_size": vae_chunk_size, + "vae_spatial_tile_min": vae_spatial_tile_min, + "use_fp8_dit": use_fp8, + "use_fp8_llm": fp8_llm + } + add_metadata_to_video(generated_video_path, parameters) + videos.append((str(generated_video_path), f"Seed: {current_seed}")) + yield videos, f"Completed (seed: {current_seed})", "" + else: + yield [], f"Failed (seed: {current_seed})", "Could not find generated video file." + + +def process_i2v_batch( + prompt: str, + image_path: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + clip_vision_path: str, # Added + save_path: str, + flow_shift: float, + cfg_scale: float, # embedded_cfg_scale + guidance_scale: float, # main CFG + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + vae_chunk_size: int, # Added + vae_spatial_tile_min: int, # Added + negative_prompt: Optional[str] = None, # Added + use_fp8: bool = False, # Added + fp8_llm: bool = False, # Added + *lora_params # Captures LoRA weights and multipliers +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Process a batch of videos using the new I2V script""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting I2V generation..." + yield [], "Preparing...", progress_text + + # Extract LoRA weights and multipliers once + num_lora_weights = 4 + lora_weights_list = lora_params[:num_lora_weights] + lora_multipliers_list = lora_params[num_lora_weights:num_lora_weights*2] + + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user.", "" + return + + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size} (I2V)" + yield all_videos.copy(), batch_text, progress_text + + # Call the single video processing function + single_gen = process_i2v_single_video( + prompt=prompt, + image_path=image_path, + width=width, + height=height, + batch_size=batch_size, + video_length=video_length, + fps=fps, + infer_steps=infer_steps, + seed=current_seed, + dit_folder=dit_folder, + model=model, + vae=vae, + te1=te1, + te2=te2, + clip_vision_path=clip_vision_path, + save_path=save_path, + flow_shift=flow_shift, + cfg_scale=cfg_scale, + guidance_scale=guidance_scale, + output_type=output_type, + attn_mode=attn_mode, + block_swap=block_swap, + exclude_single_blocks=exclude_single_blocks, + use_split_attn=use_split_attn, + lora_folder=lora_folder, + vae_chunk_size=vae_chunk_size, + vae_spatial_tile_min=vae_spatial_tile_min, + # --- Pass LoRA params by keyword --- + lora1=lora_weights_list[0], + lora2=lora_weights_list[1], + lora3=lora_weights_list[2], + lora4=lora_weights_list[3], + lora1_multiplier=lora_multipliers_list[0], + lora2_multiplier=lora_multipliers_list[1], + lora3_multiplier=lora_multipliers_list[2], + lora4_multiplier=lora_multipliers_list[3], + # --- End LoRA keyword args --- + negative_prompt=negative_prompt, + use_fp8=use_fp8, + fp8_llm=fp8_llm + ) + + # Yield progress updates from the single generator + try: + for videos, status, progress in single_gen: + if videos: + # Only add the latest video from this specific generation + new_video = videos[-1] + if new_video not in all_videos: + all_videos.append(new_video) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + except Exception as e: + yield all_videos.copy(), f"Error in batch {i+1}: {e}", "" + print(f"Error during single I2V generation: {e}") # Log error + + # Optional small delay between batch items + time.sleep(0.1) + + yield all_videos, "I2V Batch complete", "" + + +def wanx_extend_video_wrapper( + prompt, negative_prompt, input_image, base_video_path, + width, height, video_length, fps, infer_steps, + flow_shift, guidance_scale, seed, + task, dit_folder, dit_path, vae_path, t5_path, clip_path, # <--- Parameters received here + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers="", slg_start=0.0, slg_end=1.0, + lora1="None", lora2="None", lora3="None", lora4="None", + lora1_multiplier=1.0, lora2_multiplier=1.0, lora3_multiplier=1.0, lora4_multiplier=1.0, + enable_cfg_skip=False, cfg_skip_mode="none", cfg_apply_ratio=0.7 +): + """Direct wrapper that bypasses the problematic wanx_generate_video function""" + global stop_event + + # All videos generated + all_videos = [] + + # Debug prints to understand what we're getting + print(f"DEBUG - Received parameters in wanx_extend_video_wrapper:") + print(f" task: {task}") + print(f" dit_folder: {dit_folder}") # <<< Should be the folder path ('wan') + print(f" dit_path: {dit_path}") # <<< Should be the model filename + print(f" vae_path: {vae_path}") # <<< Should be the VAE path + print(f" t5_path: {t5_path}") # <<< Should be the T5 path + print(f" clip_path: {clip_path}") # <<< Should be the CLIP path + print(f" output_type: {output_type}") + print(f" sample_solver: {sample_solver}") + print(f" attn_mode: {attn_mode}") + print(f" block_swap: {block_swap}") + + # Get current seed + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + + # --- START CRITICAL FIX --- + # Detect if parameters are swapped based on the pattern observed in the error log + # Check if dit_path looks like a VAE path (contains "VAE" or ends with .pth) + # AND dit_folder looks like a model filename (ends with .safetensors or .pt) + params_swapped = False + if dit_path and dit_folder and \ + (("VAE" in dit_path or dit_path.endswith(".pth")) and \ + (dit_folder.endswith(".safetensors") or dit_folder.endswith(".pt"))): + params_swapped = True + print("WARNING: Parameters appear to be swapped in extend workflow. Applying correction...") + + # Correct the parameters based on the observed swap + actual_model_filename = dit_folder # Original dit_folder was the filename + actual_vae_path = dit_path # Original dit_path was the VAE path + actual_t5_path = vae_path # Original vae_path was the T5 path + actual_clip_path = t5_path # Original t5_path was the CLIP path + + # Assign corrected values back to expected variable names for the rest of the function + dit_path = actual_model_filename + vae_path = actual_vae_path + t5_path = actual_t5_path + clip_path = actual_clip_path + dit_folder = "wan" # Assume default 'wan' folder if swapped + + print(f" Corrected dit_folder: {dit_folder}") + print(f" Corrected dit_path (model filename): {dit_path}") + print(f" Corrected vae_path: {vae_path}") + print(f" Corrected t5_path: {t5_path}") + print(f" Corrected clip_path: {clip_path}") + + # Construct the full model path using the potentially corrected dit_folder and dit_path + actual_model_path = os.path.join(dit_folder, dit_path) if not os.path.isabs(dit_path) else dit_path + print(f" Using actual_model_path for --dit: {actual_model_path}") + # --- END CRITICAL FIX --- + + # Prepare environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + # Clear CUDA cache + clear_cuda_cache() + + # Validate and fix parameters + # Fix output_type - must be one of: video, images, latent, both + valid_output_types = ["video", "images", "latent", "both"] + actual_output_type = "video" if output_type not in valid_output_types else output_type + + # Fix sample_solver - must be one of: unipc, dpm++, vanilla + valid_sample_solvers = ["unipc", "dpm++", "vanilla"] + actual_sample_solver = "unipc" if sample_solver not in valid_sample_solvers else sample_solver + + # Fix attn_mode - must be one of: sdpa, flash, sageattn, xformers, torch + valid_attn_modes = ["sdpa", "flash", "sageattn", "xformers", "torch"] + actual_attn_mode = "sdpa" if attn_mode not in valid_attn_modes else attn_mode + + # Fix block_swap - must be an integer + try: + actual_block_swap = int(block_swap) + except (ValueError, TypeError): + actual_block_swap = 0 + + # Build command array with explicit string conversions for EVERY parameter + command = [ + sys.executable, + "wan_generate_video.py", + "--task", str(task), + "--prompt", str(prompt), + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", str(save_path), + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", actual_output_type, + "--sample_solver", actual_sample_solver, + "--attn_mode", actual_attn_mode, + "--blocks_to_swap", str(actual_block_swap), + # Use the corrected model path and other paths + "--dit", str(actual_model_path), # <<< Use corrected full model path + "--vae", str(vae_path), # <<< Use potentially corrected vae_path + "--t5", str(t5_path) # <<< Use potentially corrected t5_path + ] + + # Add image path and clip model path if needed + if input_image: + command.extend(["--image_path", str(input_image)]) + # Use the potentially corrected clip_path + if clip_path and clip_path != "outputs" and "output" not in clip_path: + command.extend(["--clip", str(clip_path)]) # <<< Use potentially corrected clip_path + + # Add negative prompt + if negative_prompt: + command.extend(["--negative_prompt", str(negative_prompt)]) + + # Handle boolean flags - keep original values + if fp8: + command.append("--fp8") + + if fp8_scaled: + command.append("--fp8_scaled") + + if fp8_t5: + command.append("--fp8_t5") + + # Add SLG parameters + try: + # Ensure slg_layers is treated as a string before splitting + slg_layers_str = str(slg_layers) if slg_layers is not None else "" + if slg_layers_str and slg_layers_str.strip() and slg_layers_str.lower() != "none": + slg_list = [] + for layer in slg_layers_str.split(","): + layer = layer.strip() + if layer.isdigit(): # Only add if it's a valid integer + slg_list.append(int(layer)) + if slg_list: # Only add if we have valid layers + command.extend(["--slg_layers", ",".join(map(str, slg_list))]) + + # Only add slg_start and slg_end if we have valid slg_layers + if slg_start is not None: + try: + slg_start_float = float(slg_start) + if slg_start_float >= 0: + command.extend(["--slg_start", str(slg_start_float)]) + except (ValueError, TypeError): pass # Ignore if conversion fails + if slg_end is not None: + try: + slg_end_float = float(slg_end) + if slg_end_float <= 1.0: + command.extend(["--slg_end", str(slg_end_float)]) + except (ValueError, TypeError): pass # Ignore if conversion fails + except Exception as e: # Catch potential errors during processing + print(f"Warning: Error processing SLG parameters: {e}") + pass + + # Handle LoRA weights and multipliers + valid_loras = [] + if lora_folder and isinstance(lora_folder, str): + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + # Skip None or empty values + if not weight or str(weight).lower() == "none": + continue + + # Construct path and check existence + full_path = os.path.join(str(lora_folder), str(weight)) + if not os.path.exists(full_path): + print(f"LoRA file not found: {full_path}") + continue + + # Add valid LoRA + valid_loras.append((full_path, str(mult))) + + if valid_loras: + weights = [w for w, _ in valid_loras] + multipliers = [m for _, m in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + # Final conversion to ensure all elements are strings + command_str = [str(item) for item in command] + + print(f"Running Command (wanx_extend_video_wrapper): {' '.join(command_str)}") + + # Process execution + p = subprocess.Popen( + command_str, # Use stringified command + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] # Store the generated (non-extended) video first + + # Process stdout in real time + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + # Yield empty list during processing, actual video is collected later + yield [], f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + return_code = p.wait() # Get return code + + # Clean CUDA cache and wait + clear_cuda_cache() + time.sleep(0.5) + + # Check return code + if return_code != 0: + print(f"❌ Error: wan_generate_video.py exited with code {return_code}") + yield [], f"Failed (seed: {current_seed})", f"Subprocess failed with code {return_code}" + return + + # Find the *newly generated* video first + generated_video_path = None + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + # Find the most recent mp4 containing the seed + all_mp4_files = glob.glob(os.path.join(save_path_abs, f"*_{current_seed}*.mp4")) + if all_mp4_files: + generated_video_path = max(all_mp4_files, key=os.path.getmtime) + print(f"Found newly generated video: {generated_video_path}") + + # Add metadata to the generated video before potential concatenation + parameters = { + "prompt": prompt, "negative_prompt": negative_prompt, "input_image": input_image, + "width": width, "height": height, "video_length": video_length, "fps": fps, + "infer_steps": infer_steps, "flow_shift": flow_shift, "guidance_scale": guidance_scale, + "seed": current_seed, "task": task, "dit_path": actual_model_path, # Store the actual path used + "vae_path": vae_path, "t5_path": t5_path, "clip_path": clip_path, + "save_path": save_path, "output_type": actual_output_type, "sample_solver": actual_sample_solver, + "exclude_single_blocks": exclude_single_blocks, "attn_mode": actual_attn_mode, + "block_swap": actual_block_swap, "fp8": fp8, "fp8_scaled": fp8_scaled, "fp8_t5": fp8_t5, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "slg_layers": slg_layers, "slg_start": slg_start, "slg_end": slg_end, + "is_extension_source": True # Flag this as the source for an extension + } + add_metadata_to_video(generated_video_path, parameters) + # videos.append((str(generated_video_path), f"Generated segment (Seed: {current_seed})")) # Optionally yield segment + else: + print(f"Could not find generated video segment for seed {current_seed} in {save_path_abs}") + + # Stop here if no new video segment was generated + if not generated_video_path: + yield [], f"Failed (seed: {current_seed})", "Could not find generated video segment." + return + + # Now concatenate with base video if we have the new segment and a base_video_path + if generated_video_path and base_video_path and os.path.exists(base_video_path): + try: + print(f"Extending base video: {base_video_path}") + + # Create unique output filename for the *extended* video + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + output_filename = f"extended_{timestamp}_seed{current_seed}_{Path(base_video_path).stem}.mp4" + output_path = os.path.join(save_path_abs, output_filename) + + # Create a temporary file list for ffmpeg concatenation + list_file = os.path.join(save_path_abs, f"temp_concat_list_{current_seed}.txt") + with open(list_file, "w") as f: + f.write(f"file '{os.path.abspath(base_video_path)}'\n") + f.write(f"file '{os.path.abspath(generated_video_path)}'\n") # Use the newly generated segment + + print(f"Concatenating: {base_video_path} + {generated_video_path} -> {output_path}") + + # Run ffmpeg concatenation command + concat_command = [ + "ffmpeg", + "-f", "concat", + "-safe", "0", # Allow relative paths if needed, but we use absolute + "-i", list_file, + "-c", "copy", # Fast concatenation without re-encoding + "-y", # Overwrite output if exists + output_path + ] + + # Convert all command parts to strings + concat_command_str = [str(item) for item in concat_command] + + print(f"Running FFmpeg command: {' '.join(concat_command_str)}") + concat_result = subprocess.run(concat_command_str, check=False, capture_output=True, text=True) # Don't check=True initially + + # Clean up temporary list file + if os.path.exists(list_file): + try: + os.remove(list_file) + except OSError as e: + print(f"Warning: Could not remove temp list file {list_file}: {e}") + + + # Check if concatenation was successful + if concat_result.returncode == 0 and os.path.exists(output_path): + # Optionally, add metadata to the *extended* video as well + extended_parameters = parameters.copy() + extended_parameters["is_extension_source"] = False + extended_parameters["base_video"] = os.path.basename(base_video_path) + add_metadata_to_video(output_path, extended_parameters) + + extended_video_gallery_item = [(output_path, f"Extended (Seed: {current_seed})")] + print(f"✅ Successfully created extended video: {output_path}") + yield extended_video_gallery_item, "Extended video created successfully", "" + return # Success! + else: + print(f"❌ Failed to create extended video at {output_path}") + print(f"FFmpeg stderr: {concat_result.stderr}") + # Yield the generated segment if concatenation failed + yield [(generated_video_path, f"Generated segment (Seed: {current_seed})")], "Generated segment (extension failed)", f"FFmpeg failed: {concat_result.stderr[:200]}..." + return + + except Exception as e: + print(f"❌ Error during concatenation: {str(e)}") + # Yield the generated segment if concatenation failed + yield [(generated_video_path, f"Generated segment (Seed: {current_seed})")], "Generated segment (extension error)", f"Error: {str(e)}" + return + + # If we got here, base_video_path was likely None or didn't exist, but generation succeeded + yield [(generated_video_path, f"Generated segment (Seed: {current_seed})")], "Generated segment (no base video provided)", "" + +def wanx_v2v_generate_video( + prompt, + negative_prompt, + input_video, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + strength, + seed, + task, + dit_folder, + dit_path, + vae_path, + t5_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers, + slg_start, + slg_end, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0, + enable_cfg_skip=False, + cfg_skip_mode="none", + cfg_apply_ratio=0.7, +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate video with WanX model in video-to-video mode""" + global stop_event + + # Convert values safely to float or None + try: + slg_start_float = float(slg_start) if slg_start is not None and str(slg_start).lower() != "none" else None + except (ValueError, TypeError): + slg_start_float = None + print(f"Warning: Could not convert slg_start '{slg_start}' to float") + + try: + slg_end_float = float(slg_end) if slg_end is not None and str(slg_end).lower() != "none" else None + except (ValueError, TypeError): + slg_end_float = None + print(f"Warning: Could not convert slg_end '{slg_end}' to float") + + print(f"slg_start_float: {slg_start_float}, slg_end_float: {slg_end_float}") + + if stop_event.is_set(): + yield [], "", "" + return + + # Check if we need input video (required for v2v) + if not input_video: + yield [], "Error: No input video provided", "Please provide an input video for video-to-video generation" + return + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + current_seed = seed + + # Prepare environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + clear_cuda_cache() + + # Construct full dit_path including folder - this is the fix + full_dit_path = os.path.join(dit_folder, dit_path) if not os.path.isabs(dit_path) else dit_path + + command = [ + sys.executable, + "wan_generate_video.py", + "--task", task, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--dit", full_dit_path, # Use full_dit_path instead of dit_path + "--vae", vae_path, + "--t5", t5_path, + "--sample_solver", sample_solver, + "--video_path", input_video, # This is the key for v2v mode + "--strength", str(strength) # Strength parameter for v2v + ] + if enable_cfg_skip and cfg_skip_mode != "none": + command.extend([ + "--cfg_skip_mode", cfg_skip_mode, + "--cfg_apply_ratio", str(cfg_apply_ratio) + ]) + # Handle SLG parameters + if slg_layers and str(slg_layers).strip() and slg_layers.lower() != "none": + try: + # Parse SLG layers + layer_list = [int(x) for x in str(slg_layers).split(",")] + if layer_list: # Only proceed if we have valid layer values + command.extend(["--slg_layers", ",".join(map(str, layer_list))]) + + # Only add slg_start and slg_end if we have valid slg_layers + try: + if slg_start_float is not None and slg_start_float >= 0: + command.extend(["--slg_start", str(slg_start_float)]) + if slg_end_float is not None and slg_end_float <= 1.0: + command.extend(["--slg_end", str(slg_end_float)]) + except ValueError as e: + print(f"Invalid SLG timing values: {str(e)}") + except ValueError as e: + print(f"Invalid SLG layers format: {slg_layers} - {str(e)}") + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + + if fp8: + command.append("--fp8") + + if fp8_scaled: + command.append("--fp8_scaled") + + if fp8_t5: + command.append("--fp8_t5") + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + + # Handle LoRA weights and multipliers + lora_weights = [lora1, lora2, lora3, lora4] + lora_multipliers = [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier] + + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + full_path = os.path.join(lora_folder, weight) + if not os.path.exists(full_path): + print(f"LoRA file not found: {full_path}") + continue + valid_loras.append((full_path, mult)) + + if valid_loras: + weights = [w for w, _ in valid_loras] + multipliers = [str(m) for _, m in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + print(f"Running: {' '.join(command)}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "task": task, + "flow_shift": flow_shift, + "guidance_scale": guidance_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "input_video": input_video, + "strength": strength, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "dit_path": full_dit_path, # Store the full path in metadata + "vae_path": vae_path, + "t5_path": t5_path, + "negative_prompt": negative_prompt if negative_prompt else None, + "sample_solver": sample_solver + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +def wanx_v2v_batch_handler( + prompt, + negative_prompt, + input_video, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + strength, + seed, + batch_size, + task, + dit_folder, # folder path + dit_path, # model filename + vae_path, + t5_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers: str, + slg_start: Optional[str], + slg_end: Optional[str], + enable_cfg_skip: bool, + cfg_skip_mode: str, + cfg_apply_ratio: float, + *lora_params +): + """Handle batch generation for WanX v2v""" + global stop_event + stop_event.clear() + + # Extract LoRA parameters + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Process each item in the batch + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Generate a single video + for videos, status, progress in wanx_v2v_generate_video( + prompt, + negative_prompt, + input_video, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + strength, + current_seed, + task, + dit_folder, # Pass folder path + dit_path, # Pass model filename + vae_path, + t5_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers, + slg_start, + slg_end, + *lora_weights, + *lora_multipliers, + enable_cfg_skip, + cfg_skip_mode, + cfg_apply_ratio, + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + +def update_wanx_v2v_dimensions(video): + """Update dimensions from uploaded video""" + if video is None: + return "", gr.update(value=832), gr.update(value=480) + + cap = cv2.VideoCapture(video) + if not cap.isOpened(): + return "Error opening video", gr.update(), gr.update() + + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + # Make dimensions divisible by 32 + w = (w // 32) * 32 + h = (h // 32) * 32 + + return f"{w}x{h}", w, h + +def send_wanx_v2v_to_hunyuan_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + guidance_scale: float, + negative_prompt: str +) -> Tuple: + """Send the selected WanX v2v video to Hunyuan v2v tab""" + if gallery is None or not gallery: + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + # If no selection made but we have videos, use the first one + if selected_index is None and len(gallery) > 0: + selected_index = 0 + + if selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + # Clean up path for Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Make sure it's a string + video_path = str(video_path) + + return (video_path, prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def handle_wanx_v2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when gallery item is clicked""" + return evt.index + +def variance_of_laplacian(image): + """ + Compute the variance of the Laplacian of the image. + Higher variance indicates a sharper image. + """ + return cv2.Laplacian(image, cv2.CV_64F).var() + +def extract_sharpest_frame(video_path, frames_to_check=30): + """ + Extract the sharpest frame from the last N frames of the video. + + Args: + video_path (str): Path to the video file + frames_to_check (int): Number of frames from the end to check + + Returns: + tuple: (temp_image_path, frame_number, sharpness_score) + """ + print(f"\n=== Extracting sharpest frame from the last {frames_to_check} frames ===") + print(f"Input video path: {video_path}") + + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None, None, None + + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None, None, None + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + print(f"Total frames detected: {total_frames}, FPS: {fps:.2f}") + + if total_frames < 1: + print("❌ Error: Video contains 0 frames") + return None, None, None + + # Determine how many frames to check (the last N frames) + if frames_to_check > total_frames: + frames_to_check = total_frames + start_frame = 0 + else: + start_frame = total_frames - frames_to_check + + print(f"Checking frames {start_frame} to {total_frames-1}") + + # Find the sharpest frame + sharpest_frame = None + max_sharpness = -1 + sharpest_frame_number = -1 + + # Set starting position + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + # Process frames with a progress bar + with tqdm(total=frames_to_check, desc="Finding sharpest frame") as pbar: + frame_idx = start_frame + while frame_idx < total_frames: + ret, frame = cap.read() + if not ret: + break + + # Convert to grayscale and calculate sharpness + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + sharpness = variance_of_laplacian(gray) + + # Update if this is the sharpest frame so far + if sharpness > max_sharpness: + max_sharpness = sharpness + sharpest_frame = frame.copy() + sharpest_frame_number = frame_idx + + frame_idx += 1 + pbar.update(1) + + cap.release() + + if sharpest_frame is None: + print("❌ Error: Failed to find a sharp frame") + return None, None, None + + # Prepare output path + temp_dir = os.path.abspath("temp_frames") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, f"sharpest_frame_{os.path.basename(video_path)}.png") + print(f"Saving frame to: {temp_path}") + + # Write and verify + if not cv2.imwrite(temp_path, sharpest_frame): + print("❌ Error: Failed to write frame to file") + return None, None, None + + if not os.path.exists(temp_path): + print("❌ Error: Output file not created") + return None, None, None + + # Calculate frame time in seconds + frame_time = sharpest_frame_number / fps + + print(f"✅ Extracted sharpest frame: {sharpest_frame_number} (at {frame_time:.2f}s) with sharpness {max_sharpness:.2f}") + return temp_path, sharpest_frame_number, max_sharpness + + except Exception as e: + print(f"❌ Unexpected error: {str(e)}") + return None, None, None + finally: + if 'cap' in locals(): + cap.release() + +def trim_video_to_frame(video_path, frame_number, output_dir="outputs"): + """ + Trim video up to the specified frame and save as a new video. + + Args: + video_path (str): Path to the video file + frame_number (int): Frame number to trim to + output_dir (str): Directory to save the trimmed video + + Returns: + str: Path to the trimmed video file + """ + print(f"\n=== Trimming video to frame {frame_number} ===") + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None + + try: + # Get video information + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None + + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + + # Calculate time in seconds + time_seconds = frame_number / fps + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Generate output filename + timestamp = f"{int(time_seconds)}s" + base_name = Path(video_path).stem + output_file = os.path.join(output_dir, f"{base_name}_trimmed_to_{timestamp}.mp4") + + # Use ffmpeg to trim the video + ( + ffmpeg + .input(video_path) + .output(output_file, to=time_seconds, c="copy") + .global_args('-y') # Overwrite output files + .run(quiet=True) + ) + + if not os.path.exists(output_file): + print("❌ Error: Failed to create trimmed video") + return None + + print(f"✅ Successfully trimmed video to {time_seconds:.2f}s: {output_file}") + return output_file + + except Exception as e: + print(f"❌ Error trimming video: {str(e)}") + return None + +def send_sharpest_frame_handler(gallery, selected_idx, frames_to_check=30): + """ + Extract the sharpest frame from the last N frames of the selected video + + Args: + gallery: Gradio gallery component with videos + selected_idx: Index of the selected video + frames_to_check: Number of frames from the end to check + + Returns: + tuple: (image_path, video_path, frame_number, sharpness) + """ + if gallery is None or not gallery: + return None, None, None, "No videos in gallery" + + if selected_idx is None and len(gallery) == 1: + selected_idx = 0 + + if selected_idx is None or selected_idx >= len(gallery): + return None, None, None, "No video selected" + + # Get the video path + item = gallery[selected_idx] + if isinstance(item, tuple): + video_path = item[0] + elif isinstance(item, dict): + video_path = item.get('name') or item.get('data') + else: + video_path = str(item) + + # Extract the sharpest frame + image_path, frame_number, sharpness = extract_sharpest_frame(video_path, frames_to_check) + + if image_path is None: + return None, None, None, "Failed to extract sharpest frame" + + return image_path, video_path, frame_number, f"Extracted frame {frame_number} with sharpness {sharpness:.2f}" + +def trim_and_prepare_for_extension(video_path, frame_number, save_path="outputs"): + """ + Trim the video to the specified frame and prepare for extension. + + Args: + video_path: Path to the video file + frame_number: Frame number to trim to + save_path: Directory to save the trimmed video + + Returns: + tuple: (trimmed_video_path, status_message) + """ + if not video_path or not os.path.exists(video_path): + return None, "No video selected or video file does not exist" + + if frame_number is None: + return None, "No frame number provided, please extract sharpest frame first" + + # Trim the video + trimmed_video = trim_video_to_frame(video_path, frame_number, save_path) + + if trimmed_video is None: + return None, "Failed to trim video" + + return trimmed_video, f"Video trimmed to frame {frame_number} and ready for extension" + +def send_last_frame_handler(gallery, selected_idx): + """Handle sending last frame to input with better error handling""" + if gallery is None or not gallery: + return None, None + + if selected_idx is None and len(gallery) == 1: + selected_idx = 0 + + if selected_idx is None or selected_idx >= len(gallery): + return None, None + + # Get the frame and video path + frame = handle_last_frame_transfer(gallery, selected_idx) + video_path = None + + if selected_idx < len(gallery): + item = gallery[selected_idx] + video_path = parse_video_path(item) + + return frame, video_path + +def extract_last_frame(video_path: str) -> Optional[str]: + """Extract last frame from video and return temporary image path with error handling""" + print(f"\n=== Starting frame extraction ===") + print(f"Input video path: {video_path}") + + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None + + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"Total frames detected: {total_frames}") + + if total_frames < 1: + print("❌ Error: Video contains 0 frames") + return None + + # Extract last frame + cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) + success, frame = cap.read() + + if not success or frame is None: + print("❌ Error: Failed to read last frame") + return None + + # Prepare output path + temp_dir = os.path.abspath("temp_frames") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, f"last_frame_{os.path.basename(video_path)}.png") + print(f"Saving frame to: {temp_path}") + + # Write and verify + if not cv2.imwrite(temp_path, frame): + print("❌ Error: Failed to write frame to file") + return None + + if not os.path.exists(temp_path): + print("❌ Error: Output file not created") + return None + + print("✅ Frame extraction successful") + return temp_path + + except Exception as e: + print(f"❌ Unexpected error: {str(e)}") + return None + finally: + if 'cap' in locals(): + cap.release() + +def handle_last_frame_transfer(gallery: list, selected_idx: int) -> Optional[str]: + """Improved frame transfer with video input validation""" + try: + if gallery is None or not gallery: + raise ValueError("No videos generated yet") + + if selected_idx is None: + # Auto-select last generated video if batch_size=1 + if len(gallery) == 1: + selected_idx = 0 + else: + raise ValueError("Please select a video first") + + if selected_idx >= len(gallery): + raise ValueError("Invalid selection index") + + item = gallery[selected_idx] + + # Video file existence check + video_path = parse_video_path(item) + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file missing: {video_path}") + + return extract_last_frame(video_path) + + except Exception as e: + print(f"Frame transfer failed: {str(e)}") + return None + +def parse_video_path(item) -> str: + """Parse different gallery item formats""" + if isinstance(item, tuple): + return item[0] + elif isinstance(item, dict): + return item.get('name') or item.get('data') + return str(item) + +def get_random_image_from_folder(folder_path): + """Get a random image from the specified folder""" + if not os.path.isdir(folder_path): + return None, f"Error: {folder_path} is not a valid directory" + + # Get all image files in the folder + image_files = [] + for ext in ('*.jpg', '*.jpeg', '*.png', '*.bmp', '*.webp'): + image_files.extend(glob.glob(os.path.join(folder_path, ext))) + for ext in ('*.JPG', '*.JPEG', '*.PNG', '*.BMP', '*.WEBP'): + image_files.extend(glob.glob(os.path.join(folder_path, ext))) + + if not image_files: + return None, f"Error: No image files found in {folder_path}" + + # Select a random image + random_image = random.choice(image_files) + return random_image, f"Selected: {os.path.basename(random_image)}" + +def resize_image_keeping_aspect_ratio(image_path, max_width, max_height): + """Resize image keeping aspect ratio and ensuring dimensions are divisible by 16""" + try: + img = Image.open(image_path) + width, height = img.size + + # Calculate aspect ratio + aspect_ratio = width / height + + # Calculate new dimensions while maintaining aspect ratio + if width > height: + new_width = min(max_width, width) + new_height = int(new_width / aspect_ratio) + else: + new_height = min(max_height, height) + new_width = int(new_height * aspect_ratio) + + # Make dimensions divisible by 16 + new_width = math.floor(new_width / 16) * 16 + new_height = math.floor(new_height / 16) * 16 + + # Ensure minimum size + new_width = max(16, new_width) + new_height = max(16, new_height) + + # Resize image + resized_img = img.resize((new_width, new_height), Image.LANCZOS) + + # Save to temporary file + temp_path = f"temp_resized_{os.path.basename(image_path)}" + resized_img.save(temp_path) + + return temp_path, (new_width, new_height) + except Exception as e: + return None, f"Error: {str(e)}" +# Function to process a batch of images from a folder +def batch_handler( + use_random, + prompt, negative_prompt, + width, height, + video_length, fps, infer_steps, + seed, flow_shift, guidance_scale, embedded_cfg_scale, + batch_size, input_folder_path, + dit_folder, model, vae, te1, te2, save_path, output_type, attn_mode, + block_swap, exclude_single_blocks, use_split_attn, use_fp8, split_uncond, + lora_folder, *lora_params +): + """Handle both folder-based batch processing and regular batch processing""" + global stop_event + + # Check if this is a SkyReels model that needs special handling + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + + if use_random: + # Random image from folder mode + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Get random image from folder + random_image, status = get_random_image_from_folder(input_folder_path) + if random_image is None: + yield all_videos, f"Error in batch {i+1}: {status}", "" + continue + + # Resize image + resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) + if resized_image is None: + yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" + continue + + # If we have dimensions, update them + local_width, local_height = width, height + if isinstance(size_info, tuple): + local_width, local_height = size_info + progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height}" + else: + progress_text = f"Using image: {os.path.basename(random_image)}" + + yield all_videos.copy(), batch_text, progress_text + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + # Process the image + # For SkyReels models, we need to create a command with dit_in_channels=32 + if is_skyreels_i2v: + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model + + # Extract parameters from lora_params + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + cmd = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(local_height), str(local_width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(embedded_cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128", + "--dit_in_channels", "32", # This is crucial for SkyReels i2v + "--image_path", resized_image # Pass the image directly + ] + + if use_fp8: + cmd.append("--fp8") + + if split_uncond: + cmd.append("--split_uncond") + + if use_split_attn: + cmd.append("--split_attn") + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + if negative_prompt: + cmd.extend(["--negative_prompt", negative_prompt]) + + if guidance_scale is not None: + cmd.extend(["--guidance_scale", str(guidance_scale)]) + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + print(f"Running command: {' '.join(cmd)}") + + # Run the process + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield all_videos, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield all_videos.copy(), f"Processing video {i+1} (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + all_videos.append((str(video_path), f"Seed: {current_seed}")) + else: + # For non-SkyReels models, use the regular process_single_video function + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + single_video_args = [ + prompt, local_width, local_height, 1, video_length, fps, infer_steps, + current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, embedded_cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + single_video_args.extend(lora_weights) + single_video_args.extend(lora_multipliers) + single_video_args.extend([None, resized_image, None, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) + + for videos, status, progress in process_single_video(*single_video_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clean up temporary file + try: + if os.path.exists(resized_image): + os.remove(resized_image) + except: + pass + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # Regular image input - this is the part we need to fix + # When a SkyReels I2V model is used, we need to use the direct command approach + # with dit_in_channels=32 explicitly specified, just like in the folder processing branch + if is_skyreels_i2v: + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract lora parameters + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + extra_args = list(lora_params[num_lora_weights*2:]) if len(lora_params) > num_lora_weights*2 else [] + + # Print extra_args for debugging + print(f"Extra args: {extra_args}") + + # Get input image path from extra args - this is where we need to fix + # In skyreels_generate_btn.click, we're passing skyreels_input which + # should be the image path + image_path = None + if len(extra_args) > 0 and extra_args[0] is not None: + image_path = extra_args[0] + print(f"Image path found in extra_args[0]: {image_path}") + + # If we still don't have an image path, this is a problem + if not image_path: + # Let's try to debug what's happening - in the future, you can remove these + # debug prints once everything works correctly + print("No image path found in extra_args[0]") + print(f"Full lora_params: {lora_params}") + yield [], "Error: No input image provided", "An input image is required for SkyReels I2V models" + return + + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Set up environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model + + # Build the command with dit_in_channels=32 + cmd = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(embedded_cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128", + "--dit_in_channels", "32", # This is crucial for SkyReels i2v + "--image_path", image_path + ] + + if use_fp8: + cmd.append("--fp8") + + if split_uncond: + cmd.append("--split_uncond") + + if use_split_attn: + cmd.append("--split_attn") + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + if negative_prompt: + cmd.extend(["--negative_prompt", negative_prompt]) + + if guidance_scale is not None: + cmd.extend(["--guidance_scale", str(guidance_scale)]) + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + print(f"Running command: {' '.join(cmd)}") + + # Run the process + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield all_videos, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield all_videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + all_videos.append((str(video_path), f"Seed: {current_seed}")) + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # For regular non-SkyReels models, use the original process_batch function + regular_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, guidance_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + yield from process_batch(*(regular_args + list(lora_params))) + +def get_dit_models(dit_folder: str) -> List[str]: + """Get list of available DiT models in the specified folder""" + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + +def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: + """Update both DiT and LoRA dropdowns""" + # Get model lists + dit_models = get_dit_models(dit_folder) + lora_choices = get_lora_options(lora_folder) + + # Current values processing + dit_value = current_values[0] + if dit_value not in dit_models: + dit_value = dit_models[0] if dit_models else None + + weights = current_values[1:5] + multipliers = current_values[5:9] + + results = [gr.update(choices=dit_models, value=dit_value)] + + # Add LoRA updates + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in lora_choices: + weight = "None" + results.extend([ + gr.update(choices=lora_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def extract_video_metadata(video_path: str) -> Dict: + """Extract metadata from video file using ffprobe.""" + cmd = [ + 'ffprobe', + '-v', 'quiet', + '-print_format', 'json', + '-show_format', + video_path + ] + + try: + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + metadata = json.loads(result.stdout.decode('utf-8')) + if 'format' in metadata and 'tags' in metadata['format']: + comment = metadata['format']['tags'].get('comment', '{}') + return json.loads(comment) + return {} + except Exception as e: + print(f"Metadata extraction failed: {str(e)}") + return {} + +def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: + """Map metadata parameters to Gradio components for different tabs""" + mapping = { + 'common': { + 'prompt': ('prompt', 'v2v_prompt', 'wanx_v2v_prompt'), # Add WanX-v2v mapping + 'width': ('width', 'v2v_width', 'wanx_v2v_width'), + 'height': ('height', 'v2v_height', 'wanx_v2v_height'), + 'batch_size': ('batch_size', 'v2v_batch_size', 'wanx_v2v_batch_size'), + 'video_length': ('video_length', 'v2v_video_length', 'wanx_v2v_video_length'), + 'fps': ('fps', 'v2v_fps', 'wanx_v2v_fps'), + 'infer_steps': ('infer_steps', 'v2v_infer_steps', 'wanx_v2v_infer_steps'), + 'seed': ('seed', 'v2v_seed', 'wanx_v2v_seed'), + 'flow_shift': ('flow_shift', 'v2v_flow_shift', 'wanx_v2v_flow_shift'), + 'guidance_scale': ('cfg_scale', 'v2v_cfg_scale', 'wanx_v2v_guidance_scale'), + 'negative_prompt': ('negative_prompt', 'v2v_negative_prompt', 'wanx_v2v_negative_prompt'), + 'strength': ('strength', 'v2v_strength', 'wanx_v2v_strength') + }, + 'lora': { + 'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]', f'wanx_v2v_lora_weights[{i}]') for i in range(4)], + 'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]', f'wanx_v2v_lora_multipliers[{i}]') for i in range(4)] + } + } + + results = {} + for param, value in metadata.items(): + # Handle common parameters + if param in mapping['common']: + target_idx = 0 if target_tab == 't2v' else 1 if target_tab == 'v2v' else 2 + if target_idx < len(mapping['common'][param]): + target = mapping['common'][param][target_idx] + results[target] = value + + # Handle LoRA parameters + if param == 'lora_weights': + for i, weight in enumerate(value[:4]): + target_idx = 0 if target_tab == 't2v' else 1 if target_tab == 'v2v' else 2 + if target_idx < len(mapping['lora']['lora_weights'][i]): + target = mapping['lora']['lora_weights'][i][target_idx] + results[target] = weight + + if param == 'lora_multipliers': + for i, mult in enumerate(value[:4]): + target_idx = 0 if target_tab == 't2v' else 1 if target_tab == 'v2v' else 2 + if target_idx < len(mapping['lora']['lora_multipliers'][i]): + target = mapping['lora']['lora_multipliers'][i][target_idx] + results[target] = float(mult) + + return results + +def add_metadata_to_video(video_path: str, parameters: dict) -> None: + """Add generation parameters to video metadata using ffmpeg.""" + import json + import subprocess + + # Convert parameters to JSON string + params_json = json.dumps(parameters, indent=2) + + # Temporary output path + temp_path = video_path.replace(".mp4", "_temp.mp4") + + # Add Fun-Control information to metadata if applicable + task = parameters.get("task", "") + if task.endswith("-FC"): + parameters["fun_control"] = True + # Store the control path in metadata if available + if "control_path" in parameters: + parameters["control_video"] = os.path.basename(parameters["control_path"]) + + # FFmpeg command to add metadata without re-encoding + cmd = [ + 'ffmpeg', + '-i', video_path, + '-metadata', f'comment={params_json}', + '-codec', 'copy', + temp_path + ] + + try: + # Execute FFmpeg command + subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + # Replace original file with the metadata-enhanced version + os.replace(temp_path, video_path) + except subprocess.CalledProcessError as e: + print(f"Failed to add metadata: {e.stderr.decode()}") + if os.path.exists(temp_path): + os.remove(temp_path) + except Exception as e: + print(f"Error: {str(e)}") + +def count_prompt_tokens(prompt: str) -> int: + enc = tiktoken.get_encoding("cl100k_base") + tokens = enc.encode(prompt) + return len(tokens) + + +def get_lora_options(lora_folder: str = "lora") -> List[str]: + if not os.path.exists(lora_folder): + return ["None"] + lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] + lora_files.sort(key=str.lower) + return ["None"] + lora_files + +def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: + """Transfer selected video and prompt to Video2Video tab""" + if not gallery or evt.index >= len(gallery): + return None, "", selected_index.value + + selected_item = gallery[evt.index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Update the selected index + selected_index.value = evt.index + + return str(video_path), prompt, evt.index + +def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: + """Send the currently selected video to V2V tab""" + if not gallery or selected_index.value is None or selected_index.value >= len(gallery): + return None, "" + + selected_item = gallery[selected_index.value] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return str(video_path), prompt + +def clear_cuda_cache(): + """Clear CUDA cache if available""" + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Optional: synchronize to ensure cache is cleared + torch.cuda.synchronize() + +def wanx_batch_handler( + use_random, + prompt, + negative_prompt, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + batch_size, + input_folder_path, + wanx_input_end, + task, + dit_folder, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers: str, + slg_start: Optional[str], + slg_end: Optional[str], + enable_cfg_skip: bool, + cfg_skip_mode: str, + cfg_apply_ratio: float, + enable_preview: bool, + preview_steps: int, + *lora_params, # <-- DO NOT ADD NAMED ARGS AFTER THIS! +): + """Handle both folder-based batch processing and regular processing for all WanX tabs""" + global stop_event + + # Convert None strings to actual None + slg_layers = None if slg_layers == "None" else slg_layers + slg_start = None if slg_start == "None" else slg_start + slg_end = None if slg_end == "None" else slg_end + + # Construct full dit_path including folder + full_dit_path = os.path.join(dit_folder, dit_path) if not os.path.isabs(dit_path) else dit_path + # Clean up LoRA params to proper format + clean_lora_params = [] + for param in lora_params: + # Convert None strings to "None" for consistency + if param is None or str(param).lower() == "none": + clean_lora_params.append("None") + else: + clean_lora_params.append(str(param)) + + # Extract LoRA weights and multipliers + num_lora_weights = 4 + lora_weights = clean_lora_params[:num_lora_weights] + lora_multipliers = [] + for mult in clean_lora_params[num_lora_weights:num_lora_weights*2]: + try: + lora_multipliers.append(float(mult)) + except (ValueError, TypeError): + lora_multipliers.append(1.0) + while len(lora_weights) < 4: + lora_weights.append("None") + while len(lora_multipliers) < 4: + lora_multipliers.append(1.0) + + # Now extract trailing params: input_file, control_video, control_strength, control_start, control_end + remaining_params = clean_lora_params[num_lora_weights*2:] + input_file = remaining_params[0] if len(remaining_params) > 0 else None + control_video = remaining_params[1] if len(remaining_params) > 1 else None + try: + control_strength = float(remaining_params[2]) if len(remaining_params) > 2 else 1.0 + except Exception: + control_strength = 1.0 + try: + control_start = float(remaining_params[3]) if len(remaining_params) > 3 else 0.0 + except Exception: + control_start = 0.0 + try: + control_end = float(remaining_params[4]) if len(remaining_params) > 4 else 1.0 + except Exception: + control_end = 1.0 + + yield [], [], "Preparing batch...", "" # Clear main and preview galleries + + if use_random: + stop_event.clear() + all_videos = [] + all_previews = [] # Keep track of previews from the last successful item? Or clear each time? Let's clear. + progress_text = "Starting generation..." + yield [], [], "Preparing...", progress_text # Clear galleries again just in case + batch_size = int(batch_size) + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, [], "Generation stopped by user", "" # Yield empty previews on stop + return + + # --- Clear previews for this item --- + current_previews_for_item = [] + yield all_videos.copy(), current_previews_for_item, f"Generating video {i + 1} of {batch_size}", progress_text # Yield cleared previews + + # ... (Keep existing random image logic: get random, resize) ... + random_image, status = get_random_image_from_folder(input_folder_path) + if random_image is None: + yield all_videos, current_previews_for_item, f"Error in batch {i+1}: {status}", "" + continue # Skip to next batch item on error + + resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) + if resized_image is None: + yield all_videos, current_previews_for_item, f"Error resizing image in batch {i+1}: {size_info}", "" + # Clean up the random image if resize failed but image exists + try: + if os.path.exists(random_image) and "temp_resized" not in random_image: # Avoid double delete if resize output existed + pass # Might not want to delete original random image here + except: pass + continue # Skip to next batch item on error + + local_width, local_height = width, height + if isinstance(size_info, tuple): local_width, local_height = size_info + progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height}" + yield all_videos.copy(), current_previews_for_item, f"Generating video {i + 1} of {batch_size}", progress_text + + current_seed = seed + if seed == -1: current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: current_seed = seed + i + + # --- Corrected call to wanx_generate_video with accumulation --- + newly_generated_video = None # Track the video generated *in this iteration* + last_status_for_item = f"Generating video {i+1}/{batch_size}" # Keep track of last status + last_progress_for_item = progress_text # Keep track of last progress line + + # Inner loop iterates through the generator for ONE batch item + for videos_update, previews_update, status, progress in wanx_generate_video( + prompt, negative_prompt, resized_image, local_width, local_height, + video_length, fps, infer_steps, flow_shift, guidance_scale, current_seed, + wanx_input_end, # Pass the argument + task, dit_folder, full_dit_path, vae_path, t5_path, clip_path, save_path, + output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, + fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora_weights[0], lora_weights[1], lora_weights[2], lora_weights[3], + lora_multipliers[0], lora_multipliers[1], lora_multipliers[2], lora_multipliers[3], + enable_cfg_skip, cfg_skip_mode, cfg_apply_ratio, + None, 1.0, 0.0, 1.0, # Placeholders for control video args in random mode + enable_preview=enable_preview, + preview_steps=preview_steps + ): + # Store the latest video info from this *specific* generator run + if videos_update: + # wanx_generate_video yields the *full* list it knows about, + # so we take the last item assuming it's the new one. + newly_generated_video = videos_update[-1] + + current_previews_for_item = previews_update # Update previews for *this* item + last_status_for_item = f"Batch {i+1}/{batch_size}: {status}" # Store last status + last_progress_for_item = progress # Store last progress line + # Yield the *current cumulative* list during progress updates + yield all_videos.copy(), current_previews_for_item, last_status_for_item, last_progress_for_item + + # --- After the inner loop finishes for item 'i' --- + # Now, add the video generated in this iteration to the main list + if newly_generated_video and newly_generated_video not in all_videos: + all_videos.append(newly_generated_video) + print(f"DEBUG: Appended video {newly_generated_video[1] if isinstance(newly_generated_video, tuple) else 'unknown'} to all_videos (Total: {len(all_videos)})") + # Yield the updated cumulative list *immediately* after appending + yield all_videos.copy(), current_previews_for_item, last_status_for_item, last_progress_for_item + elif not newly_generated_video: + print(f"DEBUG: No new video generated or yielded by wanx_generate_video for batch item {i+1}.") + + + # --- Cleanup for item 'i' (Correctly indented) --- + try: + # Only remove the temporary resized image + if os.path.exists(resized_image) and "temp_resized" in resized_image: + os.remove(resized_image) + print(f"DEBUG: Removed temporary resized image: {resized_image}") + except Exception as e: + print(f"Warning: Could not remove temp image {resized_image}: {e}") + clear_cuda_cache() + time.sleep(0.5) + # --- End Cleanup for item 'i' --- + + # --- After the outer loop (all batch items processed) --- + yield all_videos, [], "Batch complete", "" # Yield empty previews at the end + else: + # ... (Keep existing checks for non-random mode: input file, control video) ... + batch_size = int(batch_size) + if not input_file and "i2v" in task: + yield [], [], "Error: No input image provided", "An input image is required for I2V models" + return + if "-FC" in task and not control_video: + yield [], [], "Error: No control video provided", "A control video is required for Fun-Control models" + return + + if batch_size > 1: + stop_event.clear() + all_videos = [] + all_previews = [] # Clear previews at start of batch + progress_text = "Starting generation..." + yield [], [], "Preparing...", progress_text # Clear galleries + + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, [], "Generation stopped by user", "" # Yield empty previews + return + + # --- Clear previews for this item --- + current_previews_for_item = [] + yield all_videos.copy(), current_previews_for_item, f"Generating video {i+1}/{batch_size}", progress_text + + current_seed = seed + if seed == -1: current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: current_seed = seed + i + batch_text = f"Generating video {i+1}/{batch_size} (seed: {current_seed})" + yield all_videos.copy(), current_previews_for_item, batch_text, progress_text # Update status + + # --- Corrected call to wanx_generate_video with accumulation --- + newly_generated_video = None # Track the video generated *in this iteration* + last_status_for_item = f"Generating video {i+1}/{batch_size}" # Keep track of last status + last_progress_for_item = progress_text # Keep track of last progress line + + # Inner loop iterates through the generator for ONE batch item + for videos_update, previews_update, status, progress in wanx_generate_video( + prompt, negative_prompt, input_file, width, height, + video_length, fps, infer_steps, flow_shift, guidance_scale, current_seed, + wanx_input_end, # Pass the argument + task, dit_folder, full_dit_path, vae_path, t5_path, clip_path, save_path, + output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, + fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora_weights[0], lora_weights[1], lora_weights[2], lora_weights[3], + lora_multipliers[0], lora_multipliers[1], lora_multipliers[2], lora_multipliers[3], + enable_cfg_skip, cfg_skip_mode, cfg_apply_ratio, + control_video, control_strength, control_start, control_end, + # --- Pass preview args --- + enable_preview=enable_preview, + preview_steps=preview_steps + ): + # Store the latest video info from this *specific* generator run + if videos_update: + # wanx_generate_video yields the *full* list it knows about, + # so we take the last item assuming it's the new one. + newly_generated_video = videos_update[-1] + + current_previews_for_item = previews_update # Update previews for *this* item + last_status_for_item = f"Batch {i+1}/{batch_size}: {status}" # Store last status + last_progress_for_item = progress # Store last progress line + # Yield the *current cumulative* list during progress updates + yield all_videos.copy(), current_previews_for_item, last_status_for_item, last_progress_for_item + + # --- After the inner loop finishes for item 'i' --- + # Now, add the video generated in this iteration to the main list + if newly_generated_video and newly_generated_video not in all_videos: + all_videos.append(newly_generated_video) + print(f"DEBUG: Appended video {newly_generated_video[1] if isinstance(newly_generated_video, tuple) else 'unknown'} to all_videos (Total: {len(all_videos)})") + # Yield the updated cumulative list *immediately* after appending + yield all_videos.copy(), current_previews_for_item, last_status_for_item, last_progress_for_item + elif not newly_generated_video: + print(f"DEBUG: No new video generated or yielded by wanx_generate_video for batch item {i+1}.") + # --- End modified call --- + + clear_cuda_cache() + time.sleep(0.5) + yield all_videos, [], "Batch complete", "" # Yield empty previews at the end + else: # Single generation (batch_size = 1) + stop_event.clear() + # --- Modified call to wanx_generate_video (yield from) --- + # Add preview args directly + yield from wanx_generate_video( + prompt, negative_prompt, input_file, width, height, + video_length, fps, infer_steps, flow_shift, guidance_scale, seed, + wanx_input_end, # Pass the argument + task, dit_folder, full_dit_path, vae_path, t5_path, clip_path, save_path, + output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, + fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora_weights[0], lora_weights[1], lora_weights[2], lora_weights[3], + lora_multipliers[0], lora_multipliers[1], lora_multipliers[2], lora_multipliers[3], + enable_cfg_skip, cfg_skip_mode, cfg_apply_ratio, + control_video, control_strength, control_start, control_end, + # --- Pass preview args --- + enable_preview=enable_preview, + preview_steps=preview_steps + ) + +def process_single_video( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + lora1: str = "", + lora2: str = "", + lora3: str = "", + lora4: str = "", + lora1_multiplier: float = 1.0, + lora2_multiplier: float = 1.0, + lora3_multiplier: float = 1.0, + lora4_multiplier: float = 1.0, + video_path: Optional[str] = None, + image_path: Optional[str] = None, + strength: Optional[float] = None, + negative_prompt: Optional[str] = None, + embedded_cfg_scale: Optional[float] = None, + split_uncond: Optional[bool] = None, + guidance_scale: Optional[float] = None, + use_fp8: bool = True +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate a single video with the given parameters""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + if is_skyreels: + # Force certain parameters for SkyReels + if negative_prompt is None: + negative_prompt = "" + if embedded_cfg_scale is None: + embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels + if split_uncond is None: + split_uncond = True + if guidance_scale is None: + guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided + + # Determine the input channels based on model type + if is_skyreels_i2v: + dit_in_channels = 32 # SkyReels I2V uses 32 channels + else: + dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) + else: + dit_in_channels = 16 # Regular Hunyuan models use 16 channels + embedded_cfg_scale = cfg_scale + + if os.path.isabs(model): + model_path = model + else: + model_path = os.path.normpath(os.path.join(dit_folder, model)) + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + env["BATCH_RUN_ID"] = f"{time.time()}" + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) + if batch_size > 1: # Only modify seed for batch generation + current_seed = (seed + batch_id * 100003) % (2**32) + else: + current_seed = seed + + clear_cuda_cache() + + command = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128" + ] + + if use_fp8: + command.append("--fp8") + + # Add negative prompt and embedded cfg scale for SkyReels + if is_skyreels: + command.extend(["--dit_in_channels", str(dit_in_channels)]) + command.extend(["--guidance_scale", str(guidance_scale)]) + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + if split_uncond: + command.append("--split_uncond") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + if use_split_attn: + command.append("--split_attn") + + # Handle input paths + if video_path: + command.extend(["--video_path", video_path]) + if strength is not None: + command.extend(["--strength", str(strength)]) + elif image_path: + command.extend(["--image_path", image_path]) + # Only add strength parameter for non-SkyReels I2V models + # SkyReels I2V doesn't use strength parameter for image-to-video generation + if strength is not None and not is_skyreels_i2v: + command.extend(["--strength", str(strength)]) + + print(f"{command}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "model": model, + "vae": vae, + "te1": te1, + "te2": te2, + "save_path": save_path, + "flow_shift": flow_shift, + "cfg_scale": cfg_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "input_video": video_path if video_path else None, + "input_image": image_path if image_path else None, + "strength": strength, + "negative_prompt": negative_prompt if is_skyreels else None, + "embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +# The issue is in the process_batch function, in the section that handles different input types +# Here's the corrected version of that section: + +def process_batch( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + *args +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Process a batch of videos using Gradio's queue""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract additional arguments + num_lora_weights = 4 + lora_weights = args[:num_lora_weights] + lora_multipliers = args[num_lora_weights:num_lora_weights*2] + extra_args = args[num_lora_weights*2:] + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + # Handle input paths and additional parameters + input_path = extra_args[0] if extra_args else None + strength = float(extra_args[1]) if len(extra_args) > 1 else None + + # Get use_fp8 flag (it should be the last parameter) + use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True + + # Get SkyReels specific parameters if applicable + if is_skyreels: + # Always set embedded_cfg_scale to 1.0 for SkyReels models + embedded_cfg_scale = 1.0 + + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" + # Use cfg_scale for guidance_scale parameter + guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale + split_uncond = True if len(extra_args) > 4 and extra_args[4] else False + else: + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None + guidance_scale = cfg_scale + embedded_cfg_scale = cfg_scale + split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Handle different input types + video_path = None + image_path = None + + if input_path: + # Check if it's an image file (common image extensions) + is_image = False + lower_path = input_path.lower() + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') + is_image = any(lower_path.endswith(ext) for ext in image_extensions) + + # Only use image_path for SkyReels I2V models and actual image files + if is_skyreels_i2v and is_image: + image_path = input_path + else: + video_path = input_path + + # Prepare arguments for process_single_video + single_video_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + single_video_args.extend(lora_weights) + single_video_args.extend(lora_multipliers) + single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) + + for videos, status, progress in process_single_video(*single_video_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def update_wanx_image_dimensions(image): + """Update dimensions from uploaded image""" + if image is None: + return "", gr.update(value=832), gr.update(value=480) + img = Image.open(image) + w, h = img.size + w = (w // 32) * 32 + h = (h // 32) * 32 + return f"{w}x{h}", w, h + +def calculate_wanx_width(height, original_dims): + """Calculate width based on height maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 32) * 32 + return gr.update(value=new_width) + +def calculate_wanx_height(width, original_dims): + """Calculate height based on width maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 32) * 32 + return gr.update(value=new_height) + +def update_wanx_from_scale(scale, original_dims): + """Update dimensions based on scale percentage""" + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 32) * 32 + new_h = math.floor((orig_h * scale / 100) / 32) * 32 + return gr.update(value=new_w), gr.update(value=new_h) + +def recommend_wanx_flow_shift(width, height): + """Get recommended flow shift value based on dimensions""" + recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 + return gr.update(value=recommended_shift) + +def handle_wanx_gallery_select(evt: gr.SelectData, gallery) -> tuple: + """Track selected index and video path when gallery item is clicked""" + if gallery is None: + return None, None + + if evt.index >= len(gallery): + return None, None + + selected_item = gallery[evt.index] + video_path = None + + # Extract the video path based on the item type + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + return evt.index, video_path + +def get_step_from_preview_path(path): + match = re.search(r"step_(\d+)_", os.path.basename(path)) + return int(match.group(1)) if match else -1 + +def wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + wanx_input_end, + task, + dit_folder, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers, + slg_start, + slg_end, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0, + enable_cfg_skip=False, + cfg_skip_mode="none", + cfg_apply_ratio=0.7, + control_video=None, + control_strength=1.0, + control_start=0.0, + control_end=1.0, + enable_preview: bool = False, + preview_steps: int = 5 +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate video with WanX model (supports both i2v, t2v and Fun-Control)""" + global stop_event + + current_previews = [] + yield [], current_previews, "Preparing...", "" # Yield empty previews + + # Fix 1: Ensure lora_folder is a string + lora_folder = str(lora_folder) if lora_folder else "lora" + + # Debug prints + print(f"DEBUG - LoRA params: {lora1}, {lora2}, {lora3}, {lora4}") + print(f"DEBUG - LoRA multipliers: {lora1_multiplier}, {lora2_multiplier}, {lora3_multiplier}, {lora4_multiplier}") + print(f"DEBUG - LoRA folder: {lora_folder}") + + # Convert values safely to float or None + try: + slg_start_float = float(slg_start) if slg_start is not None and str(slg_start).lower() != "none" else None + except (ValueError, TypeError): + slg_start_float = None + print(f"Warning: Could not convert slg_start '{slg_start}' to float") + + try: + slg_end_float = float(slg_end) if slg_end is not None and str(slg_end).lower() != "none" else None + except (ValueError, TypeError): + slg_end_float = None + print(f"Warning: Could not convert slg_end '{slg_end}' to float") + + print(f"slg_start_float: {slg_start_float}, slg_end_float: {slg_end_float}") + + if stop_event.is_set(): + yield [], [], "", "" # Yield empty previews + return + + run_id = f"{int(time.time())}_{random.randint(1000, 9999)}" + unique_preview_suffix = f"wanx_{run_id}" # Add prefix for clarity + # --- Construct unique preview paths --- + preview_base_path = os.path.join(save_path, f"latent_preview_{unique_preview_suffix}") + preview_mp4_path = preview_base_path + ".mp4" + preview_png_path = preview_base_path + ".png" + + # Check if this is a Fun-Control task + is_fun_control = "-FC" in task and control_video is not None + if is_fun_control: + print(f"DEBUG - Using Fun-Control mode with control video: {control_video}") + # Verify control video is provided + if not control_video: + yield [], "Error: No control video provided", "Fun-Control requires a control video" + return + + # Verify needed files exist + for path_name, path in [ + ("DIT", dit_path), + ("VAE", vae_path), + ("T5", t5_path), + ("CLIP", clip_path) + ]: + if not os.path.exists(path): + yield [], f"Error: {path_name} model not found", f"Model file doesn't exist: {path}" + return + + # Get current seed or use provided seed + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + + # Check if we need input image (required for i2v, not for t2v) + if "i2v" in task and not input_image: + yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" + return + + # Check for Fun-Control requirements + if is_fun_control and not control_video: + yield [], "Error: No control video provided", "Please provide a control video for Fun-Control generation" + return + + # Prepare environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + clear_cuda_cache() + + # Fix 2: Create command array with all string values + command = [ + sys.executable, + "wan_generate_video.py", + "--task", str(task), + "--prompt", str(prompt), + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", str(save_path), + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", str(output_type), + "--attn_mode", str(attn_mode), + "--blocks_to_swap", str(block_swap), + "--dit", str(dit_path), + "--vae", str(vae_path), + "--t5", str(t5_path), + "--sample_solver", str(sample_solver) + ] + + # Fix 3: Only add boolean flags if they're True + if enable_preview and preview_steps > 0: + command.extend(["--preview", str(preview_steps)]) + # --- ADDED: Pass the unique suffix --- + command.extend(["--preview_suffix", unique_preview_suffix]) + # --- End Pass Suffix --- + print(f"DEBUG - Enabling preview every {preview_steps} steps with suffix {unique_preview_suffix}.") + + if enable_cfg_skip and cfg_skip_mode != "none": + command.extend([ + "--cfg_skip_mode", str(cfg_skip_mode), + "--cfg_apply_ratio", str(cfg_apply_ratio) + ]) + + if wanx_input_end and wanx_input_end != "none" and os.path.exists(str(wanx_input_end)): + command.extend(["--end_image_path", str(wanx_input_end)]) + command.extend(["--trim_tail_frames", "3"]) + + # Handle Fun-Control (control video path) + if is_fun_control and control_video: + command.extend(["--control_path", str(control_video)]) + command.extend(["--control_weight", str(control_strength)]) + command.extend(["--control_start", str(control_start)]) + command.extend(["--control_end", str(control_end)]) + + # Handle SLG parameters + if slg_layers and str(slg_layers).strip() and str(slg_layers).lower() != "none": + try: + # Make sure slg_layers is parsed as a list of integers + slg_list = [] + for layer in str(slg_layers).split(","): + layer = layer.strip() + if layer.isdigit(): # Only add if it's a valid integer + slg_list.append(int(layer)) + if slg_list: # Only add if we have valid layers + command.extend(["--slg_layers", ",".join(map(str, slg_list))]) + + # Only add slg_start and slg_end if we have valid slg_layers + try: + if slg_start_float is not None and slg_start_float >= 0: + command.extend(["--slg_start", str(slg_start_float)]) + if slg_end_float is not None and slg_end_float <= 1.0: + command.extend(["--slg_end", str(slg_end_float)]) + except ValueError as e: + print(f"Invalid SLG timing values: {str(e)}") + except ValueError as e: + print(f"Invalid SLG layers format: {slg_layers} - {str(e)}") + + + # Add image path only for i2v task and if input image is provided + if "i2v" in task and input_image: + command.extend(["--image_path", str(input_image)]) + command.extend(["--clip", str(clip_path)]) # CLIP is needed for i2v and Fun-Control + + # Add video path for v2v task + if "v2v" in task and input_image: + command.extend(["--video_path", str(input_image)]) + # Add strength parameter for video-to-video + if isinstance(guidance_scale, (int, float)) and guidance_scale > 0: + command.extend(["--strength", str(guidance_scale)]) + + if negative_prompt: + command.extend(["--negative_prompt", str(negative_prompt)]) + + # Add boolean flags correctly + if fp8: + command.append("--fp8") + + if fp8_scaled: + command.append("--fp8_scaled") + + if fp8_t5: + command.append("--fp8_t5") + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + + # Handle LoRA weights and multipliers + lora_weights = [lora1, lora2, lora3, lora4] + lora_multipliers = [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier] + + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + # Skip None, empty, or "None" values + if weight is None or not str(weight) or str(weight).lower() == "none": + continue + + # Ensure weight is a string + weight_str = str(weight) + + # Construct full path and verify file exists + full_path = os.path.join(lora_folder, weight_str) + if not os.path.exists(full_path): + print(f"LoRA file not found: {full_path}") + continue + + # Add valid LoRA to the list + valid_loras.append((full_path, mult)) + + # Only add LoRA parameters if we have valid LoRAs + if valid_loras: + weights = [w for w, _ in valid_loras] + multipliers = [str(m) for _, m in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + # Make sure every item in command is a string + command = [str(item) for item in command] + + print(f"Running: {' '.join(command)}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + processed_preview_files = set() # Keep track of previews already yielded - REMAINS THE SAME IN UI FUNCTION + # --- Reset preview state for this run --- + current_preview_yield_path = None + last_preview_mtime = 0 + + current_phase = "Preparing" # Add phase tracking like FramePack + while True: + if stop_event.is_set(): + try: + p.terminate() + p.wait(timeout=5) + except subprocess.TimeoutExpired: + p.kill() + p.wait() + except Exception as e: + print(f"Error terminating subprocess: {e}") + yield [], [], "Generation stopped by user.", "" # Yield empty previews + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + time.sleep(0.01); continue + + line = line.strip() + if not line: continue + print(f"WANX SUBPROCESS: {line}") # Log subprocess output + + # --- Adopt FramePack's Parsing Logic --- + status_text = f"Processing (seed: {current_seed})" # Default status + progress_text_update = line # Default progress + + # Check for TQDM progress using regex + tqdm_match = re.search(r'(\d+)\%\|.+\| (\d+)/(\d+) \[(\d{2}:\d{2})<(\d{2}:\d{2})', line) + + if tqdm_match: + percentage = int(tqdm_match.group(1)) + current_step = int(tqdm_match.group(2)) + total_steps = int(tqdm_match.group(3)) + time_elapsed = tqdm_match.group(4) + time_remaining = tqdm_match.group(5) + + current_phase = f"Denoising Step {current_step}/{total_steps}" # Update phase + + # Format progress text like FramePack for JS compatibility + progress_text_update = f"Step {current_step}/{total_steps} ({percentage}%) | Elapsed: {time_elapsed}, Remaining: {time_remaining}" + status_text = f"Generating (seed: {current_seed}) - {current_phase}" + + elif "ERROR" in line.upper() or "TRACEBACK" in line.upper(): + status_text = f"Error (seed: {current_seed})" + progress_text_update = line # Show error line + current_phase = "Error" + + # Add more phases if needed (e.g., "Decoding", "Saving") by checking logs + elif "Decoding video..." in line: # Placeholder check + current_phase = "Decoding Video" + status_text = f"Generating (seed: {current_seed}) - {current_phase}" + progress_text_update = "Decoding video..." + + elif "Video saved to:" in line: # Placeholder check + current_phase = "Saved" + status_text = f"Completed (seed: {current_seed})" + progress_text_update = line # Show the save line + # Add any other status parsing if needed + preview_updated = False + current_mtime = 0 + found_preview_path = None + + if enable_preview: + # --- MODIFIED: Check unique paths --- + if os.path.exists(preview_mp4_path): + current_mtime = os.path.getmtime(preview_mp4_path) + found_preview_path = preview_mp4_path + elif os.path.exists(preview_png_path): + current_mtime = os.path.getmtime(preview_png_path) + found_preview_path = preview_png_path + # --- End Modified Check --- + + if found_preview_path and current_mtime > last_preview_mtime: + print(f"DEBUG: Preview file updated: {found_preview_path} (mtime: {current_mtime})") + # Yield the clean path (already unique) + current_preview_yield_path = found_preview_path # No cache buster needed + last_preview_mtime = current_mtime + preview_updated = True + # --- End Preview Check --- + + # --- YIELD --- + # Yield progress and potentially updated unique preview path + preview_list_for_yield = [current_preview_yield_path] if current_preview_yield_path else [] + # Yield progress and potentially updated unique preview path list + yield videos.copy(), preview_list_for_yield, status_text, progress_text_update + + p.stdout.close() + rc = p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # --- Collect final generated video --- + generated_video_path = None + if rc == 0: # Only look for video if process succeeded + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + # Find the most recent mp4 containing the seed + all_mp4_files = glob.glob(os.path.join(save_path_abs, f"*_{current_seed}*.mp4")) + # Exclude files in the 'previews' subdirectory + all_mp4_files = [f for f in all_mp4_files if "previews" not in os.path.dirname(f)] + + if all_mp4_files: + # Find the *absolute* most recent one, as multiple might match seed in edge cases + generated_video_path = max(all_mp4_files, key=os.path.getmtime) + print(f"Found newly generated video: {generated_video_path}") + + # Add metadata (assuming add_metadata_to_video exists and works) + parameters = { + "prompt": prompt, "negative_prompt": negative_prompt, + "input_image": input_image if "i2v" in task else None, + "width": width, "height": height, "video_length": video_length, "fps": fps, + "infer_steps": infer_steps, "flow_shift": flow_shift, "guidance_scale": guidance_scale, + "seed": current_seed, "task": task, "dit_path": dit_path, + "vae_path": vae_path, "t5_path": t5_path, "clip_path": clip_path if "i2v" in task or is_fun_control else None, + "save_path": save_path, "output_type": output_type, "sample_solver": sample_solver, + "exclude_single_blocks": exclude_single_blocks, "attn_mode": attn_mode, + "block_swap": block_swap, "fp8": fp8, "fp8_scaled": fp8_scaled, "fp8_t5": fp8_t5, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "slg_layers": slg_layers, "slg_start": slg_start, "slg_end": slg_end, + "enable_cfg_skip": enable_cfg_skip, "cfg_skip_mode": cfg_skip_mode, "cfg_apply_ratio": cfg_apply_ratio, + "control_video": control_video if is_fun_control else None, + "control_strength": control_strength if is_fun_control else None, + "control_start": control_start if is_fun_control else None, + "control_end": control_end if is_fun_control else None, + } + try: + add_metadata_to_video(generated_video_path, parameters) + except NameError: + print("Warning: add_metadata_to_video function not found. Skipping metadata.") + except Exception as meta_err: + print(f"Warning: Failed to add metadata: {meta_err}") + + # Append to the final video list + videos.append((str(generated_video_path), f"Seed: {current_seed}")) + else: + print(f"Subprocess finished successfully (rc=0), but could not find generated video for seed {current_seed} in {save_path_abs}") + +# --- Final Yield --- + final_status = f"Completed (seed: {current_seed})" if rc == 0 and generated_video_path else f"Failed (seed: {current_seed}, rc={rc})" + final_progress = f"Video saved: {os.path.basename(generated_video_path)}" if rc == 0 and generated_video_path else f"Subprocess failed with exit code {rc}" + + # Check for the preview file one last time for the final update (using unique path) + # --- MODIFIED Final Preview Check and List Creation --- + final_preview_path = None + # --- Use the UNIQUE paths defined earlier in the function --- + if os.path.exists(preview_mp4_path): + final_preview_path = os.path.abspath(preview_mp4_path) + elif os.path.exists(preview_png_path): + final_preview_path = os.path.abspath(preview_png_path) + # --- End path checking --- + + final_preview_list_for_yield = [final_preview_path] if final_preview_path else [] + # --- End Modified --- + + yield videos, final_preview_list_for_yield, final_status, final_progress + +def send_wanx_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + guidance_scale: float, + negative_prompt: str +) -> Tuple: + """Send the selected WanX video to Video2Video tab""" + if gallery is None or not gallery: + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + # If no selection made but we have videos, use the first one + if selected_index is None and len(gallery) > 0: + selected_index = 0 + + if selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + # Clean up path for Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Make sure it's a string + video_path = str(video_path) + + return (video_path, prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def wanx_generate_video_batch( + prompt, + negative_prompt, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers: int, + slg_start: Optional[str], + slg_end: Optional[str], + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0, + batch_size=1, + input_image=None # Make input_image optional and place it at the end +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate videos with WanX with support for batches""" + slg_start = None if slg_start == 'None' or slg_start is None else slg_start + slg_end = None if slg_end == 'None' or slg_end is None else slg_end + + # Now safely convert to float if not None + slg_start_float = float(slg_start) if slg_start is not None and isinstance(slg_start, (str, int, float)) else None + slg_end_float = float(slg_end) if slg_end is not None and isinstance(slg_end, (str, int, float)) else None + print(f"slg_start_float: {slg_start_float}, slg_end_float: {slg_end_float}") + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Process each item in the batch + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Generate a single video using the existing function + for videos, status, progress in wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + current_seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers, + slg_start, + slg_end, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def update_wanx_t2v_dimensions(size): + """Update width and height based on selected size""" + width, height = map(int, size.split('*')) + return gr.update(value=width), gr.update(value=height) + +def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when gallery item is clicked""" + return evt.index + +def send_wanx_t2v_to_v2v( + gallery, prompt, selected_index, width, height, video_length, + fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt +) -> Tuple: + """Send the selected WanX T2V video to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def prepare_for_batch_extension(input_img, base_video, batch_size): + """Prepare inputs for batch video extension""" + if input_img is None: + return None, None, batch_size, "No input image found", "" + + if base_video is None: + return input_img, None, batch_size, "No base video selected for extension", "" + + return input_img, base_video, batch_size, "Preparing batch extension...", f"Will create {batch_size} variations of extended video" + +def concat_batch_videos(base_video_path, generated_videos, save_path, original_video_path=None): + """Concatenate multiple generated videos with the base video""" + if not base_video_path: + return [], "No base video provided" + + if not generated_videos or len(generated_videos) == 0: + return [], "No new videos generated" + + # Create output directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + + # Track all extended videos + extended_videos = [] + + # For each generated video, create an extended version + for i, video_item in enumerate(generated_videos): + try: + # Extract video path from gallery item + if isinstance(video_item, tuple): + new_video_path = video_item[0] + seed_info = video_item[1] if len(video_item) > 1 else "" + elif isinstance(video_item, dict): + new_video_path = video_item.get("name", video_item.get("data", None)) + seed_info = "" + else: + new_video_path = video_item + seed_info = "" + + if not new_video_path or not os.path.exists(new_video_path): + print(f"Skipping missing video: {new_video_path}") + continue + + # Create unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + # Extract seed from seed_info if available + seed_match = re.search(r"Seed: (\d+)", seed_info) + seed_part = f"_seed{seed_match.group(1)}" if seed_match else f"_{i}" + + output_filename = f"extended_{timestamp}{seed_part}_{Path(base_video_path).stem}.mp4" + output_path = os.path.join(save_path, output_filename) + + # Create a temporary file list for ffmpeg + list_file = os.path.join(save_path, f"temp_list_{i}.txt") + with open(list_file, "w") as f: + f.write(f"file '{os.path.abspath(base_video_path)}'\n") + f.write(f"file '{os.path.abspath(new_video_path)}'\n") + + # Run ffmpeg concatenation + command = [ + "ffmpeg", + "-f", "concat", + "-safe", "0", + "-i", list_file, + "-c", "copy", + output_path + ] + + subprocess.run(command, check=True, capture_output=True) + + # Clean up temporary file + if os.path.exists(list_file): + os.remove(list_file) + + # Add to extended videos list if successful + if os.path.exists(output_path): + seed_display = f"Extended {seed_info}" if seed_info else f"Extended video #{i+1}" + extended_videos.append((output_path, seed_display)) + + except Exception as e: + print(f"Error processing video {i}: {str(e)}") + + if not extended_videos: + return [], "Failed to create any extended videos" + + return extended_videos, f"Successfully created {len(extended_videos)} extended videos" + +def wanx_extend_single_video( + prompt, negative_prompt, input_image, base_video_path, + width, height, video_length, fps, infer_steps, + flow_shift, guidance_scale, seed, + task, dit_path, vae_path, t5_path, clip_path, + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers="", slg_start=0.0, slg_end=1.0, + lora1="None", lora2="None", lora3="None", lora4="None", + lora1_multiplier=1.0, lora2_multiplier=1.0, lora3_multiplier=1.0, lora4_multiplier=1.0 +): + """Generate a single video and concatenate with base video""" + # First, generate the video with proper parameter handling + all_videos = [] + + # Sanitize lora parameters + lora_weights = [str(lora1) if lora1 is not None else "None", + str(lora2) if lora2 is not None else "None", + str(lora3) if lora3 is not None else "None", + str(lora4) if lora4 is not None else "None"] + + # Convert multipliers to float + try: + lora_multipliers = [float(lora1_multiplier), float(lora2_multiplier), + float(lora3_multiplier), float(lora4_multiplier)] + except (ValueError, TypeError): + # Fallback to defaults if conversion fails + lora_multipliers = [1.0, 1.0, 1.0, 1.0] + + # Debug print + print(f"Sanitized LoRA weights: {lora_weights}") + print(f"Sanitized LoRA multipliers: {lora_multipliers}") + + # Generate video + for videos, status, progress in wanx_generate_video( + prompt, negative_prompt, input_image, width, height, + video_length, fps, infer_steps, flow_shift, guidance_scale, + seed, task, dit_path, vae_path, t5_path, clip_path, + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora_weights[0], lora_weights[1], lora_weights[2], lora_weights[3], + lora_multipliers[0], lora_multipliers[1], lora_multipliers[2], lora_multipliers[3], + enable_cfg_skip=False, + cfg_skip_mode="none", + cfg_apply_ratio=0.7 + ): + + # Keep track of generated videos + if videos: + all_videos = videos + + # Forward progress updates + yield all_videos, status, progress + + # Now concatenate with base video if we have something + if all_videos and base_video_path and os.path.exists(base_video_path): + try: + print(f"Extending base video: {base_video_path}") + + # Create unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + output_filename = f"extended_{timestamp}_seed{seed}_{Path(base_video_path).stem}.mp4" + output_path = os.path.join(save_path, output_filename) + + # Extract the path from the gallery item + new_video_path = all_videos[0][0] if isinstance(all_videos[0], tuple) else all_videos[0] + + # Create a temporary file list for ffmpeg + list_file = os.path.join(save_path, f"temp_list_{seed}.txt") + with open(list_file, "w") as f: + f.write(f"file '{os.path.abspath(base_video_path)}'\n") + f.write(f"file '{os.path.abspath(new_video_path)}'\n") + + print(f"Concatenating: {base_video_path} + {new_video_path}") + + # Run ffmpeg concatenation + command = [ + "ffmpeg", + "-f", "concat", + "-safe", "0", + "-i", list_file, + "-c", "copy", + "-y", + output_path + ] + + subprocess.run(command, check=True, capture_output=True) + + # Clean up temporary file + if os.path.exists(list_file): + os.remove(list_file) + + # Return the extended video if successful + if os.path.exists(output_path): + extended_video = [(output_path, f"Extended (Seed: {seed})")] + print(f"Successfully created extended video: {output_path}") + yield extended_video, "Extended video created successfully", "" + return + else: + print(f"Failed to create extended video at {output_path}") + except Exception as e: + print(f"Error creating extended video: {str(e)}") + + # If we got here, something went wrong with the concatenation + yield all_videos, "Generated video (extension failed)", "" + +def process_batch_extension( + prompt, negative_prompt, input_image, base_video, + width, height, video_length, fps, infer_steps, + flow_shift, guidance_scale, seed, batch_size, + task, dit_folder, dit_path, vae_path, t5_path, clip_path, # <<< Added dit_folder + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora1="None", lora2="None", lora3="None", lora4="None", + lora1_multiplier=1.0, lora2_multiplier=1.0, lora3_multiplier=1.0, lora4_multiplier=1.0 +): + """Process a batch of video extensions one at a time""" + global stop_event + stop_event.clear() + + all_extended_videos = [] # Store successfully extended videos + progress_text = "Starting video extension batch..." + yield [], progress_text, "" # Initial yield + + try: + # Ensure batch_size is treated as an integer + batch_size = int(batch_size) + except (ValueError, TypeError): + batch_size = 1 + print("Warning: Invalid batch_size, defaulting to 1.") + + # Ensure base_video exists + if not base_video or not os.path.exists(base_video): + yield [], "Error: Base video not found", f"Cannot find video at {base_video}" + return + + # Process each batch item independently + for i in range(batch_size): + if stop_event.is_set(): + yield all_extended_videos, "Extension stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Processing extension {i+1}/{batch_size} (seed: {current_seed})" + yield all_extended_videos, batch_text, progress_text # Update progress + + # Use the direct wrapper with correct parameter order, including dit_folder + generation_iterator = wanx_extend_video_wrapper( + prompt=prompt, negative_prompt=negative_prompt, input_image=input_image, base_video_path=base_video, + width=width, height=height, video_length=video_length, fps=fps, infer_steps=infer_steps, + flow_shift=flow_shift, guidance_scale=guidance_scale, seed=current_seed, + task=task, + dit_folder=dit_folder, # <<< Pass the folder path + dit_path=dit_path, # <<< Pass the model filename + vae_path=vae_path, + t5_path=t5_path, + clip_path=clip_path, + save_path=save_path, output_type=output_type, sample_solver=sample_solver, + exclude_single_blocks=exclude_single_blocks, attn_mode=attn_mode, block_swap=block_swap, + fp8=fp8, fp8_scaled=fp8_scaled, fp8_t5=fp8_t5, lora_folder=lora_folder, + slg_layers=slg_layers, slg_start=slg_start, slg_end=slg_end, + lora1=lora1, lora2=lora2, lora3=lora3, lora4=lora4, + lora1_multiplier=lora1_multiplier, lora2_multiplier=lora2_multiplier, + lora3_multiplier=lora3_multiplier, lora4_multiplier=lora4_multiplier + ) + + # Iterate through the generator for this single extension + final_videos_for_item = [] + final_status_for_item = "Unknown status" + final_progress_for_item = "" + try: + for videos, status, progress in generation_iterator: + # Forward progress information immediately + yield all_extended_videos, f"Batch {i+1}/{batch_size}: {status}", progress + + # Store the latest state for this item + final_videos_for_item = videos + final_status_for_item = status + final_progress_for_item = progress + + # After the loop for one item finishes, check the result + if final_videos_for_item: + # Check if the video is actually an extended one + is_extended = any("Extended" in (v[1] if isinstance(v, tuple) else "") for v in final_videos_for_item) + if is_extended: + all_extended_videos.extend(final_videos_for_item) + print(f"Added extended video to collection (total: {len(all_extended_videos)})") + else: + # It was just the generated segment, maybe log this? + print(f"Video segment generated for batch {i+1} but extension failed or wasn't performed.") + else: + print(f"No video returned for batch item {i+1}.") + + + except Exception as e: + print(f"Error during single extension processing (batch {i+1}): {e}") + yield all_extended_videos, f"Error in batch {i+1}: {e}", "" + + + # Clean CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + # Final yield after the loop + yield all_extended_videos, "Batch extension complete", "" + +def handle_extend_generation(base_video_path: str, new_videos: list, save_path: str, current_gallery: list) -> tuple: + """Combine generated video with base video and update gallery""" + if not base_video_path: + return current_gallery, "Extend failed: No base video provided" + + if not new_videos: + return current_gallery, "Extend failed: No new video generated" + + # Ensure save path exists + os.makedirs(save_path, exist_ok=True) + + # Get the first video from new_videos (gallery item) + new_video_path = new_videos[0][0] if isinstance(new_videos[0], tuple) else new_videos[0] + + # Create a unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + output_filename = f"extended_{timestamp}_{Path(base_video_path).stem}.mp4" + output_path = str(Path(save_path) / output_filename) + + try: + # Concatenate the videos using ffmpeg + ( + ffmpeg + .input(base_video_path) + .concat( + ffmpeg.input(new_video_path) + ) + .output(output_path) + .run(overwrite_output=True, quiet=True) + ) + + # Create a new gallery entry with the combined video + updated_gallery = [(output_path, f"Extended video: {Path(output_path).stem}")] + + return updated_gallery, f"Successfully extended video to {Path(output_path).name}" + except Exception as e: + print(f"Error extending video: {str(e)}") + return current_gallery, f"Failed to extend video: {str(e)}" + +# UI setup +with gr.Blocks( + theme=themes.Default( + primary_hue=colors.Color( + name="custom", + c50="#E6F0FF", + c100="#CCE0FF", + c200="#99C1FF", + c300="#66A3FF", + c400="#3384FF", + c500="#0060df", # This is your main color + c600="#0052C2", + c700="#003D91", + c800="#002961", + c900="#001430", + c950="#000A18" + ) + ), + css=""" + .gallery-item:first-child { border: 2px solid #4CAF50 !important; } + .gallery-item:first-child:hover { border-color: #45a049 !important; } + .green-btn { + background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; + color: white !important; + border: none !important; + } + .green-btn:hover { + background: linear-gradient(to bottom right, #27ae60, #219651) !important; + } + .refresh-btn { + max-width: 40px !important; + min-width: 40px !important; + height: 40px !important; + border-radius: 50% !important; + padding: 0 !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + } + .light-blue-btn { + background: linear-gradient(to bottom right, #AEC6CF, #9AB8C4) !important; /* Light blue gradient */ + color: #333 !important; /* Darker text for readability */ + border: 1px solid #9AB8C4 !important; /* Subtle border */ + } + .light-blue-btn:hover { + background: linear-gradient(to bottom right, #9AB8C4, #8AA9B5) !important; /* Slightly darker on hover */ + border-color: #8AA9B5 !important; + } + """, + +) as demo: + # Add state for tracking selected video indices in both tabs + selected_index = gr.State(value=None) # For Text to Video + v2v_selected_index = gr.State(value=None) # For Video to Video + params_state = gr.State() #New addition + i2v_selected_index = gr.State(value=None) + skyreels_selected_index = gr.State(value=None) + wanx_i2v_selected_index = gr.State(value=None) + extended_videos = gr.State(value=[]) + wanx_base_video = gr.State(value=None) + wanx_sharpest_frame_number = gr.State(value=None) + wanx_sharpest_frame_path = gr.State(value=None) + wanx_trimmed_video_path = gr.State(value=None) + wanx_v2v_selected_index = gr.State(value=None) + wanx_t2v_selected_index = gr.State(value=None) + framepack_selected_index = gr.State(value=None) + framepack_original_dims = gr.State(value="") + fpe_selected_index = gr.State(value=None) + demo.load(None, None, None, js=""" + () => { + document.title = 'H1111'; + + function updateTitle(text) { + if (text && text.trim()) { + // Regex for the FramePack format: "Item ... (...)% | ... Remaining: HH:MM" + const framepackMatch = text.match(/.*?\((\d+)%\).*?Remaining:\s*(\d{2}:\d{2})/); + // Regex for standard tqdm format (like WanX uses) + const tqdmMatch = text.match(/(\d+)%\|.*\[.*<(\d{2}:\d{2})/); // Adjusted slightly for robustness + + if (framepackMatch) { + // Handle FramePack format + const percentage = framepackMatch[1]; + const timeRemaining = framepackMatch[2]; + document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; + } else if (tqdmMatch) { // <<< ADDED ELSE IF for standard tqdm + // Handle standard tqdm format + const percentage = tqdmMatch[1]; + const timeRemaining = tqdmMatch[2]; + document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; + } else { + // Optional: Reset title if neither format matches? + // document.title = 'H1111'; + } + } + } + + setTimeout(() => { + // This selector should still find all relevant progress textareas + const progressElements = document.querySelectorAll('textarea.scroll-hide'); + progressElements.forEach(element => { + if (element) { + new MutationObserver(() => { + updateTitle(element.value); + }).observe(element, { + attributes: true, + childList: true, + characterData: true + }); + } + }); + }, 1000); + } + """) + + with gr.Tabs() as tabs: + + #FRAME PACK TAB + with gr.Tab(id=10, label="FramePack") as framepack_tab: + + with gr.Row(): + with gr.Column(scale=4): + framepack_prompt = gr.Textbox( + scale=3, label="Prompt (Supports sections: index:prompt;;;index:prompt)", + value="cinematic video of a cat wizard casting a spell", lines=3, + info="Use '0:prompt;;;-1:prompt' or '0-2:prompt;;;3:prompt'. Index total sections -1 is last section." + ) + framepack_negative_prompt = gr.Textbox(scale=3, label="Negative Prompt", value="", lines=3) + with gr.Column(scale=1): + framepack_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + framepack_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + framepack_is_f1 = gr.Checkbox(label="🏎️ Use F1 Model", value=False, + info="Switches to the F1 model (different DiT path and logic).") + with gr.Column(scale=2): + framepack_batch_progress = gr.Textbox(label="Status", interactive=False, value="") + framepack_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + with gr.Row(): + framepack_generate_btn = gr.Button("Generate FramePack Video", elem_classes="green-btn") + framepack_stop_btn = gr.Button("Stop Generation", variant="stop") + + # Main Content + with gr.Row(): + # --- Left Column --- + with gr.Column(): + framepack_input_image = gr.Image(label="Input Image (Video Start)", type="filepath") + with gr.Row(): + framepack_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False, + info="If checked, 'Input Image (Video Start)' is hidden. Each batch item uses a random image from the folder.") + framepack_input_folder_path = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images for batch processing", + visible=False # Initially hidden + ) + with gr.Row(visible=False) as framepack_folder_options_row: # Parent Row for folder options + framepack_validate_folder_btn = gr.Button("Validate Folder") + framepack_folder_status_text = gr.Textbox( + label="Folder Status", + placeholder="Validation status will appear here", + interactive=False + ) + with gr.Accordion("Optional End Frame Control (normal model only)", open=False): + framepack_input_end_frame = gr.Image(label="End Frame Image (Video End)", type="filepath", scale=1) + framepack_end_frame_influence = gr.Dropdown( + label="End Frame Influence Mode", + choices=["last", "half", "progressive", "bookend"], + value="last", + info="How the end frame affects generation (if provided)", + visible=False + ) + framepack_end_frame_weight = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.5, # Default changed from 0.3 + label="End Frame Weight", + info="Influence strength of the end frame (if provided)", + visible=False + ) + + gr.Markdown("### Resolution Options (Choose One)") + framepack_target_resolution = gr.Number( + label="Option 1: Target Resolution (Uses Buckets)", + value=640, minimum=0, maximum=1280, step=32, + info="Target bucket size (e.g., 640 for 640x640). Uses input image aspect ratio. Final size divisible by 32.", + interactive=True + ) + with gr.Accordion("Option 2: Explicit Resolution (Overrides Option 1)", open=False): + framepack_scale_slider = gr.Slider( + minimum=1, maximum=200, value=100, step=1, label="Scale % (UI Only)" + ) + with gr.Row(): + framepack_width = gr.Number( + label="Width", value=None, minimum=0, step=32, + info="Must be divisible by 32.", interactive=True + ) + framepack_calc_height_btn = gr.Button("→") + framepack_calc_width_btn = gr.Button("←") + framepack_height = gr.Number( + label="Height", value=None, minimum=0, step=32, + info="Must be divisible by 32.", interactive=True + ) + framepack_total_second_length = gr.Slider(minimum=1.0, maximum=120.0, step=0.5, label="Total Video Length (seconds)", value=5.0) + framepack_video_sections = gr.Number( + label="Total Video Sections (Overrides seconds if > 0)", + value=None, step=1, + info="Specify exact number of sections. If set, 'Total Video Length (seconds)' is ignored by the backend." + ) + framepack_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Output FPS", value=30) + with gr.Row(): + framepack_seed = gr.Number(label="Seed (-1 for random)", value=-1) + framepack_random_seed =gr.Button("🎲️") + framepack_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Steps", value=25, interactive=True) # Moved here + + # --- Right Column --- + with gr.Column(): + framepack_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], rows=[1], + object_fit="contain", height="auto", show_label=True, + elem_id="gallery_framepack", allow_preview=True, preview=True + ) + with gr.Accordion("Latent Preview (During Generation)", open=True): + with gr.Row(): + framepack_enable_preview = gr.Checkbox(label="Enable Latent Preview", value=True) + framepack_use_full_video_preview = gr.Checkbox(label="Use Full Video Previews (slower)", value=False) + with gr.Row(): + framepack_preview_every_n_sections = gr.Slider( + minimum=1, maximum=50, step=1, value=1, + label="Preview Every N Sections", + info="Generates previews during the sampling loop." + ) + framepack_preview_output = gr.Video( # Changed from Gallery to Video + label="Latest Preview", height=300, + interactive=False, # Not interactive for display + elem_id="framepack_preview_video" + ) + framepack_skip_btn = gr.Button("Skip Batch Item", elem_classes="light-blue-btn") + with gr.Group(): + with gr.Row(): + framepack_refresh_lora_btn = gr.Button("🔄 LoRA", elem_classes="refresh-btn") # Specific LoRA refresh + framepack_lora_folder = gr.Textbox(label="LoRa Folder", value="lora", scale=4) + framepack_lora_weights = [] + framepack_lora_multipliers = [] + for i in range(4): # Assuming max 4 LoRAs like other tabs + with gr.Row(): + framepack_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", choices=get_lora_options("lora"), + value="None", allow_custom_value=False, interactive=True, scale=2 + )) + framepack_lora_multipliers.append(gr.Slider( + label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0, scale=1, interactive=True + )) + # Fixed Generation Parameters Section + with gr.Accordion("Generation Parameters", open=True): + with gr.Row(): + framepack_distilled_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Distilled Guidance Scale (embedded_cfg_scale)", value=10.0, interactive=True) + framepack_guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.1, label="Guidance Scale (CFG)", value=1.0, interactive=True, info="Default 1.0 (no CFG), backend recommends not changing.") + with gr.Row(): + framepack_guidance_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CFG Rescale (rs)", value=0.0, interactive=True, info="Default 0.0, backend recommends not changing.") + framepack_latent_window_size = gr.Number(label="Latent Window Size", value=9, interactive=True, info="Default 9") + framepack_sample_solver = gr.Dropdown(label="Sample Solver", choices=["unipc", "dpm++", "vanilla"], value="unipc", interactive=True) + + with gr.Accordion("Advanced Section Control (Optional)", open=False): + gr.Markdown( + "Define specific prompts and starting images for different sections of the video. " + "For the index you can input a range or a single index. A 5 second default video has 4 sections. The first section is 0 and the last is 3" + ) + # --- Define section controls explicitly --- + with gr.Row(): + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("**--- Control Slot 1 ---**") + with gr.Row(): + + framepack_sec_1 = gr.Textbox(label="Index/Range", value="0", placeholder="e.g., 0 or 0-1", interactive=True) + framepack_sec_prompt_1 = gr.Textbox(label="Prompt Override", lines=2, placeholder="Overrides base prompt for these sections") + framepack_sec_image_1 = gr.Image(label="Start Image Override", type="filepath", scale=1) + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("**--- Control Slot 2 ---**") + with gr.Row(): + + framepack_sec_2 = gr.Textbox(label="Index/Range", value="1", placeholder="e.g., 2 or 2-3", interactive=True) + framepack_sec_prompt_2 = gr.Textbox(label="Prompt Override", lines=2) + framepack_sec_image_2 = gr.Image(label="Start Image Override", type="filepath", scale=1) + with gr.Row(): + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("**--- Control Slot 3 ---**") + with gr.Row(): + + framepack_sec_3 = gr.Textbox(label="Index/Range", value="2", placeholder="e.g., 4 or 4-5", interactive=True) + framepack_sec_prompt_3 = gr.Textbox(label="Prompt Override", lines=2) + framepack_sec_image_3 = gr.Image(label="Start Image Override", type="filepath", scale=1) + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("**--- Control Slot 4 ---**") + with gr.Row(): + framepack_sec_4 = gr.Textbox(label="Index/Range", value="3", placeholder="e.g., 6 or 6-7", interactive=True) + framepack_sec_prompt_4 = gr.Textbox(label="Prompt Override", lines=2) + framepack_sec_image_4 = gr.Image(label="Start Image Override", type="filepath", scale=1) + + # Group section control components for easier passing to functions (remains the same) + framepack_secs = [framepack_sec_1, framepack_sec_2, framepack_sec_3, framepack_sec_4] + framepack_sec_prompts = [framepack_sec_prompt_1, framepack_sec_prompt_2, framepack_sec_prompt_3, framepack_sec_prompt_4] + framepack_sec_images = [framepack_sec_image_1, framepack_sec_image_2, framepack_sec_image_3, framepack_sec_image_4] + + # Performance/Memory Accordion - Updated + with gr.Accordion("Performance / Memory", open=True): + with gr.Row(): + framepack_fp8 = gr.Checkbox(label="Use FP8 DiT", value=False, info="Enable FP8 precision for the main Transformer model.") + framepack_fp8_llm = gr.Checkbox(label="Use FP8 LLM (Text Encoder 1)", value=False, info="Enable FP8 for the Llama text encoder.", visible=False) + framepack_fp8_scaled = gr.Checkbox(label="Use Scaled FP8 DiT", value=False, info="Requires FP8 DiT. Use scaled math (potential quality improvement).") + framepack_blocks_to_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Blocks to Swap (to Save VRAM, 0=disable)", value=26, + info="Higher values = less VRAM usage but slower generation") + framepack_bulk_decode = gr.Checkbox(label="Bulk Decode Frames (Faster Decode, Higher VRAM)", value=False, info="Decode all frames at once instead of section by section.") + with gr.Row(): + framepack_attn_mode = gr.Dropdown( + label="Attention Mode", + choices=["torch", "sdpa", "flash", "xformers", "sageattn"], # Added choices from script + value="sdpa", # Defaulting to sdpa + interactive=True + ) + framepack_vae_chunk_size = gr.Number(label="VAE Chunk Size (CausalConv3d)", value=32, step=1, minimum=0, info="0 or None=disable (Default: None)") + framepack_vae_spatial_tile_sample_min_size = gr.Number(label="VAE Spatial Tile Min Size", value=128, step=16, minimum=0, info="0 or None=disable (Default: None)") + framepack_device = gr.Textbox(label="Device Override (optional)", placeholder="e.g., cuda:0, cpu") + with gr.Row(): + framepack_use_teacache = gr.Checkbox(label="Use TeaCache", value=False, info="Enable TeaCache for faster generation (shits hands).") + framepack_teacache_steps = gr.Number(label="TeaCache Init Steps", value=25, step=1, minimum=1, info="Steps for TeaCache init (match Inference Steps)") + framepack_teacache_thresh = gr.Slider(label="TeaCache Threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.15, info="Relative L1 distance threshold for skipping.") + + with gr.Accordion("Model Paths / Advanced", open=False): + with gr.Row(): + framepack_transformer_path = gr.Textbox(label="Transformer Path (DiT)", value="hunyuan/FramePackI2V_HY_bf16.safetensors", interactive=True) + framepack_vae_path = gr.Textbox(label="VAE Path", value="hunyuan/pytorch_model.pt") + with gr.Row(): + framepack_text_encoder_path = gr.Textbox(label="Text Encoder 1 (Llama) Path *Required*", value="hunyuan/llava_llama3_fp16.safetensors") + framepack_text_encoder_2_path = gr.Textbox(label="Text Encoder 2 (CLIP) Path *Required*", value="hunyuan/clip_l.safetensors") + with gr.Row(): + framepack_image_encoder_path = gr.Textbox(label="Image Encoder (SigLIP) Path *Required*", value="hunyuan/model.safetensors") + framepack_save_path = gr.Textbox(label="Save Path *Required*", value="outputs") +### FRAMEPACK EXTENSION + with gr.Tab(id=11, label="FramePack-Extension") as framepack_extension_tab: + with gr.Row(): + with gr.Column(scale=4): + fpe_prompt = gr.Textbox( + scale=3, label="Prompt", + value="cinematic video of a cat wizard casting a spell, epic action scene", lines=3 + ) + fpe_negative_prompt = gr.Textbox(scale=3, label="Negative Prompt", value="", lines=3) + with gr.Column(scale=1): + fpe_use_normal_framepack = gr.Checkbox(label="Use Normal FramePack Model", value=False, info="Uses og model supports end frame. Default is F1 model.") + fpe_batch_count = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + with gr.Column(scale=2): + fpe_batch_progress = gr.Textbox(label="Status", interactive=False, value="") + fpe_progress_text = gr.Textbox(label="Progress", interactive=False, lines=1, elem_id="fpe_progress_text") # Unique elem_id + + with gr.Row(): + fpe_generate_btn = gr.Button("Generate Extended Video", elem_classes="green-btn") + fpe_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): # Left column for inputs + fpe_input_video = gr.Video(label="Input Video for Extension", sources=['upload'], height=300) + with gr.Accordion("Optional End Frame (for Normal FramePack Model)", open=False, visible=False) as fpe_end_frame_accordion: + fpe_end_frame = gr.Image(label="End Frame for Extension", type="filepath") + fpe_end_frame_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=1.0, label="End Frame Weight") + with gr.Accordion("Optional Start Guidance Image (for F1 Model Extension)", open=False, visible=True) as fpe_start_guidance_accordion: # Initially hidden + fpe_start_guidance_image = gr.Image(label="Start Guidance Image for Extension", type="filepath") + fpe_start_guidance_image_clip_weight = gr.Slider( + minimum=0.0, maximum=2.0, step=0.05, value=0.75, + label="Start Guidance Image CLIP Weight", + info="Blend weight for the guidance image's CLIP embedding with input video's first frame CLIP." + ) + fpe_use_guidance_image_as_first_latent = gr.Checkbox( + label="Use Guidance Image as First Latent", value=False, + info="If checked, the VAE latent of the guidance image will be used as the initial conditioning for the first generated segment. Turn down context frames when using this" + ) + gr.Markdown("### Core Generation Parameters") + with gr.Row(): + fpe_seed = gr.Number(label="Seed (-1 for random)", value=-1) + # fpe_random_seed_btn = gr.Button("🎲️") # Optional: Add random seed button + + fpe_resolution_max_dim = gr.Number(label="Resolution (Max Dimension)", value=640, step=32, info="Target max width/height for bucket.") + fpe_total_second_length = gr.Slider(minimum=1.0, maximum=120.0, step=0.5, label="Additional Video Length (seconds)", value=5.0) + fpe_latent_window_size = gr.Slider(minimum=9, maximum=33, step=1, label="Latent Window Size", value=9, info="Default 9 for F1 model.") + fpe_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps", value=25) + + with gr.Row(): + fpe_cfg_scale = gr.Slider(minimum=1.0, maximum=32.0, step=0.1, label="CFG Scale", value=1.0, info="Usually 1.0 for F1 (no external CFG).") + fpe_distilled_guidance_scale = gr.Slider(minimum=1.0, maximum=32.0, step=0.1, label="Distilled Guidance (GS)", value=3.0) + # fpe_rs_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CFG Rescale (RS)", value=0.0, visible=False) + with gr.Row(): + with gr.Accordion("Advanced & Performance", open=True): + fpe_gpu_memory_preservation = gr.Slider(label="GPU Memory Preserve (GB)", minimum=1.0, maximum=16.0, value=6.0, step=0.1) + fpe_use_teacache = gr.Checkbox(label="Use TeaCache", value=False) + fpe_no_resize = gr.Checkbox(label="Force Original Video Resolution (No Resize)", value=False) + fpe_extension_only = gr.Checkbox(label="Save Extension Only", value=False, info="If checked, only the newly generated extension part of the video will be saved.") + fpe_mp4_crf = gr.Slider(label="MP4 CRF (Quality)", minimum=0, maximum=51, value=1, step=1, info="Lower is better quality, larger file.") + fpe_num_clean_frames = gr.Slider(label="Context Frames (1x from Input)", minimum=1, maximum=10, value=5, step=1) + fpe_vae_batch_size = gr.Slider(label="VAE Batch Size (Input Video Encoding)", minimum=4, maximum=128, value=72, step=4) + + fpe_attn_mode = gr.Dropdown(label="Attention Mode (DiT)", choices=["torch", "sdpa", "flash", "xformers", "sageattn"], value="torch") + fpe_fp8_llm = gr.Checkbox(label="Use FP8 LLM (Text Encoder 1)", value=False, visible=False) + fpe_vae_chunk_size = gr.Number(label="VAE Chunk Size (CausalConv3d)", value=32, step=1, minimum=0, info="0 or None=disable") + fpe_vae_spatial_tile_sample_min_size = gr.Number(label="VAE Spatial Tile Min Size", value=128, step=16, minimum=0, info="0 or None=disable") + + + with gr.Column(): # Right column for outputs and advanced settings + fpe_output_gallery = gr.Gallery( + label="Generated Extended Videos", columns=[1], rows=[1], # Show one main video at a time + object_fit="contain", height=480, show_label=True, + elem_id="gallery_framepack_extension", allow_preview=True, preview=True + ) + with gr.Accordion("Live Preview (During Generation)", open=True): + with gr.Row(): + fpe_enable_preview = gr.Checkbox(label="Enable Live Preview", value=True, visible=False) + fpe_preview_interval = gr.Slider( + minimum=1, maximum=50, step=1, value=5, + label="Preview Every N Steps", + info="Saves a PNG preview during sampling.", + visible=False + ) + fpe_preview_output_component = gr.Video( # Changed to Video for MP4 previews + label="Latest Section Preview", height=300, + interactive=False, elem_id="fpe_preview_video" + ) + # fpe_skip_btn = gr.Button("Skip Batch Item", elem_classes="light-blue-btn") # Optional + gr.Markdown("### LoRA Configuration") + with gr.Row(): + fpe_refresh_lora_btn = gr.Button("🔄 LoRA", elem_classes="refresh-btn") + fpe_lora_folder = gr.Textbox(label="LoRA Folder", value="lora", scale=4) + fpe_lora_weights_ui = [] + fpe_lora_multipliers_ui = [] + for i in range(4): + with gr.Row(): + fpe_lora_weights_ui.append(gr.Dropdown( + label=f"LoRA {i+1}", choices=get_lora_options("lora"), + value="None", allow_custom_value=False, interactive=True, scale=2 + )) + fpe_lora_multipliers_ui.append(gr.Slider( + label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0, scale=1, interactive=True + )) + with gr.Row(): + with gr.Accordion("Model Paths (FramePack-Extension)", open=False): + fpe_transformer_path = gr.Textbox(label="DiT Path (F1 Model)", value="hunyuan/FramePack_F1_I2V_HY_20250503.safetensors") # Default to F1 + fpe_vae_path = gr.Textbox(label="VAE Path", value="hunyuan/pytorch_model.pt") + fpe_text_encoder_path = gr.Textbox(label="Text Encoder 1 (Llama)", value="hunyuan/llava_llama3_fp16.safetensors") + fpe_text_encoder_2_path = gr.Textbox(label="Text Encoder 2 (CLIP)", value="hunyuan/clip_l.safetensors") + fpe_image_encoder_path = gr.Textbox(label="Image Encoder (SigLIP)", value="hunyuan/model.safetensors") + fpe_save_path = gr.Textbox(label="Save Path (Output Directory)", value="outputs/framepack_extensions") + + # Text to Video Tab + with gr.Tab(id=1, label="Hunyuan-t2v"): + with gr.Row(): + with gr.Column(scale=4): + prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + + with gr.Column(scale=1): + token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + + t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") + t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") + video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") + fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") + infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") + flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") + cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") + + with gr.Column(): + + with gr.Row(): + video_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + with gr.Row(): + refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + lora_weights = [] + lora_multipliers = [] + for i in range(4): + with gr.Column(): + lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + seed = gr.Number(label="Seed (use -1 for random)", value=-1) + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + + #Image to Video Tab + with gr.Tab(label="Hunyuan-i2v") as i2v_tab: # Keep tab name consistent if needed elsewhere + # ... (Keep existing Rows for prompt, batch size, progress) ... + with gr.Row(): + with gr.Column(scale=4): + i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + + with gr.Column(scale=1): + i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress_i2v") # Unique elem_id + i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text_i2v") # Unique elem_id + + with gr.Row(): + i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + i2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + + with gr.Row(): + with gr.Column(): + i2v_input = gr.Image(label="Input Image", type="filepath") + # REMOVED i2v_strength slider, as hv_i2v_generate_video.py doesn't seem to use it based on the sample command + # i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale % (UI Only - affects W/H)") # Clarified UI only + original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + # Width and height inputs + with gr.Row(): + # Renamed width/height to avoid potential conflicts if they weren't already prefixed + i2v_width = gr.Number(label="New Width", value=720, step=16) # Default from sample + calc_height_btn = gr.Button("→") + calc_width_btn = gr.Button("←") + i2v_height = gr.Number(label="New Height", value=720, step=16) # Default from sample + i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=49) # Default from sample + i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) # Default from sample + i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) # Default from sample + i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=17.0) # Default from sample + i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="Embedded CFG Scale", value=7.0) # Default from sample + i2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale (CFG)", value=1.0) # Default from sample (usually 1.0 for no CFG) + + with gr.Column(): + i2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery_i2v", # Unique elem_id + allow_preview=True, + preview=True + ) + i2v_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") # Keep sending to original V2V + + # Add LoRA section for Image2Video + i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + i2v_lora_weights = [] + i2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + i2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + i2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + i2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states_i2v.pt", # Default from sample + allow_custom_value=True, + interactive=True + ) + i2v_vae = gr.Textbox(label="VAE Path", value="hunyuan/pytorch_model.pt") # Default from sample + i2v_te1 = gr.Textbox(label="Text Encoder 1 Path", value="hunyuan/llava_llama3_fp16.safetensors") # Default from sample + i2v_te2 = gr.Textbox(label="Text Encoder 2 Path", value="hunyuan/clip_l.safetensors") # Default from sample + i2v_clip_vision_path = gr.Textbox(label="CLIP Vision Path", value="hunyuan/llava_llama3_vision.safetensors") # Default from sample + i2v_save_path = gr.Textbox(label="Save Path", value="outputs") # Default from sample + with gr.Row(): + i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") # Default from sample + i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) # Not in sample, keep default False + i2v_use_fp8 = gr.Checkbox(label="Use FP8 DiT", value=False) # Not in sample, keep default False + i2v_fp8_llm = gr.Checkbox(label="Use FP8 LLM", value=False) # Not in sample, keep default False + i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") # Default from sample + i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=30) # Default from sample + # Add VAE tiling options like sample command + i2v_vae_chunk_size = gr.Number(label="VAE Chunk Size", value=32, step=1, info="For CausalConv3d, set 0 to disable") + i2v_vae_spatial_tile_min = gr.Number(label="VAE Spatial Tile Min Size", value=128, step=16, info="Set 0 to disable spatial tiling") + + # Video to Video Tab + with gr.Tab(id=2, label="Hunyuan v2v") as v2v_tab: + with gr.Row(): + with gr.Column(scale=4): + v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + v2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt (for SkyReels models)", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + v2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + v2v_input = gr.Video(label="Input Video", format="mp4") + v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and Height Inputs + with gr.Row(): + v2v_width = gr.Number(label="New Width", value=544, step=16) + v2v_calc_height_btn = gr.Button("→") + v2v_calc_width_btn = gr.Button("←") + v2v_height = gr.Number(label="New Height", value=544, step=16) + v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) + with gr.Column(): + v2v_output = gr.Gallery( + label="Generated Videos", + columns=[1], + rows=[1], + object_fit="contain", + height="auto" + ) + v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button + v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + v2v_lora_weights = [] + v2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + v2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + v2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + v2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + v2v_save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) + +### SKYREELS + + with gr.Tab(label="SkyReels-i2v") as skyreels_tab: + with gr.Row(): + with gr.Column(scale=4): + skyreels_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + skyreels_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") + with gr.Row(): + skyreels_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) + skyreels_input_folder = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images", + visible=False + ) + skyreels_folder_status = gr.Textbox( + label="Folder Status", + placeholder="Status will appear here", + interactive=False, + visible=False + ) + skyreels_validate_folder_btn = gr.Button("Validate Folder", visible=False) + skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + + # Scale slider as percentage + skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height inputs + with gr.Row(): + skyreels_width = gr.Number(label="New Width", value=544, step=16) + skyreels_calc_height_btn = gr.Button("→") + skyreels_calc_width_btn = gr.Button("←") + skyreels_height = gr.Number(label="New Height", value=544, step=16) + + skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) + skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) + + with gr.Column(): + skyreels_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for SKYREELS + skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + skyreels_lora_weights = [] + skyreels_lora_multipliers = [] + for i in range(4): + with gr.Column(): + skyreels_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + skyreels_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + skyreels_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("skyreels"), + value="skyreels_hunyuan_i2v_bf16.safetensors", + allow_custom_value=True, + interactive=True + ) + skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) + + # WanX Image to Video Tab + with gr.Tab(id=4, label="WanX-i2v") as wanx_i2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + lines=3, + ) + + with gr.Column(scale=1): + wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + wanx_input = gr.Image(label="Input Image", type="filepath") + with gr.Row(): + wanx_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) + wanx_input_folder = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images", + visible=False + ) + wanx_folder_status = gr.Textbox( + label="Folder Status", + placeholder="Status will appear here", + interactive=False, + visible=False + ) + wanx_validate_folder_btn = gr.Button("Validate Folder", visible=False) + with gr.Row(): + wanx_use_end_image = gr.Checkbox(label="use ending image", value=False) + wanx_input_end = gr.Image(label="End Image", type="filepath", visible=False) + wanx_trim_frames = gr.Checkbox(label="trim last 3 frames", value=True, visible=False, interactive=True) + + with gr.Row(): + wanx_use_fun_control = gr.Checkbox(label="Use Fun-Control Model", value=False) + wanx_control_video = gr.Video(label="Control Video for Fun-Control", visible=False, format="mp4") + wanx_control_strength = gr.Slider(minimum=0.1, maximum=2.0, step=0.05, value=1.0, + label="Control Strength", visible=False, + info="Adjust influence of control video (1.0 = normal)") + wanx_control_start = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=0.0, + label="Control Start (Fun-Control fade-in)", + visible=False, + info="When (0-1) in the timeline control influence is full after fade-in" + ) + wanx_control_end = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=1.0, + label="Control End (Fun-Control fade-out start)", + visible=False, + info="When (0-1) in the timeline control starts to fade out" + ) + wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height display + with gr.Row(): + wanx_width = gr.Number(label="Width", value=832, interactive=True) + wanx_calc_height_btn = gr.Button("→") + wanx_calc_width_btn = gr.Button("←") + wanx_height = gr.Number(label="Height", value=480, interactive=True) + wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_video_length = gr.Slider(minimum=1, maximum=401, step=4, label="Video Length in Frames", value=81) + wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, + info="Recommended: 3.0 for 480p, 5.0 for others") + wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + with gr.Accordion("Latent Preview (During Generation)", open=True): + wanx_enable_preview = gr.Checkbox(label="Enable Latent Preview", value=True) + wanx_preview_steps = gr.Slider(minimum=1, maximum=50, step=1, value=5, + label="Preview Every N Steps", info="Generates previews during the sampling loop.") + wanx_preview_output = gr.Gallery( + label="Latent Previews", columns=4, rows=2, object_fit="contain", height=300, + allow_preview=True, preview=True, show_label=True, elem_id="wanx_preview_gallery" + ) + wanx_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") + wanx_i2v_send_to_wanx_v2v_btn = gr.Button("Send Selected to WanX-v2v") + wanx_send_last_frame_btn = gr.Button("Send Last Frame to Input") + wanx_extend_btn = gr.Button("Extend Video") + wanx_frames_to_check = gr.Slider(minimum=1, maximum=100, step=1, value=30, + label="Frames to Check from End", + info="Number of frames from the end to check for sharpness") + wanx_send_sharpest_frame_btn = gr.Button("Extract Sharpest Frame") + wanx_trim_and_extend_btn = gr.Button("Trim Video & Prepare for Extension") + wanx_sharpest_frame_status = gr.Textbox(label="Status", interactive=False) + + # Add a new button for directly extending with the trimmed video + wanx_extend_with_trimmed_btn = gr.Button("Extend with Trimmed Video") + + # Add LoRA section for WanX-i2v similar to other tabs + wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_lora_weights = [] + wanx_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + # Update the wanx_task dropdown choices to include Fun-Control options + wanx_task = gr.Dropdown( + label="Task", + choices=["i2v-14B", "i2v-14B-FC", "i2v-14B-FC-1.1", "t2v-14B", "t2v-1.3B", "t2v-14B-FC", "t2v-1.3B-FC", "i2v-1.3B-new"], + value="i2v-14B", + info="Select model type. *-FC options enable Fun-Control features" + ) + wanx_dit_folder = gr.Textbox(label="DiT Model Folder", value="wan") + wanx_dit_path = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("wan"), # Use the existing function to get available models + value="wan2.1_i2v_720p_14B_fp16.safetensors", + allow_custom_value=True, + interactive=True + ) + wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") + wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) + + with gr.Column(): + wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_fp8_scaled = gr.Checkbox(label="Use Scaled FP8", value=False, info="For mixing fp16/bf16 and fp8 weights") + wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + # Add new row for Skip Layer Guidance options + with gr.Row(): + wanx_slg_layers = gr.Textbox(label="SLG Layers", value="", placeholder="Comma-separated layer indices, e.g. 1,5,10", info="Layers to skip for guidance") + wanx_slg_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG Start", value=0.0, info="When to start skipping layers (% of total steps)") + wanx_slg_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG End", value=1.0, info="When to stop skipping layers (% of total steps)") + + with gr.Row(): + wanx_enable_cfg_skip = gr.Checkbox(label="Enable CFG Skip (similar to teacache)", value=False) + with gr.Column(visible=False) as wanx_cfg_skip_options: + wanx_cfg_skip_mode = gr.Radio( + choices=["early", "late", "middle", "early_late", "alternate", "none"], + label="CFG Skip Mode", + value="none", + info="Controls which steps to apply CFG on" + ) + wanx_cfg_apply_ratio = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.7, + label="CFG Apply Ratio", + info="Ratio of steps to apply CFG (0.0-1.0). Lower values = faster, but less accurate" + ) + + #WanX-t2v Tab + + # WanX Text to Video Tab + with gr.Tab(id=5, label="WanX-t2v") as wanx_t2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_t2v_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_t2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") + wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") + wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, + info="Recommended: 3.0 for I2V with 480p, 5.0 for others") + wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_t2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + with gr.Accordion("Latent Preview (During Generation)", open=False): + wanx_t2v_enable_preview = gr.Checkbox(label="Enable Latent Preview", value=False) + wanx_t2v_preview_steps = gr.Slider(minimum=1, maximum=50, step=1, value=5, + label="Preview Every N Steps", info="Generates previews during the sampling loop.") + wanx_t2v_preview_output = gr.Gallery( + label="Latent Previews", columns=4, rows=2, object_fit="contain", height=300, + allow_preview=True, preview=True, show_label=True, elem_id="wanx_t2v_preview_gallery" + ) + wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan v2v") + wanx_t2v_send_to_wanx_v2v_btn = gr.Button("Send Selected to WanX-v2v") + + # Add LoRA section for WanX-t2v + wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_t2v_lora_weights = [] + wanx_t2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_t2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_t2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_t2v_task = gr.Dropdown( + label="Task", + choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], + value="t2v-14B", + info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" + ) + wanx_t2v_dit_path = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("wan"), + value="wan2.1_t2v_14B_fp16.safetensors", + allow_custom_value=True, + interactive=True + ) + wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") + wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, + info="Max 39 for 14B model, 29 for 1.3B model") + + with gr.Column(): + wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_t2v_fp8_scaled = gr.Checkbox(label="Use Scaled FP8", value=False, + info="For mixing fp16/bf16 and fp8 weights") + wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + # Add new row for Skip Layer Guidance options + with gr.Row(): + wanx_t2v_slg_layers = gr.Textbox(label="SLG Layers", value="", placeholder="Comma-separated layer indices, e.g. 1,5,10", info="Layers to skip for guidance") + wanx_t2v_slg_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG Start", value=0.0, info="When to start skipping layers (% of total steps)") + wanx_t2v_slg_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG End", value=1.0, info="When to stop skipping layers (% of total steps)") + wanx_t2v_use_random_folder = gr.Checkbox(visible=False, value=False, label="Use Random Images") + wanx_t2v_input_folder = gr.Textbox(visible=False, value="", label="Image Folder") + wanx_t2v_input_end = gr.Textbox(visible=False, value="none", label="End Frame") + + with gr.Row(): + wanx_t2v_enable_cfg_skip = gr.Checkbox(label="Enable CFG Skip (similar to teacache)", value=False) + with gr.Column(visible=False) as wanx_t2v_cfg_skip_options: + wanx_t2v_cfg_skip_mode = gr.Radio( + choices=["early", "late", "middle", "early_late", "alternate", "none"], + label="CFG Skip Mode", + value="none", + info="Controls which steps to apply CFG on" + ) + wanx_t2v_cfg_apply_ratio = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.7, + label="CFG Apply Ratio", + info="Ratio of steps to apply CFG (0.0-1.0). Lower values = faster, but less accurate" + ) + + #WanX-v2v Tab + with gr.Tab(id=6, label="WanX-v2v") as wanx_v2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_v2v_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_v2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_v2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + wanx_v2v_input = gr.Video(label="Input Video", format="mp4") + wanx_v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength", + info="0 = keep original, 1 = full generation") + wanx_v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + wanx_v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and Height Inputs + with gr.Row(): + wanx_v2v_width = gr.Number(label="New Width", value=832, step=32) + wanx_v2v_calc_height_btn = gr.Button("→") + wanx_v2v_calc_width_btn = gr.Button("←") + wanx_v2v_height = gr.Number(label="New Height", value=480, step=32) + wanx_v2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_v2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=40) + wanx_v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, + info="Recommended: 3.0 for 480p, 5.0 for others") + wanx_v2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_v2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + wanx_v2v_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") + + # Add LoRA section for WanX-v2v + wanx_v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_v2v_lora_weights = [] + wanx_v2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_v2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_v2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_v2v_task = gr.Dropdown( + label="Task", + choices=["t2v-14B", "t2v-1.3B"], + value="t2v-14B", + info="Model size: t2v-1.3B is faster, t2v-14B has higher quality" + ) + wanx_v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="wan") + wanx_v2v_dit_path = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("wan"), + value="wan2.1_t2v_14B_fp16.safetensors", + allow_custom_value=True, + interactive=True + ) + wanx_v2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_v2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_v2v_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_v2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_v2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, + info="Max 39 for 14B model, 29 for 1.3B model") + + with gr.Column(): + wanx_v2v_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_v2v_fp8_scaled = gr.Checkbox(label="Use Scaled FP8", value=False, + info="For mixing fp16/bf16 and fp8 weights") + wanx_v2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + # Add Skip Layer Guidance options + with gr.Row(): + wanx_v2v_slg_layers = gr.Textbox(label="SLG Layers", value="", placeholder="Comma-separated layer indices, e.g. 1,5,10", info="Layers to skip for guidance") + wanx_v2v_slg_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG Start", value=0.0, info="When to start skipping layers (% of total steps)") + wanx_v2v_slg_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG End", value=1.0, info="When to stop skipping layers (% of total steps)") + + with gr.Row(): + wanx_v2v_enable_cfg_skip = gr.Checkbox(label="Enable CFG Skip (similar to teacache)", value=False) + with gr.Column(visible=False) as wanx_v2v_cfg_skip_options: + wanx_v2v_cfg_skip_mode = gr.Radio( + choices=["early", "late", "middle", "early_late", "alternate", "none"], + label="CFG Skip Mode", + value="none", + info="Controls which steps to apply CFG on" + ) + wanx_v2v_cfg_apply_ratio = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.7, + label="CFG Apply Ratio", + info="Ratio of steps to apply CFG (0.0-1.0). Lower values = faster, but less accurate" + ) + + #Video Info Tab + with gr.Tab("Video Info") as video_info_tab: + with gr.Row(): + video_input = gr.Video(label="Upload Video", interactive=True) + metadata_output = gr.JSON(label="Generation Parameters") + + with gr.Row(): + send_to_fpe_btn = gr.Button("Send to FramePack-Extension", variant="primary") + send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") + send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") + with gr.Row(): + send_to_framepack_btn = gr.Button("Send to FramePack", variant="primary") + send_to_wanx_i2v_btn = gr.Button("Send to WanX-i2v", variant="primary") + send_to_wanx_t2v_btn = gr.Button("Send to WanX-t2v", variant="primary") + send_to_wanx_v2v_btn = gr.Button("Send to WanX-v2v", variant="primary") + + + with gr.Row(): + status = gr.Textbox(label="Status", interactive=False) + + #Convert lora tab + with gr.Tab("Convert LoRA") as convert_lora_tab: + def suggest_output_name(file_obj) -> str: + """Generate suggested output name from input file""" + if not file_obj: + return "" + # Get input filename without extension and add MUSUBI + base_name = os.path.splitext(os.path.basename(file_obj.name))[0] + return f"{base_name}_MUSUBI" + + def convert_lora(input_file, output_name: str, target_format: str) -> str: + """Convert LoRA file to specified format""" + try: + if input_file is None: + return "Error: No input file selected" + + # Ensure output directory exists + os.makedirs("lora", exist_ok=True) + + # Construct output path + output_path = os.path.join("lora", f"{output_name}.safetensors") + + # Determine which script to use based on target_format + if target_format == "Hunyuan to FramePack": + script_name = "convert_hunyuan_to_framepack.py" + cmd = [ + sys.executable, + script_name, + "--input", input_file.name, + "--output", output_path + ] + print(f"Using '{script_name}' to convert {input_file.name} to {output_path} for FramePack.") + else: # Existing logic for "default" and "other" + script_name = "convert_lora.py" + cmd = [ + sys.executable, + script_name, + "--input", input_file.name, + "--output", output_path, + "--target", target_format.lower() + ] + + print(f"Running conversion command: {' '.join(cmd)}") + + # Check if the selected script file exists + if not os.path.exists(script_name): + return f"Error: Conversion script '{script_name}' not found. Please ensure it's in the same directory as h1111.py." + + # Execute conversion + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + console_output = result.stdout if result.stdout else "" + if result.stderr: + console_output += f"\n--- Script STDERR ---\n{result.stderr}" + if not console_output.strip(): + console_output = "Conversion script completed with no output." + if os.path.exists(output_path): + console_output += f"\n[UI Info] Output file confirmed by h1111.py at: {output_path}" + else: + console_output += f"\n[UI Warning] Output file NOT found by h1111.py at expected location: {output_path}" + return console_output.strip() + except subprocess.CalledProcessError as e: + error_message = f"Conversion Script Error (Exit Code: {e.returncode}):\n" + if e.stdout and e.stdout.strip(): + error_message += f"--- Script STDOUT ---\n{e.stdout.strip()}\n" + if e.stderr and e.stderr.strip(): + error_message += f"--- Script STDERR ---\n{e.stderr.strip()}\n" + if not (e.stdout and e.stdout.strip()) and not (e.stderr and e.stderr.strip()): + error_message += "Script produced no output on STDOUT or STDERR." + + print(f"Subprocess error details logged to console. UI will show combined script output.") # Log for server console + return error_message.strip() + + + with gr.Row(): + input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) + output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") + format_radio = gr.Radio( + choices=["default", "other", "Hunyuan to FramePack"], # <-- Added new choice here + value="default", + label="Target Format", + info="Choose 'default' for H1111/MUSUBI format, 'other' for diffusion pipe format, or 'Hunyuan to FramePack' for FramePack compatibility." + ) + + with gr.Row(): + convert_btn = gr.Button("Convert LoRA", variant="primary") + status_output = gr.Textbox(label="Status", interactive=False) + + # Automatically update output name when file is selected + input_file.change( + fn=suggest_output_name, + inputs=[input_file], + outputs=[output_name] + ) + + # Handle conversion + convert_btn.click( + fn=convert_lora, + inputs=[input_file, output_name, format_radio], + outputs=status_output + ) + with gr.Tab("Model Merging") as model_merge_tab: + with gr.Row(): + with gr.Column(): + # Model selection + dit_model = gr.Dropdown( + label="Base DiT Model", + choices=["mp_rank_00_model_states.pt"], + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + with gr.Row(): + with gr.Column(): + # Output model name + output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + merge_btn = gr.Button("Merge Models", variant="primary") + merge_status = gr.Textbox(label="Status", interactive=False) + with gr.Row(): + # LoRA selection section (similar to Text2Video) + merge_lora_weights = [] + merge_lora_multipliers = [] + for i in range(4): + with gr.Column(): + merge_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + merge_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + + #Event handlers etc + +# Toggle visibility of End Frame controls and DiT path based on fpe_use_normal_framepack + def toggle_fpe_normal_framepack_options(use_normal_fp): + f1_dit_path = "hunyuan/FramePack_F1_I2V_HY_20250503.safetensors" + normal_fp_dit_path = "hunyuan/FramePackI2V_HY_bf16.safetensors" + + updated_dit_path = normal_fp_dit_path if use_normal_fp else f1_dit_path + + # Check if the target path exists and fallback if necessary + if not os.path.exists(updated_dit_path): + fallback_path = f1_dit_path if use_normal_fp and os.path.exists(f1_dit_path) else normal_fp_dit_path if not use_normal_fp and os.path.exists(normal_fp_dit_path) else None + if fallback_path and os.path.exists(fallback_path): + print(f"Warning: DiT path '{updated_dit_path}' not found. Falling back to '{fallback_path}'.") + updated_dit_path = fallback_path + else: # If preferred and fallback are missing, stick to the intended one and let later checks handle it. + print(f"Warning: DiT path '{updated_dit_path}' not found. No fallback available or fallback also missing.") + + return ( + gr.update(visible=use_normal_fp), # fpe_end_frame_accordion + gr.update(visible=not use_normal_fp), # fpe_start_guidance_accordion (NEW) + gr.update(value=updated_dit_path), # fpe_transformer_path + gr.update(visible=use_normal_fp) # fpe_fp8_llm + ) + + fpe_use_normal_framepack.change( + fn=toggle_fpe_normal_framepack_options, + inputs=[fpe_use_normal_framepack], + outputs=[ + fpe_end_frame_accordion, + fpe_start_guidance_accordion, # NEW output + fpe_transformer_path, + fpe_fp8_llm + ] + ) + + fpe_generate_btn.click( + fn=process_framepack_extension_video, + inputs=[ + fpe_input_video, fpe_prompt, fpe_negative_prompt, fpe_seed, fpe_batch_count, + fpe_use_normal_framepack, fpe_end_frame, fpe_end_frame_weight, + fpe_resolution_max_dim, fpe_total_second_length, fpe_latent_window_size, + fpe_steps, fpe_cfg_scale, fpe_distilled_guidance_scale, + fpe_gpu_memory_preservation, fpe_use_teacache, fpe_no_resize, fpe_mp4_crf, + fpe_num_clean_frames, fpe_vae_batch_size, fpe_save_path, + # Model Paths + fpe_transformer_path, fpe_vae_path, fpe_text_encoder_path, + fpe_text_encoder_2_path, fpe_image_encoder_path, + # Advanced + fpe_attn_mode, fpe_fp8_llm, fpe_vae_chunk_size, fpe_vae_spatial_tile_sample_min_size, + # LoRAs + fpe_lora_folder, + fpe_lora_weights_ui[0], fpe_lora_multipliers_ui[0], + fpe_lora_weights_ui[1], fpe_lora_multipliers_ui[1], + fpe_lora_weights_ui[2], fpe_lora_multipliers_ui[2], + fpe_lora_weights_ui[3], fpe_lora_multipliers_ui[3], + # Preview (UI state, not directly passed to scripts) + fpe_enable_preview, fpe_preview_interval, + fpe_extension_only, + fpe_start_guidance_image, + fpe_start_guidance_image_clip_weight, + fpe_use_guidance_image_as_first_latent, + ], + outputs=[ + fpe_output_gallery, + fpe_preview_output_component, + fpe_batch_progress, + fpe_progress_text + ], + queue=True + ) + + fpe_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + def handle_fpe_gallery_select(evt: gr.SelectData) -> int: + return evt.index + fpe_output_gallery.select(fn=handle_fpe_gallery_select, outputs=fpe_selected_index) + + fpe_lora_refresh_outputs_list = [] + for i in range(len(fpe_lora_weights_ui)): + fpe_lora_refresh_outputs_list.extend([fpe_lora_weights_ui[i], fpe_lora_multipliers_ui[i]]) + + fpe_refresh_lora_btn.click( + fn=refresh_lora_dropdowns_simple, + inputs=[fpe_lora_folder], + outputs=fpe_lora_refresh_outputs_list + ) + + def change_to_framepack_tab(): + return gr.Tabs(selected=10) # FramePack tab has id=10 + + def handle_send_to_framepack_tab(metadata: dict) -> Tuple[str, dict, str]: # Added str return type for state value + """Prepare parameters specifically for the FramePack tab.""" + if not metadata: + # Return default/empty values for status, params, and original_dims state + return "No parameters to send", {}, "" + + # Extract the value intended for the state here + original_dims_value = metadata.get("original_dims_str", "") + + # Return status message, the full metadata for params_state, and the specific value for framepack_original_dims state + return "Parameters ready for FramePack", metadata, original_dims_value + + send_to_framepack_btn.click( + fn=handle_send_to_framepack_tab, + inputs=[metadata_output], + outputs=[status, params_state, framepack_original_dims] # Add framepack_original_dims here + ).then( + # This lambda now prepares updates for UI components (32 items) + lambda params: ( + # Prepare the full list of 32 update values first + ( + # Fetch LoRA lists from params, default to empty lists if not found + (weights_from_meta := params.get("lora_weights", [])), + (mults_from_meta := params.get("lora_multipliers", [])), + # Create explicitly padded lists ensuring 4 elements + (padded_weights := (weights_from_meta + ["None"] * 4)[:4]), + (padded_mults := ([float(m) for m in mults_from_meta] + [1.0] * 4)[:4]), # Ensure multipliers are floats + + # Build the list of update values + [ + params.get("prompt", "cinematic video of a cat wizard casting a spell"), + params.get("negative_prompt", ""), + # Handle resolution: Prioritize explicit W/H if valid (divisible by 8), else use target_res, else default + gr_update(value=int(params["video_width"])) if params.get("video_width") and int(params.get("video_width", 0)) > 0 and int(params.get("video_width", 0)) % 8 == 0 else gr_update(value=None), + gr_update(value=int(params["video_height"])) if params.get("video_height") and int(params.get("video_height", 0)) > 0 and int(params.get("video_height", 0)) % 8 == 0 else gr_update(value=None), + # Use target resolution only if explicit width/height are *not* validly provided from metadata + gr_update(value=int(params.get("target_resolution"))) if not (params.get("video_width") and int(params.get("video_width", 0)) > 0 and int(params.get("video_width", 0)) % 8 == 0) and params.get("target_resolution") else gr_update(value=640), + params.get("video_seconds", 5.0), + params.get("fps", 30), + params.get("seed", -1), + params.get("infer_steps", 25), + params.get("embedded_cfg_scale", 10.0), # Distilled Guidance + params.get("guidance_scale", 1.0), # CFG + params.get("guidance_rescale", 0.0), # RS + params.get("sample_solver", "unipc"), + # Unpack the *padded* lists + *padded_weights, # 4 items + *padded_mults, # 4 items + # Performance/Memory + params.get("fp8", False), + params.get("fp8_scaled", False), + params.get("fp8_llm", False), + params.get("blocks_to_swap", 26), + params.get("bulk_decode", False), + params.get("attn_mode", "sdpa"), + params.get("vae_chunk_size", 32), + params.get("vae_spatial_tile_sample_min_size", 128), + params.get("device", ""), + # End Frame Blending Params - Use UI defaults + params.get("end_frame_influence", "last"), + params.get("end_frame_weight", 0.5), + params.get("is_f1", False) + ] + )[-1] # Return the list of values we just built + ) if params else [gr.update()] * 32, + inputs=params_state, # Read parameters from state + outputs=[ + # Map to FramePack components (UI only - 32 components) + framepack_prompt, + framepack_negative_prompt, + framepack_width, # Will be updated or set to None + framepack_height, # Will be updated or set to None + framepack_target_resolution, # Will be updated or set to None/default + framepack_total_second_length, + framepack_fps, + framepack_seed, + framepack_steps, + framepack_distilled_guidance_scale, + framepack_guidance_scale, + framepack_guidance_rescale, + framepack_sample_solver, + # LoRAs (unpacking the lists - 8 components total) + *framepack_lora_weights, # 4 components + *framepack_lora_multipliers, # 4 components + # Performance/Memory + framepack_fp8, + framepack_fp8_scaled, + framepack_fp8_llm, + framepack_blocks_to_swap, + framepack_bulk_decode, + framepack_attn_mode, + framepack_vae_chunk_size, + framepack_vae_spatial_tile_sample_min_size, + framepack_device, + # Map to new UI components + framepack_end_frame_influence, + framepack_end_frame_weight, + framepack_is_f1 + ] + ).then( + fn=change_to_framepack_tab, # Switch to the FramePack tab + inputs=None, + outputs=[tabs] + ) + # Connect FramePack Generate button + def update_framepack_image_dimensions(image): + """Update FramePack dimensions from uploaded image, store raw dims, set default target res""" + if image is None: + return "", gr.update(value=None), gr.update(value=None), gr.update(value=640) # Reset W/H, default target res + try: + img = Image.open(image) + w, h = img.size + original_dims_str = f"{w}x{h}" # Store raw WxH + target_res_default = 640 + # Return original dims string, clear explicit W/H, set default target res + return original_dims_str, gr.update(value=None), gr.update(value=None), gr.update(value=target_res_default) + except Exception as e: + print(f"Error reading image dimensions: {e}") + return "", gr.update(value=None), gr.update(value=None), gr.update(value=640) # Fallback + + framepack_input_image.change( + fn=update_framepack_image_dimensions, + inputs=[framepack_input_image], + outputs=[framepack_original_dims, framepack_width, framepack_height, framepack_target_resolution] + ) + + framepack_prompt.change(fn=count_prompt_tokens, inputs=framepack_prompt, outputs=framepack_token_counter) + # If explicit width/height is set (and valid), clear target resolution + def clear_target_res_on_explicit_change(val): + return gr.update(value=None) if val is not None and val > 0 else gr.update() + + framepack_scale_slider.change( + fn=update_framepack_from_scale, + inputs=[framepack_scale_slider, framepack_original_dims], + outputs=[framepack_width, framepack_height, framepack_target_resolution] # Also clears target res + ) + + framepack_calc_width_btn.click( + fn=calculate_framepack_width, + inputs=[framepack_height, framepack_original_dims], + outputs=[framepack_width] + ).then( + fn=clear_target_res_on_explicit_change, # Clear target res if width is manually set + inputs=[framepack_width], + outputs=[framepack_target_resolution] + ) + + framepack_calc_height_btn.click( + fn=calculate_framepack_height, + inputs=[framepack_width, framepack_original_dims], + outputs=[framepack_height] + ).then( + fn=clear_target_res_on_explicit_change, # Clear target res if height is manually set + inputs=[framepack_height], + outputs=[framepack_target_resolution] + ) + + framepack_width.change( + fn=clear_target_res_on_explicit_change, + inputs=[framepack_width], + outputs=[framepack_target_resolution] + ) + framepack_height.change( + fn=clear_target_res_on_explicit_change, + inputs=[framepack_height], + outputs=[framepack_target_resolution] + ) + + # If target resolution is set (and valid), clear explicit width/height + def clear_explicit_res_on_target_change(target_res): + return (gr.update(value=None), gr.update(value=None)) if target_res is not None and target_res > 0 else (gr.update(), gr.update()) + + framepack_target_resolution.change( + fn=clear_explicit_res_on_target_change, + inputs=[framepack_target_resolution], + outputs=[framepack_width, framepack_height] + ) + framepack_use_random_folder.change( + fn=lambda use_folder_mode: ( + gr.update(visible=use_folder_mode), # framepack_input_folder_path + gr.update(visible=use_folder_mode), # framepack_folder_options_row (which contains validate button and status) + gr.update(visible=not use_folder_mode) # framepack_input_image + ), + inputs=[framepack_use_random_folder], + outputs=[framepack_input_folder_path, framepack_folder_options_row, framepack_input_image] + ) + + # Validate folder button handler + framepack_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], # Reuse existing helper + inputs=[framepack_input_folder_path], + outputs=[framepack_folder_status_text] + ) + def toggle_f1_model_path(is_f1): + f1_path = "hunyuan/FramePack_F1_I2V_HY_20250503.safetensors" + standard_path = "hunyuan/FramePackI2V_HY_bf16.safetensors" + target_path = f1_path if is_f1 else standard_path + + # Check if the target path exists + if not os.path.exists(target_path): + print(f"Warning: F1 model path '{target_path}' not found. Falling back to standard path.") + # Optionally fall back or just update with the non-existent path + # Let's fall back to standard if F1 is missing, but keep standard if standard is missing (error handled later) + if is_f1 and os.path.exists(standard_path): + print(f"Falling back to standard path: {standard_path}") + return gr.update(value=standard_path) + elif is_f1: + print(f"F1 path missing and standard path also missing. Cannot automatically switch.") + # Return the intended (missing) path, error will be caught later + return gr.update(value=target_path) + else: # Standard path is missing + print(f"Warning: Standard path '{standard_path}' not found.") + return gr.update(value=target_path) # Return the missing standard path + + print(f"Switching DiT path to: {target_path}") + return gr.update(value=target_path) + + framepack_is_f1.change( + fn=toggle_f1_model_path, + inputs=[framepack_is_f1], + outputs=[framepack_transformer_path] + ) + + framepack_generate_btn.click( + fn=process_framepack_video, + inputs=[ + framepack_prompt, framepack_negative_prompt, framepack_input_image, + framepack_input_end_frame, framepack_end_frame_influence, framepack_end_frame_weight, + framepack_transformer_path, framepack_vae_path, framepack_text_encoder_path, + framepack_text_encoder_2_path, framepack_image_encoder_path, + framepack_target_resolution, framepack_width, framepack_height, framepack_original_dims, + framepack_total_second_length, framepack_video_sections, framepack_fps, framepack_seed, framepack_steps, + framepack_distilled_guidance_scale, framepack_guidance_scale, framepack_guidance_rescale, + framepack_sample_solver, framepack_latent_window_size, + framepack_fp8, framepack_fp8_scaled, framepack_fp8_llm, + framepack_blocks_to_swap, framepack_bulk_decode, framepack_attn_mode, + framepack_vae_chunk_size, framepack_vae_spatial_tile_sample_min_size, + framepack_device, + framepack_use_teacache, + framepack_teacache_steps, + framepack_teacache_thresh, + framepack_batch_size, framepack_save_path, + framepack_lora_folder, + framepack_enable_preview, + framepack_preview_every_n_sections, + framepack_use_full_video_preview, + framepack_is_f1, + framepack_use_random_folder, + framepack_input_folder_path, + *framepack_secs, *framepack_sec_prompts, *framepack_sec_images, + *framepack_lora_weights, *framepack_lora_multipliers + ], + outputs=[ + framepack_output, # Main gallery + framepack_preview_output, # Preview video player + framepack_batch_progress, # Status text + framepack_progress_text # Progress text + ], + queue=True + ) + + framepack_random_seed.click( + fn=set_random_seed, + inputs=None, + outputs=[framepack_seed] + ) + # Connect FramePack Stop button + framepack_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Connect FramePack Gallery selection + def handle_framepack_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + framepack_output.select( + fn=handle_framepack_gallery_select, + outputs=framepack_selected_index + ) + + # FramePack LoRA Refresh Button Handler + framepack_lora_refresh_outputs = [] + for i in range(len(framepack_lora_weights)): + framepack_lora_refresh_outputs.extend([framepack_lora_weights[i], framepack_lora_multipliers[i]]) + + framepack_refresh_lora_btn.click( + fn=refresh_lora_dropdowns_simple, # Use the new simplified function + inputs=[framepack_lora_folder], # Only needs the folder path as input + outputs=framepack_lora_refresh_outputs # Still outputs updates to all 8 components + ) + def trigger_skip(): + """Sets the skip event and returns a status message.""" + print("FramePack Skip button clicked, setting skip_event.") + skip_event.set() + return "Skip signal sent..." + + framepack_skip_btn.click( + fn=trigger_skip, + inputs=None, + outputs=[framepack_batch_progress], # Update status text + queue=False # Send signal immediately + ) + + def toggle_fun_control(use_fun_control): + """Toggle control video visibility and update task suffix""" + # Only update visibility, don't try to set paths + return gr.update(visible=use_fun_control) + + def update_task_for_funcontrol(use_fun_control, current_task): + """Add or remove -FC suffix from task based on checkbox""" + if use_fun_control: + if not current_task.endswith("-FC"): + if "i2v" in current_task: + return "i2v-14B-FC" + elif "t2v" in current_task: + return "t2v-14B-FC" + return current_task + else: + if current_task.endswith("-FC"): + return current_task.replace("-FC", "") + return current_task + + wanx_use_fun_control.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)), + inputs=[wanx_use_fun_control], + outputs=[wanx_control_video, wanx_control_strength, wanx_control_start, wanx_control_end] + ) + + # Make task change update checkbox state + def update_from_task(task): + """Update Fun-Control checkbox and control video visibility based on task""" + is_fun_control = "-FC" in task + return gr.update(value=is_fun_control), gr.update(visible=is_fun_control) + + wanx_task.change( + fn=update_from_task, + inputs=[wanx_task], + outputs=[wanx_use_fun_control, wanx_control_video] + ) + wanx_enable_cfg_skip.change( + fn=lambda x: gr.update(visible=x), + inputs=[wanx_enable_cfg_skip], + outputs=[wanx_cfg_skip_options] + ) + + wanx_t2v_enable_cfg_skip.change( + fn=lambda x: gr.update(visible=x), + inputs=[wanx_t2v_enable_cfg_skip], + outputs=[wanx_t2v_cfg_skip_options] + ) + + wanx_v2v_enable_cfg_skip.change( + fn=lambda x: gr.update(visible=x), + inputs=[wanx_v2v_enable_cfg_skip], + outputs=[wanx_v2v_cfg_skip_options] + ) + + #WanX-v2v tab functions + wanx_v2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_v2v_prompt, outputs=wanx_v2v_token_counter) + + # Stop button handler + wanx_v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Video input handling + wanx_v2v_input.change( + fn=update_wanx_v2v_dimensions, + inputs=[wanx_v2v_input], + outputs=[wanx_v2v_original_dims, wanx_v2v_width, wanx_v2v_height] + ) + + # Flow shift recommendation button + wanx_v2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_v2v_width, wanx_v2v_height], + outputs=[wanx_v2v_flow_shift] + ) + + # Width/height calculation buttons + wanx_v2v_calc_width_btn.click( + fn=calculate_wanx_width, # Reuse function from WanX tabs + inputs=[wanx_v2v_height, wanx_v2v_original_dims], + outputs=[wanx_v2v_width] + ) + + wanx_v2v_calc_height_btn.click( + fn=calculate_wanx_height, # Reuse function from WanX tabs + inputs=[wanx_v2v_width, wanx_v2v_original_dims], + outputs=[wanx_v2v_height] + ) + + # Scale slider handling for adjusting dimensions + wanx_v2v_scale_slider.change( + fn=update_wanx_from_scale, # Reuse function from WanX tabs + inputs=[wanx_v2v_scale_slider, wanx_v2v_original_dims], + outputs=[wanx_v2v_width, wanx_v2v_height] + ) + + def change_to_wanx_v2v_tab(): + return gr.Tabs(selected=6) + + def send_wanx_t2v_to_v2v_input(gallery, selected_index): + """Send the selected WanX-t2v video to WanX-v2v input""" + if gallery is None or not gallery: + return None, None + + if selected_index is None and len(gallery) == 1: + selected_index = 0 + + if selected_index is None or selected_index >= len(gallery): + return None, None + + # Get the video path + item = gallery[selected_index] + video_path = parse_video_path(item) + + return video_path, "Video sent from WanX-t2v tab" + + wanx_t2v_send_to_wanx_v2v_btn.click( + fn=send_wanx_t2v_to_v2v_input, + inputs=[wanx_t2v_output, wanx_t2v_selected_index], + outputs=[wanx_v2v_input, wanx_v2v_batch_progress] + ).then( + fn=lambda prompt: prompt, + inputs=[wanx_t2v_prompt], + outputs=[wanx_v2v_prompt] + ).then( + fn=change_to_wanx_v2v_tab, + inputs=None, + outputs=[tabs] + ) + + # Send video from WanX-i2v to WanX-v2v + wanx_i2v_send_to_wanx_v2v_btn.click( + fn=send_wanx_t2v_to_v2v_input, # Reuse the same function + inputs=[wanx_output, wanx_i2v_selected_index], + outputs=[wanx_v2v_input, wanx_v2v_batch_progress] + ).then( + fn=lambda prompt: prompt, + inputs=[wanx_prompt], + outputs=[wanx_v2v_prompt] + ).then( + fn=change_to_wanx_v2v_tab, + inputs=None, + outputs=[tabs] + ) + + # Update model paths when task changes + def update_model_paths_for_task(task): + if "1.3B" in task: + return gr.update(value="wan/wan2.1_t2v_1.3B_fp16.safetensors") + else: + return gr.update(value="wan/wan2.1_t2v_14B_fp16.safetensors") + + wanx_v2v_task.change( + fn=update_model_paths_for_task, + inputs=[wanx_v2v_task], + outputs=[wanx_v2v_dit_path] + ) + + # Generate button handler + wanx_v2v_generate_btn.click( + fn=wanx_v2v_batch_handler, + inputs=[ + wanx_v2v_prompt, + wanx_v2v_negative_prompt, + wanx_v2v_input, + wanx_v2v_width, + wanx_v2v_height, + wanx_v2v_video_length, + wanx_v2v_fps, + wanx_v2v_infer_steps, + wanx_v2v_flow_shift, + wanx_v2v_guidance_scale, + wanx_v2v_strength, + wanx_v2v_seed, + wanx_v2v_batch_size, + wanx_v2v_task, + wanx_v2v_dit_folder, + wanx_v2v_dit_path, + wanx_v2v_vae_path, + wanx_v2v_t5_path, + wanx_v2v_save_path, + wanx_v2v_output_type, + wanx_v2v_sample_solver, + wanx_v2v_exclude_single_blocks, + wanx_v2v_attn_mode, + wanx_v2v_block_swap, + wanx_v2v_fp8, + wanx_v2v_fp8_scaled, + wanx_v2v_fp8_t5, + wanx_v2v_lora_folder, + wanx_v2v_slg_layers, + wanx_v2v_slg_start, + wanx_v2v_slg_end, + wanx_v2v_enable_cfg_skip, + wanx_v2v_cfg_skip_mode, + wanx_v2v_cfg_apply_ratio, + *wanx_v2v_lora_weights, + *wanx_v2v_lora_multipliers + ], + outputs=[wanx_v2v_output, wanx_v2v_batch_progress, wanx_v2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_v2v_batch_size], + outputs=wanx_v2v_selected_index + ) + + # Gallery selection handling + wanx_v2v_output.select( + fn=handle_wanx_v2v_gallery_select, + outputs=wanx_v2v_selected_index + ) + def change_to_tab_two(): + return gr.Tabs(selected=2) + + # Send to Hunyuan v2v tab + wanx_v2v_send_to_v2v_btn.click( + fn=send_wanx_v2v_to_hunyuan_v2v, + inputs=[ + wanx_v2v_output, + wanx_v2v_prompt, + wanx_v2v_selected_index, + wanx_v2v_width, + wanx_v2v_height, + wanx_v2v_video_length, + wanx_v2v_fps, + wanx_v2v_infer_steps, + wanx_v2v_seed, + wanx_v2v_flow_shift, + wanx_v2v_guidance_scale, + wanx_v2v_negative_prompt + ], + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + + # Add refresh button handler for WanX-v2v tab + wanx_v2v_refresh_outputs = [wanx_v2v_dit_path] # This is one output + for i in range(4): + wanx_v2v_refresh_outputs.extend([wanx_v2v_lora_weights[i], wanx_v2v_lora_multipliers[i]]) # This adds 8 more outputs + + wanx_v2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, # We need to use this function instead + inputs=[wanx_v2v_dit_folder, wanx_v2v_lora_folder, wanx_v2v_dit_path] + wanx_v2v_lora_weights + wanx_v2v_lora_multipliers, + outputs=wanx_v2v_refresh_outputs + ) + + # Add function to send videos from Video Info tab to WanX-v2v + def send_to_wanx_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: + """Handle both parameters and video transfer from Video Info to WanX-v2v tab with debugging""" + if not video_path: + return "No video selected", {}, None + + # Print debug information + print(f"VIDEO INFO TO WANX-V2V TRANSFER:") + print(f"Original metadata: {metadata}") + print(f"Video path: {video_path}") + + # Special handling for WanX-v2v prompt fields + # Create a copy of metadata with explicit prompt fields + enhanced_metadata = metadata.copy() + if "prompt" in metadata: + enhanced_metadata["wanx_v2v_prompt"] = metadata["prompt"] + if "negative_prompt" in metadata: + enhanced_metadata["wanx_v2v_negative_prompt"] = metadata["negative_prompt"] + + print(f"Enhanced metadata: {enhanced_metadata}") + + status_msg, params = send_parameters_to_tab(enhanced_metadata, "wanx_v2v") + print(f"Mapped parameters: {params}") + + return f"Parameters ready for WanX-v2v (DEBUG INFO IN CONSOLE)", enhanced_metadata, video_path + + # Then, implement a proper handler to change to the WanX-v2v tab + def change_to_wanx_v2v_tab(): + return gr.Tabs(selected=6) # WanX-v2v is tab index 6 + + # Next, connect the button to the functions with proper parameter mapping + send_to_wanx_v2v_btn.click( + fn=lambda m, v: handle_send_to_wanx_tab(m, 'wanx_v2v', v), + inputs=[metadata_output, video_input], + outputs=[status, params_state, wanx_v2v_input] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 40), + params.get("seed", -1), + params.get("flow_shift", 5.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0), + params.get("negative_prompt", ""), + params.get("strength", 0.75), + *[params.get("lora_weights", ["None"]*4)[i] if isinstance(params.get("lora_weights", []), list) and i < len(params.get("lora_weights", [])) else "None" for i in range(4)], + *[params.get("lora_multipliers", [1.0]*4)[i] if isinstance(params.get("lora_multipliers", []), list) and i < len(params.get("lora_multipliers", [])) else 1.0 for i in range(4)] + ] if params else [gr.update()]*21, + inputs=params_state, + outputs=[ + wanx_v2v_prompt, + wanx_v2v_width, + wanx_v2v_height, + wanx_v2v_video_length, + wanx_v2v_fps, + wanx_v2v_infer_steps, + wanx_v2v_seed, + wanx_v2v_flow_shift, + wanx_v2v_guidance_scale, + wanx_v2v_attn_mode, + wanx_v2v_block_swap, + wanx_v2v_negative_prompt, + wanx_v2v_strength, + *wanx_v2v_lora_weights, + *wanx_v2v_lora_multipliers + ] + ).then( + fn=change_to_wanx_v2v_tab, inputs=None, outputs=[tabs] + ) + + #Video Extension + wanx_send_last_frame_btn.click( + fn=send_last_frame_handler, + inputs=[wanx_output, wanx_i2v_selected_index], + outputs=[wanx_input, wanx_base_video] + ) + + wanx_extend_btn.click( + fn=prepare_for_batch_extension, + inputs=[wanx_input, wanx_base_video, wanx_batch_size], + outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] + ).then( + fn=lambda batch_size, base_video: + "Starting batch extension..." if base_video and batch_size > 0 else + "Error: Missing base video or invalid batch size", + inputs=[wanx_batch_size, wanx_base_video], + outputs=[wanx_batch_progress] + ).then( + # Process batch extension one at a time + fn=process_batch_extension, + inputs=[ + wanx_prompt, + wanx_negative_prompt, + wanx_input, # Input image (last frame) + wanx_base_video, # Base video to extend + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_flow_shift, + wanx_guidance_scale, + wanx_seed, + wanx_batch_size, + wanx_task, + wanx_dit_folder, # <<< Pass the folder path + wanx_dit_path, # <<< Pass the model filename + wanx_vae_path, + wanx_t5_path, + wanx_clip_path, + wanx_save_path, + wanx_output_type, + wanx_sample_solver, + wanx_exclude_single_blocks, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, + wanx_fp8_scaled, + wanx_fp8_t5, + wanx_lora_folder, + wanx_slg_layers, + wanx_slg_start, + wanx_slg_end, + # Pass LoRA weights and multipliers individually + wanx_lora_weights[0], + wanx_lora_weights[1], + wanx_lora_weights[2], + wanx_lora_weights[3], + wanx_lora_multipliers[0], + wanx_lora_multipliers[1], + wanx_lora_multipliers[2], + wanx_lora_multipliers[3] + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] + ) + + # Extract and send sharpest frame to input + wanx_send_sharpest_frame_btn.click( + fn=send_sharpest_frame_handler, + inputs=[wanx_output, wanx_i2v_selected_index, wanx_frames_to_check], + outputs=[wanx_input, wanx_base_video, wanx_sharpest_frame_number, wanx_sharpest_frame_status] + ) + + # Trim video to sharpest frame and prepare for extension + wanx_trim_and_extend_btn.click( + fn=trim_and_prepare_for_extension, + inputs=[wanx_base_video, wanx_sharpest_frame_number, wanx_save_path], + outputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status] + ).then( + fn=lambda path, status: (path, status if "Failed" in status else "Video trimmed successfully and ready for extension"), + inputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status], + outputs=[wanx_base_video, wanx_sharpest_frame_status] + ) + + wanx_extend_with_trimmed_btn.click( + # Prepare step: Sets the base video to the trimmed video path + fn=prepare_for_batch_extension, + inputs=[wanx_input, wanx_trimmed_video_path, wanx_batch_size], # Use trimmed video path here + outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] # Update base_video state + ).then( + # Actual extension processing step + fn=process_batch_extension, + inputs=[ + wanx_prompt, + wanx_negative_prompt, + wanx_input, # Input image (sharpest frame) + wanx_trimmed_video_path, # Base video to extend (the trimmed one) + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_flow_shift, + wanx_guidance_scale, + wanx_seed, + wanx_batch_size, + wanx_task, + wanx_dit_folder, # <<< Pass the folder path + wanx_dit_path, # <<< Pass the model filename + wanx_vae_path, + wanx_t5_path, + wanx_clip_path, + wanx_save_path, + wanx_output_type, + wanx_sample_solver, + wanx_exclude_single_blocks, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, + wanx_fp8_scaled, + wanx_fp8_t5, + wanx_lora_folder, + wanx_slg_layers, + wanx_slg_start, + wanx_slg_end, + # Pass LoRA weights and multipliers individually + wanx_lora_weights[0], + wanx_lora_weights[1], + wanx_lora_weights[2], + wanx_lora_weights[3], + wanx_lora_multipliers[0], + wanx_lora_multipliers[1], + wanx_lora_multipliers[2], + wanx_lora_multipliers[3] + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] + ) + + #Video Info + def handle_send_to_wanx_tab(metadata, target_tab, video_path=None): + """Common handler for sending video parameters to WanX tabs""" + if not metadata: + return "No parameters to send", {}, None # Return three values + + # Tab names for clearer messages + tab_names = { + 'wanx_i2v': 'WanX-i2v', + 'wanx_t2v': 'WanX-t2v', + 'wanx_v2v': 'WanX-v2v' + } + + # Just pass through all parameters - we'll use them in the .then() function + return f"Parameters ready for {tab_names.get(target_tab, target_tab)}", metadata, video_path + + def change_to_wanx_i2v_tab(): + return gr.Tabs(selected=4) # WanX-i2v tab index + + def change_to_wanx_t2v_tab(): + return gr.Tabs(selected=5) # WanX-t2v tab index + + + send_to_wanx_i2v_btn.click( + fn=lambda m: ("Parameters ready for WanX-i2v", m), + inputs=[metadata_output], + outputs=[status, params_state] + ).then( + # Reusing the same pattern as other tab transfers with LoRA handling + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 40), + params.get("seed", -1), + params.get("flow_shift", 3.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0), + params.get("task", "i2v-14B"), + params.get("negative_prompt", ""), + *[params.get("lora_weights", ["None"]*4)[i] if isinstance(params.get("lora_weights", []), list) and i < len(params.get("lora_weights", [])) else "None" for i in range(4)], + *[params.get("lora_multipliers", [1.0]*4)[i] if isinstance(params.get("lora_multipliers", []), list) and i < len(params.get("lora_multipliers", [])) else 1.0 for i in range(4)] + ] if params else [gr.update()]*20, + inputs=params_state, + outputs=[ + wanx_prompt, wanx_width, wanx_height, wanx_video_length, + wanx_fps, wanx_infer_steps, wanx_seed, wanx_flow_shift, + wanx_guidance_scale, wanx_attn_mode, wanx_block_swap, + wanx_task, wanx_negative_prompt, + *wanx_lora_weights, + *wanx_lora_multipliers + ] + ).then( + fn=change_to_wanx_i2v_tab, + inputs=None, + outputs=[tabs] + ) + + # 3. Update the WanX-t2v button handler + send_to_wanx_t2v_btn.click( + fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_t2v'), + inputs=[metadata_output], + outputs=[status, params_state] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 50), + params.get("seed", -1), + params.get("flow_shift", 5.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0), + params.get("negative_prompt", ""), + *[params.get("lora_weights", ["None"]*4)[i] if isinstance(params.get("lora_weights", []), list) and i < len(params.get("lora_weights", [])) else "None" for i in range(4)], + *[params.get("lora_multipliers", [1.0]*4)[i] if isinstance(params.get("lora_multipliers", []), list) and i < len(params.get("lora_multipliers", [])) else 1.0 for i in range(4)] + ] if params else [gr.update()]*20, + inputs=params_state, + outputs=[ + wanx_t2v_prompt, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_seed, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_attn_mode, + wanx_t2v_block_swap, + wanx_t2v_negative_prompt, + *wanx_t2v_lora_weights, + *wanx_t2v_lora_multipliers + ] + ).then( + fn=change_to_wanx_t2v_tab, inputs=None, outputs=[tabs] + ) + # FramePack-Extension send-to logic + def handle_send_to_fpe_tab(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: + """Prepare parameters and video path for the FramePack-Extension tab.""" + if not video_path: + return "No video selected to send to FramePack-Extension", {}, None + + # If metadata is empty, provide a message but still allow video transfer + status_msg = "Parameters ready for FramePack-Extension." + if not metadata: + status_msg = "Video sent to FramePack-Extension (no parameters found in metadata)." + metadata = {} # Ensure metadata is a dict + + return status_msg, metadata, video_path + + def change_to_fpe_tab(): + return gr.Tabs(selected=11) # FramePack-Extension tab has id=11 + + send_to_fpe_btn.click( + fn=handle_send_to_fpe_tab, + inputs=[metadata_output, video_input], + outputs=[status, params_state, fpe_input_video] # status, state for params, and video input for FPE + ).then( + lambda params: ( + ( + (is_f1_from_meta := params.get("is_f1", True)), # Default to F1 if not specified + (use_normal_fp_val := not is_f1_from_meta), # fpe_use_normal_framepack is opposite of is_f1 + + # Determine resolution_max_dim + (target_res_meta := params.get("target_resolution")), + (video_w_meta := params.get("video_width")), + (video_h_meta := params.get("video_height")), + ( + res_max_dim_val := int(target_res_meta) if target_res_meta and int(target_res_meta) > 0 + else max(int(video_w_meta), int(video_h_meta)) if video_w_meta and video_h_meta and int(video_w_meta) > 0 and int(video_h_meta) > 0 + else 640 # Default + ), + # LoRA handling + (weights_from_meta := params.get("lora_weights", [])), + (mults_from_meta := params.get("lora_multipliers", [])), + (padded_weights := (weights_from_meta + ["None"] * 4)[:4]), + (padded_mults := ([float(m) if isinstance(m, (int, float, str)) and str(m).replace('.', '', 1).isdigit() else 1.0 for m in mults_from_meta] + [1.0] * 4)[:4]), + + [ + params.get("prompt", "cinematic video of a cat wizard casting a spell"), + params.get("negative_prompt", ""), + params.get("seed", -1), + use_normal_fp_val, + # fpe_end_frame and fpe_end_frame_weight are typically not in generic metadata, use defaults + gr_update(value=None), # fpe_end_frame (Image) + gr_update(value=1.0), # fpe_end_frame_weight + res_max_dim_val, + params.get("video_seconds", params.get("total_second_length", 5.0)), # Map from FramePack's video_seconds + params.get("latent_window_size", 9), + params.get("infer_steps", params.get("steps", 25)), # Map from FramePack's infer_steps + params.get("guidance_scale", params.get("cfg_scale", 1.0)), # Map from FramePack's guidance_scale to fpe_cfg_scale + params.get("embedded_cfg_scale", params.get("distilled_guidance_scale", 3.0)), # Map from FramePack's embedded_cfg_scale + # Model Paths - use FPE defaults or specific paths from metadata if available + # The DiT path is now primarily handled by the fpe_use_normal_framepack.change event + params.get("transformer_path", "hunyuan/FramePack_F1_I2V_HY_20250503.safetensors"), # Placeholder, will be overridden + params.get("vae_path", "hunyuan/pytorch_model.pt"), + params.get("text_encoder_path", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("text_encoder_2_path", "hunyuan/clip_l.safetensors"), + params.get("image_encoder_path", "hunyuan/model.safetensors"), + # Advanced performance + params.get("attn_mode", "torch"), + params.get("fp8_llm", False), # This will be correctly set by fpe_use_normal_framepack.change + params.get("vae_chunk_size", 32), + params.get("vae_spatial_tile_sample_min_size", 128), + # LoRAs + *padded_weights, + *padded_mults, + ] + )[-1] # Return the list of values + ) if params else [gr.update()] * (18 + 8), # 18 direct params + 4 lora weights + 4 lora mults + inputs=params_state, + outputs=[ + fpe_prompt, fpe_negative_prompt, fpe_seed, + fpe_use_normal_framepack, # This will trigger its own .change event + fpe_end_frame, fpe_end_frame_weight, # These are UI only if fpe_use_normal_framepack is True + fpe_resolution_max_dim, fpe_total_second_length, fpe_latent_window_size, + fpe_steps, fpe_cfg_scale, fpe_distilled_guidance_scale, + # Model Paths + fpe_transformer_path, # Will be set by fpe_use_normal_framepack.change + fpe_vae_path, fpe_text_encoder_path, fpe_text_encoder_2_path, fpe_image_encoder_path, + # Advanced + fpe_attn_mode, fpe_fp8_llm, # fpe_fp8_llm also set by fpe_use_normal_framepack.change + fpe_vae_chunk_size, fpe_vae_spatial_tile_sample_min_size, + # LoRAs + *fpe_lora_weights_ui, *fpe_lora_multipliers_ui, + ] + ).then( + fn=change_to_fpe_tab, + inputs=None, + outputs=[tabs] + ) + #text to video + def change_to_tab_one(): + return gr.Tabs(selected=1) #This will navigate + #video to video + + def change_to_skyreels_tab(): + return gr.Tabs(selected=3) + + #SKYREELS TAB!!! + # Add state management for dimensions + def sync_skyreels_dimensions(width, height): + return gr.update(value=width), gr.update(value=height) + + # Add this function to update the LoRA dropdowns in the SKYREELS tab + def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + + # Add this function to update the models dropdown in the SKYREELS tab + def update_skyreels_model_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Add event handler for model dropdown refresh + skyreels_dit_folder.change( + fn=update_skyreels_model_dropdown, + inputs=[skyreels_dit_folder], + outputs=[skyreels_model] + ) + + # Add handlers for the refresh button + skyreels_refresh_btn.click( + fn=update_skyreels_lora_dropdowns, + inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] + ) + # Skyreels dimension handling + def calculate_skyreels_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 + return gr.update(value=new_width) + + def calculate_skyreels_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 + return gr.update(value=new_height) + + def update_skyreels_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_skyreels_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + def send_skyreels_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float, + negative_prompt: str = "" # Add this parameter + ) -> Tuple: + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + # Add event handlers for the SKYREELS tab + skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) + skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling + skyreels_input.change( + fn=update_skyreels_dimensions, + inputs=[skyreels_input], + outputs=[skyreels_original_dims, skyreels_width, skyreels_height] + ) + + skyreels_scale_slider.change( + fn=update_skyreels_from_scale, + inputs=[skyreels_scale_slider, skyreels_original_dims], + outputs=[skyreels_width, skyreels_height] + ) + + skyreels_calc_width_btn.click( + fn=calculate_skyreels_width, + inputs=[skyreels_height, skyreels_original_dims], + outputs=[skyreels_width] + ) + + skyreels_calc_height_btn.click( + fn=calculate_skyreels_height, + inputs=[skyreels_width, skyreels_original_dims], + outputs=[skyreels_height] + ) + + # Handle checkbox visibility toggling + skyreels_use_random_folder.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), + inputs=[skyreels_use_random_folder], + outputs=[skyreels_input_folder, skyreels_folder_status, skyreels_input] + ) + + # Validate folder button click handler + skyreels_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], + inputs=[skyreels_input_folder], + outputs=[skyreels_folder_status] + ) + + skyreels_use_random_folder.change( + fn=lambda x: gr.update(visible=x), + inputs=[skyreels_use_random_folder], + outputs=[skyreels_validate_folder_btn] + ) + + # Modify the skyreels_generate_btn.click event handler to use process_random_image_batch when folder mode is on + skyreels_generate_btn.click( + fn=batch_handler, + inputs=[ + skyreels_use_random_folder, + # Rest of the arguments + skyreels_prompt, + skyreels_negative_prompt, + skyreels_width, + skyreels_height, + skyreels_video_length, + skyreels_fps, + skyreels_infer_steps, + skyreels_seed, + skyreels_flow_shift, + skyreels_guidance_scale, + skyreels_embedded_cfg_scale, + skyreels_batch_size, + skyreels_input_folder, + skyreels_dit_folder, + skyreels_model, + skyreels_vae, + skyreels_te1, + skyreels_te2, + skyreels_save_path, + skyreels_output_type, + skyreels_attn_mode, + skyreels_block_swap, + skyreels_exclude_single_blocks, + skyreels_use_split_attn, + skyreels_use_fp8, + skyreels_split_uncond, + skyreels_lora_folder, + *skyreels_lora_weights, + *skyreels_lora_multipliers, + skyreels_input # Add the input image path + ], + outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[skyreels_batch_size], + outputs=skyreels_selected_index + ) + + # Gallery selection handling + skyreels_output.select( + fn=handle_skyreels_gallery_select, + outputs=skyreels_selected_index + ) + + # Send to Video2Video handler + skyreels_send_to_v2v_btn.click( + fn=send_skyreels_to_v2v, + inputs=[ + skyreels_output, skyreels_prompt, skyreels_selected_index, + skyreels_width, skyreels_height, skyreels_video_length, + skyreels_fps, skyreels_infer_steps, skyreels_seed, + skyreels_flow_shift, skyreels_guidance_scale + ] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component + outputs=[ + v2v_input, v2v_prompt, v2v_width, v2v_height, + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + + # Refresh button handler + skyreels_refresh_outputs = [skyreels_model] + for i in range(4): + skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) + + skyreels_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=skyreels_refresh_outputs + ) + + def calculate_v2v_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_width) + + def calculate_v2v_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_height) + + def update_v2v_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_v2v_dimensions(video): + if video is None: + return "", gr.update(value=544), gr.update(value=544) + cap = cv2.VideoCapture(video) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + # Make dimensions divisible by 16 + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + # Event Handlers for Video to Video Tab + v2v_input.change( + fn=update_v2v_dimensions, + inputs=[v2v_input], + outputs=[v2v_original_dims, v2v_width, v2v_height] + ) + + v2v_scale_slider.change( + fn=update_v2v_from_scale, + inputs=[v2v_scale_slider, v2v_original_dims], + outputs=[v2v_width, v2v_height] + ) + + v2v_calc_width_btn.click( + fn=calculate_v2v_width, + inputs=[v2v_height, v2v_original_dims], + outputs=[v2v_width] + ) + + v2v_calc_height_btn.click( + fn=calculate_v2v_height, + inputs=[v2v_width, v2v_original_dims], + outputs=[v2v_height] + ) + + ##Image 2 video dimension logic + def calculate_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_width) + + def calculate_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_height) + + def update_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + # Make dimensions divisible by 16 + w = (w // 16) * 16 # Changed from 8 to 16 + h = (h // 16) * 16 # Changed from 8 to 16 + return f"{w}x{h}", w, h + i2v_input.change( + fn=update_dimensions, + inputs=[i2v_input], + outputs=[original_dims, i2v_width, i2v_height] # Update correct components + ) + + scale_slider.change( + fn=update_from_scale, + inputs=[scale_slider, original_dims], + outputs=[i2v_width, i2v_height] # Update correct components + ) + + calc_width_btn.click( + fn=calculate_width, + inputs=[i2v_height, original_dims], # Update correct components + outputs=[i2v_width] + ) + + calc_height_btn.click( + fn=calculate_height, + inputs=[i2v_width, original_dims], # Update correct components + outputs=[i2v_height] + ) + + # Function to get available DiT models + def get_dit_models(dit_folder: str) -> List[str]: + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + + # Function to perform model merging + def merge_models( + dit_folder: str, + dit_model: str, + output_model: str, + exclude_single_blocks: bool, + merge_lora_folder: str, + *lora_params # Will contain both weights and multipliers + ) -> str: + try: + # Separate weights and multipliers + num_loras = len(lora_params) // 2 + weights = list(lora_params[:num_loras]) + multipliers = list(lora_params[num_loras:]) + + # Filter out "None" selections + valid_loras = [] + for weight, mult in zip(weights, multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) + + if not valid_loras: + return "No LoRA models selected for merging" + + # Create output path in the dit folder + os.makedirs(dit_folder, exist_ok=True) + output_path = os.path.join(dit_folder, output_model) + + # Prepare command + cmd = [ + sys.executable, + "merge_lora.py", + "--dit", os.path.join(dit_folder, dit_model), + "--save_merged_model", output_path + ] + + # Add LoRA weights and multipliers + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + # Execute merge operation + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + if os.path.exists(output_path): + return f"Successfully merged model and saved to {output_path}" + else: + return "Error: Output file not created" + + except subprocess.CalledProcessError as e: + return f"Error during merging: {e.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + # Update DiT model dropdown + def update_dit_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Connect events + merge_btn.click( + fn=merge_models, + inputs=[ + dit_folder, + dit_model, + output_model, + exclude_single_blocks, + merge_lora_folder, + *merge_lora_weights, + *merge_lora_multipliers + ], + outputs=merge_status + ) + + # Refresh buttons for both DiT and LoRA dropdowns + merge_refresh_btn.click( + fn=lambda f: update_dit_dropdown(f), + inputs=[dit_folder], + outputs=[dit_model] + ) + + # LoRA refresh handling + merge_refresh_outputs = [] + for i in range(4): + merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) + + merge_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, + outputs=merge_refresh_outputs + ) + # Event handlers + prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) + v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) + stop_btn.click(fn=lambda: stop_event.set(), queue=False) + v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + #Image_to_Video + def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters + img = Image.open(image_path) + + # Resize to the specified dimensions + img_resized = img.resize((width, height), Image.LANCZOS) + temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") + img_resized.save(temp_image_path) + + # Rest of function remains the same + frame_rate = 24 + duration = frames / frame_rate + command = [ + "ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", + "-t", str(duration), "-pix_fmt", "yuv420p", + "-vf", f"fps={frame_rate}", output_path + ] + + try: + subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + print(f"Video saved to {output_path}") + return True + except subprocess.CalledProcessError as e: + print(f"An error occurred while creating the video: {e}") + return False + finally: + # Clean up the temporary image file + if os.path.exists(temp_image_path): + os.remove(temp_image_path) + img.close() # Make sure to close the image file explicitly + + def generate_from_image( + image_path, + prompt, width, height, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, strength, batch_size, *lora_params + ): + """Generate video from input image with progressive updates""" + global stop_event + stop_event.clear() + + # Create temporary video path + temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") + + try: + # Convert image to video + if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): + yield [], "Failed to create temporary video", "Error in video creation" + return + + # Ensure video is fully written before proceeding + time.sleep(1) + if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: + yield [], "Failed to create temporary video", "Temporary video file is empty or missing" + return + + # Get video dimensions + try: + probe = ffmpeg.probe(temp_video_path) + video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) + if video_stream is None: + raise ValueError("No video stream found") + width = int(video_stream['width']) + height = int(video_stream['height']) + except Exception as e: + yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" + return + + # Generate the video using the temporary file + try: + generator = process_single_video( + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_params, video_path=temp_video_path, strength=strength + ) + + # Forward all generator updates + for videos, batch_text, progress_text in generator: + yield videos, batch_text, progress_text + + except Exception as e: + yield [], f"Error in video generation: {str(e)}", "Generation error" + return + + except Exception as e: + yield [], f"Unexpected error: {str(e)}", "Error occurred" + return + + finally: + # Clean up temporary file + try: + if os.path.exists(temp_video_path): + os.remove(temp_video_path) + except Exception: + pass # Ignore cleanup errors + + + # Add event handlers + i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) + i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + def handle_i2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when I2V gallery item is clicked""" + return evt.index + + def send_i2v_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: + """Send the selected video and parameters from Image2Video tab to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ + lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Use the original width and height without doubling + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) + + # Generate button handler for h-basic-i2v + i2v_generate_btn.click( + fn=process_i2v_batch, # <<< Use the new batch function + inputs=[ + i2v_prompt, + i2v_input, # Image path + i2v_width, + i2v_height, + i2v_batch_size, + i2v_video_length, + i2v_fps, + i2v_infer_steps, + i2v_seed, + i2v_dit_folder, + i2v_model, + i2v_vae, + i2v_te1, + i2v_te2, + i2v_clip_vision_path, + i2v_save_path, + i2v_flow_shift, + i2v_cfg_scale, # embedded_cfg_scale + i2v_guidance_scale, # main CFG scale + i2v_output_type, + i2v_attn_mode, + i2v_block_swap, + i2v_exclude_single_blocks, + i2v_use_split_attn, + i2v_lora_folder, + i2v_vae_chunk_size, + i2v_vae_spatial_tile_min, + # --- Add negative prompt component if you have one --- + # i2v_negative_prompt, # Uncomment if you added this textbox + # --- If no negative prompt textbox, pass None or "": --- + gr.Textbox(value="", visible=False), # Placeholder if no UI element + # --- End negative prompt handling --- + i2v_use_fp8, + i2v_fp8_llm, + *i2v_lora_weights, # Pass LoRA weights components + *i2v_lora_multipliers # Pass LoRA multipliers components + ], + outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[i2v_batch_size], + outputs=i2v_selected_index + ) + # Send to Video2Video + i2v_output.select( + fn=handle_i2v_gallery_select, + outputs=i2v_selected_index + ) + + i2v_send_to_v2v_btn.click( + fn=send_i2v_to_v2v, # Function definition needs careful review/update if args changed + inputs=[ + i2v_output, i2v_prompt, i2v_selected_index, + i2v_width, i2v_height, # <<< Use i2v width/height + i2v_video_length, i2v_fps, i2v_infer_steps, + i2v_seed, i2v_flow_shift, i2v_cfg_scale # <<< Use i2v cfg_scale (embedded) + ] + i2v_lora_weights + i2v_lora_multipliers, # <<< Use i2v LoRAs + outputs=[ + v2v_input, v2v_prompt, + v2v_width, v2v_height, # Target V2V components + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale # Target V2V components + ] + v2v_lora_weights + v2v_lora_multipliers # Target V2V LoRAs + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + #Video Info + def clean_video_path(video_path) -> str: + """Extract clean video path from Gradio's various return formats""" + print(f"Input video_path: {video_path}, type: {type(video_path)}") + if isinstance(video_path, dict): + path = video_path.get("name", "") + elif isinstance(video_path, (tuple, list)): + path = video_path[0] + elif isinstance(video_path, str): + path = video_path + else: + path = "" + print(f"Cleaned path: {path}") + return path + def handle_video_upload(video_path: str) -> Dict: + """Handle video upload and metadata extraction""" + if not video_path: + return {}, "No video uploaded" + + metadata = extract_video_metadata(video_path) + if not metadata: + return {}, "No metadata found in video" + + return metadata, "Metadata extracted successfully" + + def get_video_info(video_path: str) -> dict: + try: + probe = ffmpeg.probe(video_path) + video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') + + width = int(video_info['width']) + height = int(video_info['height']) + fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 + + # Calculate total frames + duration = float(probe['format']['duration']) + total_frames = int(duration * fps) + + # Ensure video length does not exceed 201 frames + if total_frames > 201: + total_frames = 201 + duration = total_frames / fps # Adjust duration accordingly + + return { + 'width': width, + 'height': height, + 'fps': fps, + 'total_frames': total_frames, + 'duration': duration # Might be useful in some contexts + } + except Exception as e: + print(f"Error extracting video info: {e}") + return {} + + def extract_video_details(video_path: str) -> Tuple[dict, str]: + metadata = extract_video_metadata(video_path) + video_details = get_video_info(video_path) + + # Combine metadata with video details + for key, value in video_details.items(): + if key not in metadata: + metadata[key] = value + + # Ensure video length does not exceed 201 frames + if 'video_length' in metadata: + metadata['video_length'] = min(metadata['video_length'], 201) + else: + metadata['video_length'] = min(video_details.get('total_frames', 0), 201) + + # Return both the updated metadata and a status message + return metadata, "Video details extracted successfully" + + def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: + """Create parameter mapping for target tab""" + if not metadata: + return "No parameters to send", {} + + tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" + try: + mapping = create_parameter_transfer_map(metadata, target_tab) + return f"Parameters ready for {tab_name}", mapping + except Exception as e: + return f"Error: {str(e)}", {} + + video_input.upload( + fn=extract_video_details, + inputs=video_input, + outputs=[metadata_output, status] + ) + + send_to_t2v_btn.click( + fn=lambda m: send_parameters_to_tab(m, "t2v"), + inputs=metadata_output, + outputs=[status, params_state] + ).then( + fn=change_to_tab_one, inputs=None, outputs=[tabs] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 544), # Parameter mapping is fine here + params.get("height", 544), # Parameter mapping is fine here + params.get("batch_size", 1), + params.get("video_length", 25), + params.get("fps", 24), + params.get("infer_steps", 30), + params.get("seed", -1), + params.get("model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("vae", "hunyuan/pytorch_model.pt"), + params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("te2", "hunyuan/clip_l.safetensors"), + params.get("save_path", "outputs"), + params.get("flow_shift", 11.0), + params.get("cfg_scale", 7.0), + params.get("output_type", "video"), + params.get("attn_mode", "sdpa"), + params.get("block_swap", "0"), + *[params.get(f"lora{i+1}", "") for i in range(4)], + *[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] + ] if params else [gr.update()]*26, # This lambda returns values based on param keys + inputs=params_state, + outputs=[prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, seed, # <<< CORRECTED HERE: use t2v_width, t2v_height + model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap] + lora_weights + lora_multipliers + ) + # Text to Video generation + generate_btn.click( + fn=process_batch, + inputs=[ + prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 + ], + outputs=[video_output, batch_progress, progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[batch_size], + outputs=selected_index + ) + + # Update gallery selection handling + def handle_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + # Track selected index when gallery item is clicked + video_output.select( + fn=handle_gallery_select, + outputs=selected_index + ) + + # Track selected index when Video2Video gallery item is clicked + def handle_v2v_gallery_select(evt: gr.SelectData) -> int: + """Handle gallery selection without automatically updating the input""" + return evt.index + + # Update the gallery selection event + v2v_output.select( + fn=handle_v2v_gallery_select, + outputs=v2v_selected_index + ) + + # Send button handler with gallery selection + def handle_send_button( + gallery: list, + prompt: str, + idx: int, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> tuple: + if not gallery or idx is None or idx >= len(gallery): + return (None, "", width, height, batch_size, video_length, fps, infer_steps, + seed, flow_shift, cfg_scale, + lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + "") # Add empty string for negative_prompt in the return values + + # Auto-select first item if only one exists and no selection made + if idx is None and len(gallery) == 1: + idx = 0 + + selected_item = gallery[idx] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return ( + str(video_path), + prompt, + width, + height, + batch_size, + video_length, + fps, + infer_steps, + seed, + flow_shift, + cfg_scale, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier, + "" # Add empty string for negative_prompt + ) + + send_t2v_to_v2v_btn.click( + fn=handle_send_button, + inputs=[ + video_output, prompt, selected_index, + t2v_width, t2v_height, batch_size, video_length, + fps, infer_steps, seed, flow_shift, cfg_scale + ] + lora_weights + lora_multipliers, # Remove the string here + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_batch_size, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: + """Handle both parameters and video transfer""" + status_msg, params = send_parameters_to_tab(metadata, "v2v") + return status_msg, params, video_path + + def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: + """Handle both parameters and video transfer from Video Info to V2V tab""" + if not video_path: + return "No video selected", {}, None + + status_msg, params = send_parameters_to_tab(metadata, "v2v") + # Just return the path directly + return status_msg, params, video_path + + # Send button click handler + send_to_v2v_btn.click( + fn=handle_info_to_v2v, + inputs=[metadata_output, video_input], + outputs=[status, params_state, v2v_input] + ).then( + lambda params: [ + params.get("v2v_prompt", ""), + params.get("v2v_width", 544), + params.get("v2v_height", 544), + params.get("v2v_batch_size", 1), + params.get("v2v_video_length", 25), + params.get("v2v_fps", 24), + params.get("v2v_infer_steps", 30), + params.get("v2v_seed", -1), + params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("v2v_vae", "hunyuan/pytorch_model.pt"), + params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("v2v_te2", "hunyuan/clip_l.safetensors"), + params.get("v2v_save_path", "outputs"), + params.get("v2v_flow_shift", 11.0), + params.get("v2v_cfg_scale", 7.0), + params.get("v2v_output_type", "video"), + params.get("v2v_attn_mode", "sdpa"), + params.get("v2v_block_swap", "0"), + *[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], + *[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] + ] if params else [gr.update()] * 26, + inputs=params_state, + outputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, + v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, + v2v_attn_mode, v2v_block_swap + ] + v2v_lora_weights + v2v_lora_multipliers + ).then( + lambda: print(f"Tabs object: {tabs}"), # Debug print + outputs=None + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + # Handler for sending selected video from Video2Video gallery to input + def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: + """Send the currently selected video in V2V gallery to V2V input""" + if not gallery or idx is None or idx >= len(gallery): + return None, "" + + selected_item = gallery[idx] + video_path = None + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] # Gallery returns (path, caption) + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, str): + video_path = selected_item + + if not video_path: + return None, "" + + # Check if the file exists and is accessible + if not os.path.exists(video_path): + print(f"Warning: Video file not found at {video_path}") + return None, "" + + return video_path, prompt + + v2v_send_to_input_btn.click( + fn=handle_v2v_send_button, + inputs=[v2v_output, v2v_prompt, v2v_selected_index], + outputs=[v2v_input, v2v_prompt] + ).then( + lambda: gr.update(visible=True), # Ensure the video input is visible + outputs=v2v_input + ) + + # Video to Video generation + v2v_generate_btn.click( + fn=process_batch, + inputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, + v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, + v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, + *v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, + v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 + ], + outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[v2v_batch_size], + outputs=v2v_selected_index + ) + refresh_outputs = [model] # Add model dropdown to outputs + for i in range(4): + refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) + + refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, + outputs=refresh_outputs + ) + # Image2Video refresh + i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs + for i in range(4): + i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) + + i2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, + outputs=i2v_refresh_outputs + ) + + # Video2Video refresh + v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs + for i in range(4): + v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) + + v2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, + outputs=v2v_refresh_outputs + ) + + # WanX-i2v tab connections + wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) + wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling for WanX-i2v + wanx_input.change( + fn=update_wanx_image_dimensions, + inputs=[wanx_input], + outputs=[wanx_original_dims, wanx_width, wanx_height] + ) + + # Scale slider handling for WanX-i2v + wanx_scale_slider.change( + fn=update_wanx_from_scale, + inputs=[wanx_scale_slider, wanx_original_dims], + outputs=[wanx_width, wanx_height] + ) + + # Width/height calculation buttons for WanX-i2v + wanx_calc_width_btn.click( + fn=calculate_wanx_width, + inputs=[wanx_height, wanx_original_dims], + outputs=[wanx_width] + ) + + wanx_calc_height_btn.click( + fn=calculate_wanx_height, + inputs=[wanx_width, wanx_original_dims], + outputs=[wanx_height] + ) + # Add visibility toggle for the folder input components + wanx_use_random_folder.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), + inputs=[wanx_use_random_folder], + outputs=[wanx_input_folder, wanx_folder_status, wanx_validate_folder_btn, wanx_input] + ) + def toggle_end_image(use_end_image): + return ( + gr.update(visible=use_end_image, interactive=use_end_image), # wanx_input_end + gr.update(visible=False) # wanx_trim_frames + ) + wanx_use_end_image.change( + fn=toggle_end_image, + inputs=[wanx_use_end_image], + outputs=[wanx_input_end, wanx_trim_frames] + ) + # Validate folder button handler + wanx_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], + inputs=[wanx_input_folder], + outputs=[wanx_folder_status] + ) + + # Flow shift recommendation buttons + wanx_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_width, wanx_height], + outputs=[wanx_flow_shift] + ) + + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Generate button handler + wanx_generate_btn.click( + fn=wanx_batch_handler, + inputs=[ + wanx_use_random_folder, + wanx_prompt, + wanx_negative_prompt, + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_flow_shift, + wanx_guidance_scale, + wanx_seed, + wanx_batch_size, + wanx_input_folder, + wanx_input_end, # Make sure this is passed + wanx_task, + wanx_dit_folder, + wanx_dit_path, + wanx_vae_path, + wanx_t5_path, + wanx_clip_path, + wanx_save_path, + wanx_output_type, + wanx_sample_solver, + wanx_exclude_single_blocks, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, + wanx_fp8_scaled, + wanx_fp8_t5, + wanx_lora_folder, + wanx_slg_layers, + wanx_slg_start, + wanx_slg_end, + wanx_enable_cfg_skip, + wanx_cfg_skip_mode, + wanx_cfg_apply_ratio, + # --- ADDED PREVIEW INPUTS --- + wanx_enable_preview, + wanx_preview_steps, + # --- END ADDED --- + *wanx_lora_weights, + *wanx_lora_multipliers, + wanx_input, # Input image (used as input_file in handler) + wanx_control_video, # Control video + wanx_control_strength, + wanx_control_start, + wanx_control_end, + ], + outputs=[ + wanx_output, # Main video gallery + wanx_preview_output, # ADDED: Preview gallery + wanx_batch_progress, # Status text + wanx_progress_text # Progress text + ], # Now 4 outputs + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_batch_size], + outputs=wanx_i2v_selected_index + ) + + # Add refresh button handler for WanX-i2v tab + wanx_refresh_outputs = [wanx_dit_path] # Add model dropdown to outputs + for i in range(4): + wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) + + wanx_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, # This function already exists and handles both updates + inputs=[wanx_dit_folder, wanx_lora_folder, wanx_dit_path] + wanx_lora_weights + wanx_lora_multipliers, + outputs=wanx_refresh_outputs + ) + wanx_dit_folder.change( + fn=update_dit_dropdown, + inputs=[wanx_dit_folder], + outputs=[wanx_dit_path] + ) + + wanx_dit_folder.change( + fn=update_dit_dropdown, + inputs=[wanx_dit_folder], + outputs=[wanx_t2v_dit_path] + ) + + wanx_dit_folder.change( + fn=update_dit_dropdown, + inputs=[wanx_dit_folder], + outputs=[wanx_v2v_dit_path] + ) + + # Gallery selection handling + wanx_output.select( + fn=handle_wanx_gallery_select, + inputs=[wanx_output], + outputs=[wanx_i2v_selected_index, wanx_base_video] + ) + + # Send to Video2Video handler + wanx_send_to_v2v_btn.click( + fn=send_wanx_to_v2v, + inputs=[ + wanx_output, # Gallery with videos + wanx_prompt, # Prompt text + wanx_i2v_selected_index, # Use the correct selected index state + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_seed, + wanx_flow_shift, + wanx_guidance_scale, + wanx_negative_prompt + ], + outputs=[ + v2v_input, # Video input in V2V tab + v2v_prompt, # Prompt in V2V tab + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, # Function to switch to Video2Video tab + inputs=None, + outputs=[tabs] + ) + # Connect prompt token counter + wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) + + # Stop button handler + wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Flow shift recommendation button + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Task change handler to update CLIP visibility and path + def update_clip_visibility(task): + is_i2v = "i2v" in task + return gr.update(visible=is_i2v) + + wanx_t2v_task.change( + fn=update_clip_visibility, + inputs=[wanx_t2v_task], + outputs=[wanx_t2v_clip_path] + ) + + # Generate button handler for T2V + wanx_t2v_generate_btn.click( + fn=wanx_batch_handler, + inputs=[ + wanx_t2v_use_random_folder, # use_random + wanx_t2v_prompt, # prompt + wanx_t2v_negative_prompt, # negative_prompt + wanx_t2v_width, # width + wanx_t2v_height, # height + wanx_t2v_video_length, # video_length + wanx_t2v_fps, # fps + wanx_t2v_infer_steps, # infer_steps + wanx_t2v_flow_shift, # flow_shift + wanx_t2v_guidance_scale, # guidance_scale + wanx_t2v_seed, # seed + wanx_t2v_batch_size, # batch_size + wanx_t2v_input_folder, # input_folder_path + wanx_t2v_input_end, # wanx_input_end + wanx_t2v_task, # task + wanx_dit_folder, # dit_folder (shared) + wanx_t2v_dit_path, # dit_path + wanx_t2v_vae_path, # vae_path + wanx_t2v_t5_path, # t5_path + wanx_t2v_clip_path, # clip_path (often None for t2v) + wanx_t2v_save_path, # save_path + wanx_t2v_output_type, # output_type + wanx_t2v_sample_solver, # sample_solver + wanx_t2v_exclude_single_blocks, # exclude_single_blocks + wanx_t2v_attn_mode, # attn_mode + wanx_t2v_block_swap, # block_swap + wanx_t2v_fp8, # fp8 + wanx_t2v_fp8_scaled, # fp8_scaled + wanx_t2v_fp8_t5, # fp8_t5 + wanx_t2v_lora_folder, # lora_folder + wanx_t2v_slg_layers, # slg_layers + wanx_t2v_slg_start, # slg_start + wanx_t2v_slg_end, # slg_end + wanx_t2v_enable_cfg_skip, # enable_cfg_skip + wanx_t2v_cfg_skip_mode, # cfg_skip_mode + wanx_t2v_cfg_apply_ratio, # cfg_apply_ratio + # --- ADDED PREVIEW INPUTS --- + wanx_t2v_enable_preview, + wanx_t2v_preview_steps, + # --- END ADDED --- + *wanx_t2v_lora_weights, # *lora_params (weights) + *wanx_t2v_lora_multipliers, # *lora_params (multipliers) + # --- ADDED Placeholders for trailing args expected by wanx_batch_handler --- + gr.File(value=None, visible=False), # Placeholder for input_file (None for T2V) + gr.Video(value=None, visible=False), # Placeholder for control_video (None for T2V) + gr.Number(value=1.0, visible=False), # Placeholder for control_strength + gr.Number(value=0.0, visible=False), # Placeholder for control_start + gr.Number(value=1.0, visible=False), # Placeholder for control_end + # --- END Placeholders --- + ], + outputs=[ + wanx_t2v_output, # Main video gallery + wanx_t2v_preview_output, # ADDED: Preview gallery + wanx_t2v_batch_progress, # Status text + wanx_t2v_progress_text # Progress text + ], # Now 4 outputs + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_t2v_batch_size], + outputs=wanx_t2v_selected_index + ) + + # Add refresh button handler for WanX-t2v tab + wanx_t2v_refresh_outputs = [wanx_t2v_dit_path] # This is one output + for i in range(4): + wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) # This adds 8 more outputs + + wanx_t2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, # Change to this function instead + inputs=[wanx_dit_folder, wanx_t2v_lora_folder, wanx_t2v_dit_path] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, + outputs=wanx_t2v_refresh_outputs + ) + + # Gallery selection handling + wanx_t2v_output.select( + fn=handle_wanx_t2v_gallery_select, + outputs=wanx_t2v_selected_index + ) + + # Send to Video2Video handler + wanx_t2v_send_to_v2v_btn.click( + fn=send_wanx_t2v_to_v2v, + inputs=[ + wanx_t2v_output, + wanx_t2v_prompt, + wanx_t2v_selected_index, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_seed, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_negative_prompt + ], + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) +if __name__ == "__main__": + # Make sure 'outputs' directory exists + os.makedirs("outputs", exist_ok=True) + # Optional: Clean temp_frames directory on startup + #if os.path.exists("temp_frames"): + # try: shutil.rmtree("temp_frames") + # except OSError as e: print(f"Error removing temp_frames: {e}") + os.makedirs("temp_frames", exist_ok=True) + +demo.queue().launch(server_name="0.0.0.0", share=False) \ No newline at end of file diff --git a/h1111s.py b/h1111s.py new file mode 100644 index 0000000000000000000000000000000000000000..7dafb5f9e1c0821fc41034239bea45f1dce14807 --- /dev/null +++ b/h1111s.py @@ -0,0 +1,8860 @@ +import gradio as gr +from gradio import update as gr_update +import subprocess +import threading +import time +import re +import os +import random +import tiktoken +import sys +import ffmpeg +from typing import List, Tuple, Optional, Generator, Dict, Any +import json +from gradio import themes +from gradio.themes.utils import colors +import subprocess +from PIL import Image +import math +import cv2 +import glob +import shutil +from pathlib import Path +import logging +from datetime import datetime +from tqdm import tqdm +from diffusers_helper.bucket_tools import find_nearest_bucket +import time +import argparse + + +# Add global stop event +stop_event = threading.Event() +skip_event = threading.Event() +logger = logging.getLogger(__name__) + +def refresh_lora_dropdowns_simple(lora_folder: str) -> List[gr.update]: + """Refreshes LoRA choices, always defaulting the selection to 'None'.""" + new_choices = get_lora_options(lora_folder) + results = [] + print(f"Refreshing LoRA dropdowns. Found choices: {new_choices}") # Debug print + for i in range(4): # Update all 4 slots + results.extend([ + gr.update(choices=new_choices, value="None"), # Always reset value to None + gr.update(value=1.0) # Reset multiplier + ]) + return results + +def process_framepack_extension_video( + input_video: str, + prompt: str, + negative_prompt: str, + seed: int, + batch_count: int, + fpe_use_normal_framepack: bool, + fpe_end_frame: Optional[str], + fpe_end_frame_weight: float, + resolution_max_dim: int, + total_second_length: float, + latent_window_size: int, + steps: int, + cfg_scale: float, # Maps to --cfg + distilled_guidance_scale: float, # Maps to --gs + # rs_scale: float, # --rs, usually 0.0, can be fixed or advanced option + gpu_memory_preservation: float, + use_teacache: bool, + no_resize: bool, + mp4_crf: int, + num_clean_frames: int, + vae_batch_size: int, + save_path: str, # Maps to --output_dir + # Model Paths + fpe_transformer_path: str, # DiT + fpe_vae_path: str, + fpe_text_encoder_path: str, # TE1 + fpe_text_encoder_2_path: str, # TE2 + fpe_image_encoder_path: str, + # Advanced performance + fpe_attn_mode: str, + fpe_fp8_llm: bool, + fpe_vae_chunk_size: Optional[int], + fpe_vae_spatial_tile_sample_min_size: Optional[int], + # LoRAs + fpe_lora_folder: str, + fpe_lora_weight_1: str, fpe_lora_mult_1: float, + fpe_lora_weight_2: str, fpe_lora_mult_2: float, + fpe_lora_weight_3: str, fpe_lora_mult_3: float, + fpe_lora_weight_4: str, fpe_lora_mult_4: float, + # Preview + fpe_enable_preview: bool, + fpe_preview_interval: int, # This arg is not used by f1_video_cli_local.py + fpe_extension_only: bool, + fpe_start_guidance_image: Optional[str], + fpe_start_guidance_image_clip_weight: float, + fpe_use_guidance_image_as_first_latent: bool, + *args: Any # For future expansion or unmapped params, not strictly needed here +) -> Generator[Tuple[List[Tuple[str, str]], Optional[str], str, str], None, None]: + global stop_event, skip_event + stop_event.clear() + skip_event.clear() # Assuming skip_event might be used for batch items + + if not input_video or not os.path.exists(input_video): + yield [], None, "Error: Input video for extension not found.", "" + return + + if not save_path or not save_path.strip(): + save_path = "outputs/framepack_extensions" # Default save path for extensions + os.makedirs(save_path, exist_ok=True) + + # Prepare LoRA arguments + lora_weights_paths = [] + lora_multipliers_values = [] + lora_params_ui = [ + (fpe_lora_weight_1, fpe_lora_mult_1), (fpe_lora_weight_2, fpe_lora_mult_2), + (fpe_lora_weight_3, fpe_lora_mult_3), (fpe_lora_weight_4, fpe_lora_mult_4) + ] + if fpe_lora_folder and os.path.exists(fpe_lora_folder): + for weight_name, mult_val in lora_params_ui: + if weight_name and weight_name != "None": + lora_path = os.path.join(fpe_lora_folder, weight_name) + if os.path.exists(lora_path): + lora_weights_paths.append(lora_path) + lora_multipliers_values.append(str(mult_val)) + else: + print(f"Warning: LoRA file not found: {lora_path}") + + all_generated_videos = [] + script_to_use = "f_video_end_cli_local.py" if fpe_use_normal_framepack else "f1_video_cli_local.py" + model_type_str = "Normal FramePack" if fpe_use_normal_framepack else "FramePack F1" + print(f"Using {model_type_str} model for extension via script: {script_to_use}") + + for i in range(batch_count): + if stop_event.is_set(): + yield all_generated_videos, None, "Generation stopped by user.", "" + return + skip_event.clear() + + current_seed_val = seed + if seed == -1: + current_seed_val = random.randint(0, 2**32 - 1) + elif batch_count > 1: + current_seed_val = seed + i + + # This run_id is not directly used for preview file naming by f1_video_cli_local.py + # as it constructs its own job_id based filenames for section previews. + # run_id = f"{int(time.time())}_{random.randint(1000, 9999)}_ext_s{current_seed_val}" + + current_preview_yield_path = None + last_preview_section_processed = -1 + + status_text = f"Processing Extension {i + 1}/{batch_count} (Seed: {current_seed_val})" + progress_text = "Preparing extension subprocess..." + yield all_generated_videos, current_preview_yield_path, status_text, progress_text + + command = [ + sys.executable, script_to_use, + "--input_video", str(input_video), + "--prompt", str(prompt), + "--n_prompt", str(negative_prompt), + "--seed", str(current_seed_val), + "--resolution_max_dim", str(resolution_max_dim), + "--total_second_length", str(total_second_length), # Script uses this for *additional* length + "--latent_window_size", str(latent_window_size), + "--steps", str(steps), + "--cfg", str(cfg_scale), + "--gs", str(distilled_guidance_scale), + "--rs", "0.0", + "--gpu_memory_preservation", str(gpu_memory_preservation), + "--mp4_crf", str(mp4_crf), + "--num_clean_frames", str(num_clean_frames), + "--vae_batch_size", str(vae_batch_size), + "--output_dir", str(save_path), + "--dit", str(fpe_transformer_path), "--vae", str(fpe_vae_path), + "--text_encoder1", str(fpe_text_encoder_path), "--text_encoder2", str(fpe_text_encoder_2_path), + "--image_encoder", str(fpe_image_encoder_path), + "--attn_mode", str(fpe_attn_mode), + ] + if use_teacache: command.append("--use_teacache") + if no_resize: command.append("--no_resize") + if fpe_fp8_llm: command.append("--fp8_llm") # Though F1 script might not use this + if fpe_vae_chunk_size is not None and fpe_vae_chunk_size > 0: + command.extend(["--vae_chunk_size", str(fpe_vae_chunk_size)]) + if fpe_vae_spatial_tile_sample_min_size is not None and fpe_vae_spatial_tile_sample_min_size > 0: + command.extend(["--vae_spatial_tile_sample_min_size", str(fpe_vae_spatial_tile_sample_min_size)]) + + if lora_weights_paths: + command.extend(["--lora_weight"] + lora_weights_paths) + command.extend(["--lora_multiplier"] + lora_multipliers_values) + if fpe_extension_only: + command.append("--extension_only") + # Script-specific arguments + if fpe_use_normal_framepack: + if fpe_fp8_llm: # Normal FP script uses this + command.append("--fp8_llm") + if fpe_end_frame and os.path.exists(fpe_end_frame): + command.extend(["--end_frame", str(fpe_end_frame)]) + command.extend(["--end_frame_weight", str(fpe_end_frame_weight)]) + else: + if fpe_start_guidance_image and os.path.exists(fpe_start_guidance_image): + command.extend(["--start_guidance_image", str(fpe_start_guidance_image)]) + command.extend(["--start_guidance_image_clip_weight", str(fpe_start_guidance_image_clip_weight)]) + if fpe_use_guidance_image_as_first_latent: + command.append("--use_guidance_image_as_first_latent") + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + print(f"Running FramePack-Extension Command: {' '.join(command)}") + + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, env=env, bufsize=1, universal_newlines=True) + + # Regex patterns based on script + if fpe_use_normal_framepack: # This means f_video_end_cli_local.py was used + final_video_path_regex = re.compile(r"Final (?:extended video saved:|extension-only video saved:) (.*\.mp4)") + # Regex for "--- Generating Extension: ... Section X / Y (backward) ---" + fpe_section_progress_regex = re.compile(r"--- Generating Extension: .*?: Section\s+(\d+)\s*/\s*(\d+)\s+\(backward\)") + tqdm_cli_progress_regex = re.compile(r"Sampling Extension Section .*?:\s*(\d+)%\|.*?\|\s*(\d+/\d+)\s*\[([^<]+)<([^,]+),") + else: # F1 script (f1_video_cli_local.py) was used + final_video_path_regex = re.compile(r"Final (?:extension-only )?video for seed \d+.*? saved as: (.*\.mp4)") + fpe_section_progress_regex = re.compile(r"--- F1 Extension: .*?: Section (\d+)\s*/\s*(\d+) ---") + tqdm_cli_progress_regex = re.compile(r"Sampling Extension Section .*?:\s*(\d+)%\|.*?\|\s*(\d+/\d+)\s*\[([^<]+)<([^,]+),") + fpe_preview_saved_regex = re.compile(r"MP4 Preview for section (\d+) saved: (.*\.mp4)") + + current_video_file_for_item = None + current_section_being_processed = 0 + total_sections_from_log = 0 + + for line in iter(process.stdout.readline, ''): + if stop_event.is_set(): + try: + process.terminate() + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill(); process.wait() + except Exception as e: print(f"Error terminating FPE subprocess: {e}") + yield all_generated_videos, None, "Generation stopped by user.", "" + return + if skip_event.is_set() and batch_count > 1: + print(f"Skip signal received for FPE batch item {i+1}. Terminating subprocess...") + try: + process.terminate() + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill(); process.wait() + except Exception as e: print(f"Error terminating FPE subprocess during skip: {e}") + skip_event.clear() + yield all_generated_videos, current_preview_yield_path, f"Skipping FPE item {i+1}/{batch_count}...", "" + break + + line_strip = line.strip() + if not line_strip: + continue + print(f"FPE_SUBPROCESS: {line_strip}") + + progress_text_update = line_strip + + section_match = fpe_section_progress_regex.search(line_strip) + tqdm_match_cli = tqdm_cli_progress_regex.search(line_strip) + final_video_match = final_video_path_regex.search(line_strip) + preview_saved_match = fpe_preview_saved_regex.search(line_strip) + + if preview_saved_match and fpe_enable_preview: + saved_section_num = int(preview_saved_match.group(1)) + preview_mp4_path_from_log = preview_saved_match.group(2).strip() + if os.path.exists(preview_mp4_path_from_log) and saved_section_num > last_preview_section_processed: + current_preview_yield_path = preview_mp4_path_from_log # Yield clean path + last_preview_section_processed = saved_section_num + print(f"DEBUG FPE: MP4 Preview updated from log - {current_preview_yield_path}") + # This log usually comes *after* the section info, so status might already be updated + + if section_match: + current_section_being_processed = int(section_match.group(1)) + total_sections_from_log = int(section_match.group(2)) + status_text = f"Extending Video {i + 1}/{batch_count} (Seed: {current_seed_val}) - Section {current_section_being_processed}/{total_sections_from_log}" + progress_text_update = f"Starting Section {current_section_being_processed}..." + # Fallback logic for preview (if enabled and explicit log was missed) + # This is less likely to be needed if fpe_preview_saved_regex is robust + if fpe_enable_preview and current_section_being_processed > 1: + section_to_check_for_preview = current_section_being_processed - 1 + if section_to_check_for_preview > last_preview_section_processed: + # Construct the expected preview filename based on f1_video_cli_local.py's naming + # It uses a job_id that includes seed, resolution, etc. We don't know the exact job_id here. + # Relying on "MP4 Preview for section X saved:" log is more reliable. + # For a fallback, we could glob for *partX*.mp4, but that's risky. + # For now, this fallback is removed as the primary log line should be sufficient. + pass + + + elif tqdm_match_cli: + percentage = tqdm_match_cli.group(1) + steps_iter_total = tqdm_match_cli.group(2) + time_elapsed = tqdm_match_cli.group(3).strip() + time_remaining = tqdm_match_cli.group(4).strip() + # Ensure total_sections_from_log is not zero before using in f-string + total_sections_display = total_sections_from_log if total_sections_from_log > 0 else "?" + progress_text_update = f"Section {current_section_being_processed}/{total_sections_display} - Step {steps_iter_total} ({percentage}%) | ETA: {time_remaining}" + status_text = f"Extending Video {i + 1}/{batch_count} (Seed: {current_seed_val}) - Sampling Section {current_section_being_processed}" + + elif final_video_match: + found_video_path = final_video_match.group(1).strip() + if os.path.exists(found_video_path): + current_video_file_for_item = found_video_path + progress_text_update = f"Finalizing: {os.path.basename(current_video_file_for_item)}" + status_text = f"Extension {i + 1}/{batch_count} (Seed: {current_seed_val}) - Saved" + else: + print(f"Warning FPE: Final video path from log not found: {found_video_path}") + + yield all_generated_videos, current_preview_yield_path, status_text, progress_text_update + + process.stdout.close() + return_code = process.wait() + + if return_code == 0 and current_video_file_for_item and os.path.exists(current_video_file_for_item): + all_generated_videos.append((current_video_file_for_item, f"Extended - Seed: {current_seed_val}")) + status_text = f"Extension {i + 1}/{batch_count} (Seed: {current_seed_val}) - Completed and Added" + progress_text = f"Saved: {os.path.basename(current_video_file_for_item)}" + yield all_generated_videos.copy(), None, status_text, progress_text # Clear preview after item completion + elif return_code != 0: + status_text = f"Extension {i + 1}/{batch_count} (Seed: {current_seed_val}) - Failed (Code: {return_code})" + progress_text = f"Subprocess failed. Check console for errors from f1_video_cli_local.py" + yield all_generated_videos.copy(), None, status_text, progress_text + else: # rc == 0 but no video path + status_text = f"Extension {i + 1}/{batch_count} (Seed: {current_seed_val}) - Finished, but no video file confirmed." + progress_text = "Check console logs from f1_video_cli_local.py for the saved path." + yield all_generated_videos.copy(), None, status_text, progress_text + + # The F1 script already cleans up its intermediate _partX files. + # No need for unique_preview_suffix based cleanup here for FPE. + + yield all_generated_videos, None, "FramePack-Extension Batch complete.", "" + +def set_random_seed(): + """Returns -1 to set the seed input to random.""" + return -1 + +def get_step_from_preview_path(path): # Helper function + # Extracts step number from preview filenames like latent_preview_step_005.mp4 + # or for framepack: latent_preview_section_002.mp4 (assuming sections for framepack) + # Let's adjust for potential FramePack naming convention (using 'section' instead of 'step') + base = os.path.basename(path) + match_step = re.search(r"step_(\d+)", base) + if match_step: + return int(match_step.group(1)) + match_section = re.search(r"section_(\d+)", base) # Check for FramePack section naming + if match_section: + # Maybe treat sections differently? Or just return the number? Let's return number. + return int(match_section.group(1)) + return -1 # Default if no number found + +def process_framepack_video( + prompt: str, + negative_prompt: str, + input_image: str, # Start image path + input_end_frame: Optional[str], # End image path + end_frame_influence: str, + end_frame_weight: float, + transformer_path: str, + vae_path: str, + text_encoder_path: str, + text_encoder_2_path: str, + image_encoder_path: str, + target_resolution: Optional[int], + framepack_width: Optional[int], + framepack_height: Optional[int], + original_dims_str: str, # This comes from framepack_original_dims state + total_second_length: float, + framepack_video_sections: Optional[int], + fps: int, + seed: int, + steps: int, + distilled_guidance_scale: float, + cfg: float, + rs: float, + sample_solver: str, + latent_window_size: int, + fp8: bool, + fp8_scaled: bool, + fp8_llm: bool, + blocks_to_swap: int, + bulk_decode: bool, + attn_mode: str, + vae_chunk_size: Optional[int], + vae_spatial_tile_sample_min_size: Optional[int], + device: Optional[str], + use_teacache: bool, + teacache_steps: int, + teacache_thresh: float, + batch_size: int, + save_path: str, + lora_folder: str, + enable_preview: bool, + preview_every_n_sections: int, + is_f1: bool, + use_random_folder: bool, + input_folder_path: str, + *args: Any +) -> Generator[Tuple[List[Tuple[str, str]], Optional[str], str, str], None, None]: + """Generate video using fpack_generate_video.py""" + global stop_event + stop_event.clear() + + if not save_path or not save_path.strip(): + print("Warning: save_path was empty, defaulting to 'outputs'") + save_path = "outputs" + + num_section_controls = 4 + num_loras = 4 + secs_end = num_section_controls + prompts_end = secs_end + num_section_controls + images_end = prompts_end + num_section_controls + lora_weights_end = images_end + num_loras + lora_mults_end = lora_weights_end + num_loras + + framepack_secs = args[0:secs_end] + framepack_sec_prompts = args[secs_end:prompts_end] + framepack_sec_images = args[prompts_end:images_end] + lora_weights_list = list(args[images_end:lora_weights_end]) + lora_multipliers_list = list(args[lora_weights_end:lora_mults_end]) + + if not use_random_folder and not input_image and not any(img for img in framepack_sec_images if img): + yield [], None, "Error: Input start image or at least one section image override is required when not using folder mode.", "" + return + + if use_random_folder and (not input_folder_path or not os.path.isdir(input_folder_path)): + yield [], None, f"Error: Random image folder path '{input_folder_path}' is invalid or not a directory.", "" + return + + section_prompts_parts = [] + section_images_parts = [] + index_pattern = re.compile(r"^\d+(-\d+)?$") + + for idx_str, sec_prompt, sec_image in zip(framepack_secs, framepack_sec_prompts, framepack_sec_images): + if not idx_str or not isinstance(idx_str, str) or not index_pattern.match(idx_str.strip()): + if idx_str and idx_str.strip(): + print(f"Warning: Invalid section index/range format '{idx_str}'. Skipping.") + continue + current_idx_str = idx_str.strip() + if sec_prompt and sec_prompt.strip(): + section_prompts_parts.append(f"{current_idx_str}:{sec_prompt.strip()}") + if sec_image and os.path.exists(sec_image): + section_images_parts.append(f"{current_idx_str}:{sec_image}") + + final_prompt_arg = prompt + if section_prompts_parts: + final_prompt_arg = ";;;".join(section_prompts_parts) + print(f"Using section prompt overrides: {final_prompt_arg}") + + final_image_path_arg = None + if section_images_parts: + final_image_path_arg = ";;;".join(section_images_parts) + print(f"Using section image overrides for --image_path: {final_image_path_arg}") + elif input_image: + final_image_path_arg = input_image + print(f"Using base input image for --image_path: {final_image_path_arg}") + + # These are batch-wide defaults if not overridden by folder mode + target res per item. + batch_wide_final_height, batch_wide_final_width = None, None + + if framepack_width is not None and framepack_width > 0 and framepack_height is not None and framepack_height > 0: + if framepack_width % 8 != 0 or framepack_height % 8 != 0: + yield [], None, "Error: Explicit Width and Height must be divisible by 8.", "" + return + batch_wide_final_height = int(framepack_height) + batch_wide_final_width = int(framepack_width) + print(f"Using explicit dimensions for all items: H={batch_wide_final_height}, W={batch_wide_final_width}") + elif target_resolution is not None and target_resolution > 0 and not use_random_folder: + # This case applies if: + # 1. Target resolution is set. + # 2. We are NOT in random folder mode (so aspect ratio from UI image is reliable). + if not original_dims_str: # original_dims_str comes from the UI input image + yield [], None, "Error: Target Resolution selected (not in folder mode), but no UI input image provided for aspect ratio.", "" + return + try: + orig_w, orig_h = map(int, original_dims_str.split('x')) + if orig_w <= 0 or orig_h <= 0: + yield [], None, "Error: Invalid original dimensions stored from UI image.", "" + return + bucket_dims = find_nearest_bucket(orig_h, orig_w, resolution=target_resolution) + if bucket_dims: + batch_wide_final_height, batch_wide_final_width = bucket_dims + print(f"Using Target Resolution {target_resolution} with UI image aspect. Batch-wide bucket: H={batch_wide_final_height}, W={batch_wide_final_width}") + else: + yield [], None, f"Error: Could not find bucket for Target Res {target_resolution} and UI image aspect.", "" + return + except Exception as e: + yield [], None, f"Error calculating bucket dimensions from UI image: {e}", "" + return + elif use_random_folder and target_resolution is not None and target_resolution > 0: + # Folder mode with target resolution: resolution will be determined per item. + # batch_wide_final_height and batch_wide_final_width remain None. + print(f"Folder mode with Target Resolution {target_resolution}. Resolution will be determined per item.") + elif not (framepack_width is not None and framepack_width > 0 and framepack_height is not None and framepack_height > 0) and \ + not (target_resolution is not None and target_resolution > 0): + # This is the fallback if no resolution strategy is active for the batch. + yield [], None, "Error: Resolution required. Please provide Target Resolution OR valid Width and Height (divisible by 8).", "" + return + + all_videos = [] + if framepack_video_sections is not None and framepack_video_sections > 0: + total_sections_estimate = framepack_video_sections + print(f"Using user-defined total sections for UI: {total_sections_estimate}") + else: + total_sections_estimate_float = (total_second_length * fps) / (latent_window_size * 4) + total_sections_estimate = int(max(round(total_sections_estimate_float), 1)) + print(f"Calculated total sections for UI from duration: {total_sections_estimate}") + progress_text = f"Starting FramePack generation batch ({total_sections_estimate} estimated sections per video)..." + status_text = "Preparing batch..." + yield all_videos, None, status_text, progress_text + + valid_loras_paths = [] + valid_loras_mults = [] + if lora_folder and os.path.exists(lora_folder): + for weight_name, mult in zip(lora_weights_list, lora_multipliers_list): + if weight_name and weight_name != "None": + if os.path.isabs(weight_name): + lora_path = weight_name + else: + lora_path = os.path.join(lora_folder, weight_name) + if os.path.exists(lora_path): + valid_loras_paths.append(lora_path) + valid_loras_mults.append(str(mult)) + else: + print(f"Warning: LoRA file not found: {lora_path}") + + for i in range(batch_size): # <<< START OF THE BATCH LOOP >>> + if stop_event.is_set(): + yield all_videos, None, "Generation stopped by user.", "" + return + skip_event.clear() + + last_preview_mtime = 0 + + run_id = f"{int(time.time())}_{random.randint(1000, 9999)}" + unique_preview_suffix = f"fpack_{run_id}" + preview_base_path = os.path.join(save_path, f"latent_preview_{unique_preview_suffix}") + preview_mp4_path = preview_base_path + ".mp4" + preview_png_path = preview_base_path + ".png" + + current_seed = seed + if seed == -1: current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: current_seed = seed + i + + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed})" + progress_text_update = f"Item {i+1}/{batch_size}: Preparing..." # Renamed progress_text to progress_text_update for clarity + current_video_path = None + current_preview_yield_path = None + current_input_image_for_item = input_image + current_original_dims_str_for_item = original_dims_str # Use batch-wide original_dims_str initially + + if use_random_folder: + progress_text_update = f"Item {i+1}/{batch_size}: Selecting random image..." + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + + random_image_path, random_status = get_random_image_from_folder(input_folder_path) + if random_image_path is None: + error_msg = f"Error for item {i+1}/{batch_size}: {random_status}. Skipping." + print(error_msg) + yield all_videos.copy(), None, status_text, error_msg + continue + + current_input_image_for_item = random_image_path + progress_text_update = f"Item {i+1}/{batch_size}: Using random image: {os.path.basename(random_image_path)}" + print(progress_text_update) + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + + # Derive original_dims_str_for_item from the random image if using target resolution + # and explicit UI W/H were not provided. + if target_resolution is not None and target_resolution > 0 and \ + not (framepack_width is not None and framepack_width > 0 and framepack_height is not None and framepack_height > 0): + try: + img_for_dims = Image.open(random_image_path) + rand_w, rand_h = img_for_dims.size + current_original_dims_str_for_item = f"{rand_w}x{rand_h}" + print(f"Folder mode item {i+1}: Using random image dims {current_original_dims_str_for_item} for target resolution bucketing.") + except Exception as e: + error_msg = f"Error getting dims for random image {random_image_path}: {e}. Skipping item {i+1}." + print(error_msg) + yield all_videos.copy(), None, status_text, error_msg + continue + + final_image_path_arg_for_item = None + if section_images_parts: + final_image_path_arg_for_item = ";;;".join(section_images_parts) + if current_input_image_for_item: + has_section_0_override = any(part.strip().startswith("0:") for part in section_images_parts) + if not has_section_0_override: + final_image_path_arg_for_item = f"0:{current_input_image_for_item};;;{final_image_path_arg_for_item}" + print(f"Using section image overrides (potentially with prepended base) for --image_path (item {i+1}): {final_image_path_arg_for_item}") + elif current_input_image_for_item: + final_image_path_arg_for_item = current_input_image_for_item + print(f"Using {'random' if use_random_folder else 'base'} input image as the primary for --image_path (item {i+1}): {final_image_path_arg_for_item}") + + if final_image_path_arg_for_item is None: + yield [], None, f"Error for item {i+1}: No valid start image could be determined. Ensure an image is provided.", "" + continue + + final_height_for_item, final_width_for_item = None, None + + # 1. Use batch-wide dimensions if they were set (from explicit UI W/H or target_res + UI image) + if batch_wide_final_height is not None and batch_wide_final_width is not None: + final_height_for_item = batch_wide_final_height + final_width_for_item = batch_wide_final_width + print(f"Item {i+1}: Using batch-wide dimensions: H={final_height_for_item}, W={final_width_for_item}") + # 2. Else, if using target resolution (this implies folder mode, as other cases were handled above) + elif target_resolution is not None and target_resolution > 0: + if not current_original_dims_str_for_item: # This should now be populated for folder mode + yield [], None, f"Error for item {i+1}: Target Resolution selected, but no original dimensions available for aspect ratio.", "" + continue + try: + orig_w_item, orig_h_item = map(int, current_original_dims_str_for_item.split('x')) + if orig_w_item <= 0 or orig_h_item <= 0: + yield [], None, f"Error for item {i+1}: Invalid original dimensions '{current_original_dims_str_for_item}'.", "" + continue + bucket_dims_item = find_nearest_bucket(orig_h_item, orig_w_item, resolution=target_resolution) + if bucket_dims_item: + final_height_for_item, final_width_for_item = bucket_dims_item + print(f"Item {i+1}: Using Target Resolution {target_resolution} with item-specific aspect from '{current_original_dims_str_for_item}'. Bucket: H={final_height_for_item}, W={final_width_for_item}") + else: + yield [], None, f"Error for item {i+1}: Could not find bucket for Target Res {target_resolution} and aspect {current_original_dims_str_for_item}.", "" + continue + except Exception as e_res: + yield [], None, f"Error calculating bucket dimensions for item {i+1} ({current_original_dims_str_for_item}): {e_res}", "" + continue + else: + # This case should ideally not be hit if the initial batch-wide resolution checks were thorough. + # It implies no explicit W/H, no target_res, or some other unhandled state. + yield [], None, f"Error for item {i+1}: Failed to determine resolution strategy for the item.", "" + continue # Skip this item + + if final_height_for_item is None or final_width_for_item is None: # Final check for the item + yield [], None, f"Error for item {i+1}: Final resolution could not be determined for this item.", "" + continue + + # Update status text with the preparing subprocess message + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update # Use progress_text_update + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + clear_cuda_cache() + + command = [ + sys.executable, "fpack_generate_video.py", + "--text_encoder1", text_encoder_path, "--text_encoder2", text_encoder_2_path, + "--image_encoder", image_encoder_path, + *(["--image_path", final_image_path_arg_for_item] if final_image_path_arg_for_item else []), + "--save_path", save_path, "--prompt", final_prompt_arg, + "--video_size", str(final_height_for_item), str(final_width_for_item), + *(["--video_sections", str(framepack_video_sections)] if framepack_video_sections is not None and framepack_video_sections > 0 else ["--video_seconds", str(total_second_length)]), + "--infer_steps", str(steps), "--seed", str(current_seed), + "--embedded_cfg_scale", str(distilled_guidance_scale), + "--guidance_scale", str(cfg), "--guidance_rescale", str(rs), + "--latent_window_size", str(latent_window_size), + "--sample_solver", sample_solver, "--output_type", "video", "--attn_mode", attn_mode + ] + if is_f1: command.append("--is_f1") + if transformer_path and os.path.exists(transformer_path): command.extend(["--dit", transformer_path.strip()]) + if vae_path and os.path.exists(vae_path): command.extend(["--vae", vae_path.strip()]) + if negative_prompt and negative_prompt.strip(): command.extend(["--negative_prompt", negative_prompt.strip()]) + if input_end_frame and os.path.exists(input_end_frame): command.extend(["--end_image_path", input_end_frame]) + if fp8: command.append("--fp8") + if fp8 and fp8_scaled: command.append("--fp8_scaled") + if fp8_llm: command.append("--fp8_llm") + if bulk_decode: command.append("--bulk_decode") + if blocks_to_swap > 0: command.extend(["--blocks_to_swap", str(blocks_to_swap)]) + if vae_chunk_size is not None and vae_chunk_size > 0: command.extend(["--vae_chunk_size", str(vae_chunk_size)]) + if vae_spatial_tile_sample_min_size is not None and vae_spatial_tile_sample_min_size > 0: command.extend(["--vae_spatial_tile_sample_min_size", str(vae_spatial_tile_sample_min_size)]) + if device and device.strip(): command.extend(["--device", device.strip()]) + if valid_loras_paths: + command.extend(["--lora_weight"] + valid_loras_paths) + command.extend(["--lora_multiplier"] + valid_loras_mults) + if enable_preview and preview_every_n_sections > 0: + command.extend(["--preview_latent_every", str(preview_every_n_sections)]) + command.extend(["--preview_suffix", unique_preview_suffix]) + print(f"DEBUG: Enabling preview every {preview_every_n_sections} sections with suffix {unique_preview_suffix}.") + if use_teacache: + command.append("--use_teacache") + command.extend(["--teacache_steps", str(teacache_steps)]) + command.extend(["--teacache_thresh", str(teacache_thresh)]) + + command_str = [str(c) for c in command] + print(f"Running FramePack Command: {' '.join(command_str)}") + + p = subprocess.Popen( + command_str, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + env=env, text=True, encoding='utf-8', errors='replace', bufsize=1 + ) + current_phase = "Preparing" + actual_total_sections = None + display_section_num = 1 + + while True: + if stop_event.is_set(): + try: + p.terminate() + p.wait(timeout=5) + except subprocess.TimeoutExpired: + p.kill(); p.wait() + except Exception as e: + print(f"Error terminating subprocess: {e}") + yield all_videos.copy(), None, "Generation stopped by user.", "" + return + if skip_event.is_set(): + print(f"Skip signal received for batch item {i+1}. Terminating subprocess...") + try: + p.terminate() + p.wait(timeout=5) + except subprocess.TimeoutExpired: + p.kill(); p.wait() + except Exception as e: + print(f"Error terminating subprocess during skip: {e}") + skip_event.clear() + yield all_videos.copy(), current_preview_yield_path, f"Skipping item {i+1}/{batch_size}...", "" + break + + line = p.stdout.readline() + if not line: + if p.poll() is not None: break + time.sleep(0.01); continue + + line = line.strip() + if not line: continue + print(f"SUBPROCESS: {line}") + + section_match = re.search(r"---.*?Section\s+(\d+)\s*/\s*(\d+)(?:\s+|$|\()", line) + tqdm_match = re.search(r'(\d+)\%\|.+\| (\d+)/(\d+) \[(\d{2}:\d{2})<(\d{2}:\d{2})', line) + phase_changed = False # Initialize phase_changed inside the loop + + # Default progress_text_update to the current line for general logging + progress_text_update = line # This was defined outside the loop before, moved inside + + if section_match: + current_section_num_display = int(section_match.group(1)) + total_sections_from_log = int(section_match.group(2)) + display_section_num = current_section_num_display + if actual_total_sections != total_sections_from_log: + actual_total_sections = total_sections_from_log + print(f"Detected/Updated actual total sections: {actual_total_sections}") + new_phase = f"Generating Section {display_section_num}" + if current_phase != new_phase: + current_phase = new_phase + phase_changed = True + progress_text_update = f"Item {i+1}/{batch_size} | Section {display_section_num}/{actual_total_sections} | Preparing..." + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed}) - {current_phase}" + elif tqdm_match: + percentage = int(tqdm_match.group(1)) + current_step = int(tqdm_match.group(2)) + total_steps = int(tqdm_match.group(3)) + time_elapsed = tqdm_match.group(4) + time_remaining = tqdm_match.group(5) + current_total_for_display = actual_total_sections if actual_total_sections is not None else total_sections_estimate + section_str = f"Section {display_section_num}/{current_total_for_display}" + progress_text_update = f"Item {i+1}/{batch_size} | {section_str} | Step {current_step}/{total_steps} ({percentage}%) | Elapsed: {time_elapsed}, Remaining: {time_remaining}" + denoising_phase = f"Denoising Section {display_section_num}" + if current_phase != denoising_phase: + current_phase = denoising_phase + phase_changed = True + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed}) - {current_phase}" + elif "Decoding video..." in line: + if current_phase != "Decoding Video": + current_phase = "Decoding Video" + phase_changed = True + progress_text_update = f"Item {i+1}/{batch_size} | {current_phase}..." + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed}) - {current_phase}" + elif "INFO:__main__:Video saved to:" in line: + match = re.search(r"Video saved to:\s*(.*\.mp4)", line) + if match: + found_video_path = match.group(1).strip() + if os.path.exists(found_video_path): + current_video_path = found_video_path + # Don't add to all_videos here, add after subprocess completion + else: + print(f"Warning: Parsed video path does not exist: {found_video_path}") + status_text = f"Video {i+1}/{batch_size} Saved (Seed: {current_seed})" + progress_text_update = f"Saved: {os.path.basename(found_video_path) if found_video_path else 'Unknown Path'}" + current_phase = "Saved" + phase_changed = True + else: + print(f"Warning: Could not parse video path from INFO line: {line}") + elif "ERROR" in line.upper() or "TRACEBACK" in line.upper(): + status_text = f"Item {i+1}/{batch_size}: Error Detected (Check Console)" + progress_text_update = line + if current_phase != "Error": + current_phase = "Error" + phase_changed = True + elif phase_changed and current_phase not in ["Saved", "Error"]: + status_text = f"Generating video {i + 1} of {batch_size} (Seed: {current_seed}) - {current_phase}" + + preview_updated = False + current_mtime_check = 0 # Renamed from current_mtime to avoid conflict + found_preview_path_check = None # Renamed + + if enable_preview: + if os.path.exists(preview_mp4_path): + current_mtime_check = os.path.getmtime(preview_mp4_path) + found_preview_path_check = preview_mp4_path + elif os.path.exists(preview_png_path): + current_mtime_check = os.path.getmtime(preview_png_path) + found_preview_path_check = preview_png_path + + if found_preview_path_check and current_mtime_check > last_preview_mtime: + print(f"DEBUG: Preview file updated: {found_preview_path_check} (mtime: {current_mtime_check})") + current_preview_yield_path = found_preview_path_check + last_preview_mtime = current_mtime_check + preview_updated = True + + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + + p.stdout.close(); rc = p.wait() + clear_cuda_cache(); time.sleep(0.1) + + if rc == 0 and current_video_path and os.path.exists(current_video_path): + all_videos.append((current_video_path, f"Seed: {current_seed}")) # Add video here + parameters = { + "prompt": prompt, "negative_prompt": negative_prompt, + "input_image": os.path.basename(current_input_image_for_item) if current_input_image_for_item else None, + "section_controls": [ + {"index": s, "prompt_override": p_override, "image_override": os.path.basename(img_override) if img_override else None} + for s, p_override, img_override in zip(framepack_secs, framepack_sec_prompts, framepack_sec_images) + if (p_override and p_override.strip()) or img_override + ], + "final_prompt_arg": final_prompt_arg, + "final_image_path_arg": final_image_path_arg_for_item, # Use item-specific image path + "input_end_frame": os.path.basename(input_end_frame) if input_end_frame else None, + "transformer_path": transformer_path, "vae_path": vae_path, + "text_encoder_path": text_encoder_path, "text_encoder_2_path": text_encoder_2_path, + "image_encoder_path": image_encoder_path, + "video_width": final_width_for_item, "video_height": final_height_for_item, + "video_seconds": total_second_length, "fps": fps, "seed": current_seed, + "infer_steps": steps, "embedded_cfg_scale": distilled_guidance_scale, + "guidance_scale": cfg, "guidance_rescale": rs, "sample_solver": sample_solver, + "latent_window_size": latent_window_size, + "fp8": fp8, "fp8_scaled": fp8_scaled, "fp8_llm": fp8_llm, + "blocks_to_swap": blocks_to_swap, "bulk_decode": bulk_decode, "attn_mode": attn_mode, + "vae_chunk_size": vae_chunk_size, "vae_spatial_tile_sample_min_size": vae_spatial_tile_sample_min_size, + "device": device, + "lora_weights": [os.path.basename(p) for p in valid_loras_paths], + "lora_multipliers": [float(m) for m in valid_loras_mults], + "original_dims_str": current_original_dims_str_for_item, + "target_resolution": target_resolution, + "is_f1": is_f1 + } + try: + add_metadata_to_video(current_video_path, parameters) + print(f"Added metadata to {current_video_path}") + except Exception as meta_err: + print(f"Warning: Failed to add metadata to {current_video_path}: {meta_err}") + status_text = f"Item {i+1}/{batch_size} Completed (Seed: {current_seed})" + progress_text_update = f"Video saved: {os.path.basename(current_video_path)}" + current_preview_yield_path = None # Clear preview for next item + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + elif rc != 0: + status_text = f"Item {i+1}/{batch_size} Failed (Seed: {current_seed}, Code: {rc})" + progress_text_update = f"Subprocess failed. Check console logs." + current_preview_yield_path = None # Clear preview + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + else: + status_text = f"Item {i+1}/{batch_size} Finished (Seed: {current_seed}), but no video file confirmed." + progress_text_update = "Check console logs for the saved path." + current_preview_yield_path = None # Clear preview + yield all_videos.copy(), current_preview_yield_path, status_text, progress_text_update + + # Cleanup preview files for the completed item to avoid them being picked up by next item + if enable_preview: + for prev_file in [preview_mp4_path, preview_png_path]: + if os.path.exists(prev_file): + try: + os.remove(prev_file) + print(f"Cleaned up preview file: {prev_file}") + except Exception as e_clean: + print(f"Warning: Could not remove preview file {prev_file}: {e_clean}") + + time.sleep(0.2) + + yield all_videos, None, "FramePack Batch complete", "" + +def calculate_framepack_width(height, original_dims): + """Calculate FramePack width based on height maintaining aspect ratio (divisible by 32)""" + if not original_dims or height is None: + return gr.update() + try: + # Ensure height is an integer and divisible by 32 + height = int(height) + if height <= 0 : return gr.update() + height = (height // 32) * 32 # <-- Use 32 + height = max(64, height) # Min height (64 is divisible by 32) + + orig_w, orig_h = map(int, original_dims.split('x')) + if orig_h == 0: return gr.update() + aspect_ratio = orig_w / orig_h + # Calculate new width, rounding to the nearest multiple of 32 + new_width = round((height * aspect_ratio) / 32) * 32 # <-- Round and use 32 + return gr.update(value=max(64, new_width)) # Ensure minimum size (also divisible by 32) + + except Exception as e: + print(f"Error calculating width: {e}") + return gr.update() + +def calculate_framepack_height(width, original_dims): + """Calculate FramePack height based on width maintaining aspect ratio (divisible by 32)""" + if not original_dims or width is None: + return gr.update() + try: + # Ensure width is an integer and divisible by 32 + width = int(width) + if width <= 0: return gr.update() + width = (width // 32) * 32 # <-- Use 32 + width = max(64, width) # Min width (64 is divisible by 32) + + orig_w, orig_h = map(int, original_dims.split('x')) + if orig_w == 0: return gr.update() + aspect_ratio = orig_w / orig_h + # Calculate new height, rounding to the nearest multiple of 32 + new_height = round((width / aspect_ratio) / 32) * 32 # <-- Round and use 32 + return gr.update(value=max(64, new_height)) # Ensure minimum size (also divisible by 32) + except Exception as e: + print(f"Error calculating height: {e}") + return gr.update() + +def update_framepack_from_scale(scale, original_dims): + """Update FramePack dimensions based on scale percentage (divisible by 32)""" + if not original_dims: + return gr.update(), gr.update(), gr.update() + try: + scale = float(scale) if scale is not None else 100.0 + if scale <= 0: scale = 100.0 + + orig_w, orig_h = map(int, original_dims.split('x')) + scale_factor = scale / 100.0 + + # Calculate and round to the nearest multiple of 32 + new_w = round((orig_w * scale_factor) / 32) * 32 # <-- Round and use 32 + new_h = round((orig_h * scale_factor) / 32) * 32 # <-- Round and use 32 + + # Ensure minimum size (must be multiple of 32) + new_w = max(64, new_w) # 64 is divisible by 32 + new_h = max(64, new_h) + + # Clear target resolution if using scale slider for explicit dims + return gr.update(value=new_w), gr.update(value=new_h), gr.update(value=None) + except Exception as e: + print(f"Error updating from scale: {e}") + return gr.update(), gr.update(), gr.update() + +def process_i2v_single_video( + prompt: str, + image_path: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + clip_vision_path: str, + save_path: str, + flow_shift: float, + cfg_scale: float, # embedded_cfg_scale + guidance_scale: float, # main CFG + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + vae_chunk_size: int, + vae_spatial_tile_min: int, + # --- Explicit LoRA args instead of *lora_params --- + lora1: str = "None", + lora2: str = "None", + lora3: str = "None", + lora4: str = "None", + lora1_multiplier: float = 1.0, + lora2_multiplier: float = 1.0, + lora3_multiplier: float = 1.0, + lora4_multiplier: float = 1.0, + # --- End LoRA args --- + negative_prompt: Optional[str] = None, + use_fp8: bool = False, + fp8_llm: bool = False +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate a single video using hv_i2v_generate_video.py""" + global stop_event + + # ... (Keep existing argument validation and env setup) ... + if stop_event.is_set(): + yield [], "", "" + return + + # Argument validation + if not image_path or not os.path.exists(image_path): + yield [], "Error: Input image not found", f"Cannot find image: {image_path}" + return + # Check clip vision path only if needed (Hunyuan-I2V, not SkyReels-I2V based on script name) + is_hunyuan_i2v = "mp_rank_00_model_states_i2v" in model # Heuristic check + if is_hunyuan_i2v and (not clip_vision_path or not os.path.exists(clip_vision_path)): + yield [], "Error: CLIP Vision model not found", f"Cannot find file: {clip_vision_path}" + return + + if os.path.isabs(model): + model_path = model + else: + model_path = os.path.normpath(os.path.join(dit_folder, model)) + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + current_seed = seed + + clear_cuda_cache() + + command = [ + sys.executable, + "hv_i2v_generate_video.py", # <<< Use the new script + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + # Add clip vision path only if it's likely the Hunyuan I2V model + *(["--clip_vision_path", clip_vision_path] if is_hunyuan_i2v else []), + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(cfg_scale), + "--guidance_scale", str(guidance_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--image_path", image_path + ] + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + + if use_fp8: + command.append("--fp8") + if fp8_llm: + command.append("--fp8_llm") + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + if use_split_attn: + command.append("--split_attn") + + if vae_chunk_size > 0: + command.extend(["--vae_chunk_size", str(vae_chunk_size)]) + if vae_spatial_tile_min > 0: + command.extend(["--vae_spatial_tile_sample_min_size", str(vae_spatial_tile_min)]) + + # --- Updated LoRA handling using named arguments --- + lora_weights_list = [lora1, lora2, lora3, lora4] + lora_multipliers_list = [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier] + valid_loras = [] + for weight, mult in zip(lora_weights_list, lora_multipliers_list): + if weight and weight != "None": + lora_file_path = os.path.join(lora_folder, weight) + if os.path.exists(lora_file_path): + valid_loras.append((lora_file_path, mult)) + else: + print(f"Warning: LoRA file not found: {lora_file_path}") + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + # --- End Updated LoRA handling --- + + # ... (Keep subprocess execution, output collection, and metadata saving logic) ... + command_str = [str(c) for c in command] # Ensure all args are strings + print(f"Running Command (I2V): {' '.join(command_str)}") + + p = subprocess.Popen( + command_str, # Use stringified command + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield videos, current_previews, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') # Print progress to console + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + generated_video_path = None + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + # Try to find the video matching the seed + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + generated_video_path = os.path.join(save_path_abs, matching_videos[0]) + + if generated_video_path: + # Collect parameters for metadata (adjust as needed for i2v specifics) + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "model": model, + "vae": vae, + "te1": te1, + "te2": te2, + "clip_vision_path": clip_vision_path, + "save_path": save_path, + "flow_shift": flow_shift, + "embedded_cfg_scale": cfg_scale, + "guidance_scale": guidance_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "lora_weights": list(lora_weights_list), # Save the list + "lora_multipliers": list(lora_multipliers_list), # Save the list + "input_image": image_path, + "negative_prompt": negative_prompt if negative_prompt else None, + "vae_chunk_size": vae_chunk_size, + "vae_spatial_tile_min": vae_spatial_tile_min, + "use_fp8_dit": use_fp8, + "use_fp8_llm": fp8_llm + } + add_metadata_to_video(generated_video_path, parameters) + videos.append((str(generated_video_path), f"Seed: {current_seed}")) + yield videos, f"Completed (seed: {current_seed})", "" + else: + yield [], f"Failed (seed: {current_seed})", "Could not find generated video file." + + +def process_i2v_batch( + prompt: str, + image_path: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + clip_vision_path: str, # Added + save_path: str, + flow_shift: float, + cfg_scale: float, # embedded_cfg_scale + guidance_scale: float, # main CFG + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + vae_chunk_size: int, # Added + vae_spatial_tile_min: int, # Added + negative_prompt: Optional[str] = None, # Added + use_fp8: bool = False, # Added + fp8_llm: bool = False, # Added + *lora_params # Captures LoRA weights and multipliers +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Process a batch of videos using the new I2V script""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting I2V generation..." + yield [], "Preparing...", progress_text + + # Extract LoRA weights and multipliers once + num_lora_weights = 4 + lora_weights_list = lora_params[:num_lora_weights] + lora_multipliers_list = lora_params[num_lora_weights:num_lora_weights*2] + + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user.", "" + return + + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size} (I2V)" + yield all_videos.copy(), batch_text, progress_text + + # Call the single video processing function + single_gen = process_i2v_single_video( + prompt=prompt, + image_path=image_path, + width=width, + height=height, + batch_size=batch_size, + video_length=video_length, + fps=fps, + infer_steps=infer_steps, + seed=current_seed, + dit_folder=dit_folder, + model=model, + vae=vae, + te1=te1, + te2=te2, + clip_vision_path=clip_vision_path, + save_path=save_path, + flow_shift=flow_shift, + cfg_scale=cfg_scale, + guidance_scale=guidance_scale, + output_type=output_type, + attn_mode=attn_mode, + block_swap=block_swap, + exclude_single_blocks=exclude_single_blocks, + use_split_attn=use_split_attn, + lora_folder=lora_folder, + vae_chunk_size=vae_chunk_size, + vae_spatial_tile_min=vae_spatial_tile_min, + # --- Pass LoRA params by keyword --- + lora1=lora_weights_list[0], + lora2=lora_weights_list[1], + lora3=lora_weights_list[2], + lora4=lora_weights_list[3], + lora1_multiplier=lora_multipliers_list[0], + lora2_multiplier=lora_multipliers_list[1], + lora3_multiplier=lora_multipliers_list[2], + lora4_multiplier=lora_multipliers_list[3], + # --- End LoRA keyword args --- + negative_prompt=negative_prompt, + use_fp8=use_fp8, + fp8_llm=fp8_llm + ) + + # Yield progress updates from the single generator + try: + for videos, status, progress in single_gen: + if videos: + # Only add the latest video from this specific generation + new_video = videos[-1] + if new_video not in all_videos: + all_videos.append(new_video) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + except Exception as e: + yield all_videos.copy(), f"Error in batch {i+1}: {e}", "" + print(f"Error during single I2V generation: {e}") # Log error + + # Optional small delay between batch items + time.sleep(0.1) + + yield all_videos, "I2V Batch complete", "" + + +def wanx_extend_video_wrapper( + prompt, negative_prompt, input_image, base_video_path, + width, height, video_length, fps, infer_steps, + flow_shift, guidance_scale, seed, + task, dit_folder, dit_path, vae_path, t5_path, clip_path, # <--- Parameters received here + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers="", slg_start=0.0, slg_end=1.0, + lora1="None", lora2="None", lora3="None", lora4="None", + lora1_multiplier=1.0, lora2_multiplier=1.0, lora3_multiplier=1.0, lora4_multiplier=1.0, + enable_cfg_skip=False, cfg_skip_mode="none", cfg_apply_ratio=0.7 +): + """Direct wrapper that bypasses the problematic wanx_generate_video function""" + global stop_event + + # All videos generated + all_videos = [] + + # Debug prints to understand what we're getting + print(f"DEBUG - Received parameters in wanx_extend_video_wrapper:") + print(f" task: {task}") + print(f" dit_folder: {dit_folder}") # <<< Should be the folder path ('wan') + print(f" dit_path: {dit_path}") # <<< Should be the model filename + print(f" vae_path: {vae_path}") # <<< Should be the VAE path + print(f" t5_path: {t5_path}") # <<< Should be the T5 path + print(f" clip_path: {clip_path}") # <<< Should be the CLIP path + print(f" output_type: {output_type}") + print(f" sample_solver: {sample_solver}") + print(f" attn_mode: {attn_mode}") + print(f" block_swap: {block_swap}") + + # Get current seed + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + + # --- START CRITICAL FIX --- + # Detect if parameters are swapped based on the pattern observed in the error log + # Check if dit_path looks like a VAE path (contains "VAE" or ends with .pth) + # AND dit_folder looks like a model filename (ends with .safetensors or .pt) + params_swapped = False + if dit_path and dit_folder and \ + (("VAE" in dit_path or dit_path.endswith(".pth")) and \ + (dit_folder.endswith(".safetensors") or dit_folder.endswith(".pt"))): + params_swapped = True + print("WARNING: Parameters appear to be swapped in extend workflow. Applying correction...") + + # Correct the parameters based on the observed swap + actual_model_filename = dit_folder # Original dit_folder was the filename + actual_vae_path = dit_path # Original dit_path was the VAE path + actual_t5_path = vae_path # Original vae_path was the T5 path + actual_clip_path = t5_path # Original t5_path was the CLIP path + + # Assign corrected values back to expected variable names for the rest of the function + dit_path = actual_model_filename + vae_path = actual_vae_path + t5_path = actual_t5_path + clip_path = actual_clip_path + dit_folder = "wan" # Assume default 'wan' folder if swapped + + print(f" Corrected dit_folder: {dit_folder}") + print(f" Corrected dit_path (model filename): {dit_path}") + print(f" Corrected vae_path: {vae_path}") + print(f" Corrected t5_path: {t5_path}") + print(f" Corrected clip_path: {clip_path}") + + # Construct the full model path using the potentially corrected dit_folder and dit_path + actual_model_path = os.path.join(dit_folder, dit_path) if not os.path.isabs(dit_path) else dit_path + print(f" Using actual_model_path for --dit: {actual_model_path}") + # --- END CRITICAL FIX --- + + # Prepare environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + # Clear CUDA cache + clear_cuda_cache() + + # Validate and fix parameters + # Fix output_type - must be one of: video, images, latent, both + valid_output_types = ["video", "images", "latent", "both"] + actual_output_type = "video" if output_type not in valid_output_types else output_type + + # Fix sample_solver - must be one of: unipc, dpm++, vanilla + valid_sample_solvers = ["unipc", "dpm++", "vanilla"] + actual_sample_solver = "unipc" if sample_solver not in valid_sample_solvers else sample_solver + + # Fix attn_mode - must be one of: sdpa, flash, sageattn, xformers, torch + valid_attn_modes = ["sdpa", "flash", "sageattn", "xformers", "torch"] + actual_attn_mode = "sdpa" if attn_mode not in valid_attn_modes else attn_mode + + # Fix block_swap - must be an integer + try: + actual_block_swap = int(block_swap) + except (ValueError, TypeError): + actual_block_swap = 0 + + # Build command array with explicit string conversions for EVERY parameter + command = [ + sys.executable, + "wan_generate_video.py", + "--task", str(task), + "--prompt", str(prompt), + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", str(save_path), + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", actual_output_type, + "--sample_solver", actual_sample_solver, + "--attn_mode", actual_attn_mode, + "--blocks_to_swap", str(actual_block_swap), + # Use the corrected model path and other paths + "--dit", str(actual_model_path), # <<< Use corrected full model path + "--vae", str(vae_path), # <<< Use potentially corrected vae_path + "--t5", str(t5_path) # <<< Use potentially corrected t5_path + ] + + # Add image path and clip model path if needed + if input_image: + command.extend(["--image_path", str(input_image)]) + # Use the potentially corrected clip_path + if clip_path and clip_path != "outputs" and "output" not in clip_path: + command.extend(["--clip", str(clip_path)]) # <<< Use potentially corrected clip_path + + # Add negative prompt + if negative_prompt: + command.extend(["--negative_prompt", str(negative_prompt)]) + + # Handle boolean flags - keep original values + if fp8: + command.append("--fp8") + + if fp8_scaled: + command.append("--fp8_scaled") + + if fp8_t5: + command.append("--fp8_t5") + + # Add SLG parameters + try: + # Ensure slg_layers is treated as a string before splitting + slg_layers_str = str(slg_layers) if slg_layers is not None else "" + if slg_layers_str and slg_layers_str.strip() and slg_layers_str.lower() != "none": + slg_list = [] + for layer in slg_layers_str.split(","): + layer = layer.strip() + if layer.isdigit(): # Only add if it's a valid integer + slg_list.append(int(layer)) + if slg_list: # Only add if we have valid layers + command.extend(["--slg_layers", ",".join(map(str, slg_list))]) + + # Only add slg_start and slg_end if we have valid slg_layers + if slg_start is not None: + try: + slg_start_float = float(slg_start) + if slg_start_float >= 0: + command.extend(["--slg_start", str(slg_start_float)]) + except (ValueError, TypeError): pass # Ignore if conversion fails + if slg_end is not None: + try: + slg_end_float = float(slg_end) + if slg_end_float <= 1.0: + command.extend(["--slg_end", str(slg_end_float)]) + except (ValueError, TypeError): pass # Ignore if conversion fails + except Exception as e: # Catch potential errors during processing + print(f"Warning: Error processing SLG parameters: {e}") + pass + + # Handle LoRA weights and multipliers + valid_loras = [] + if lora_folder and isinstance(lora_folder, str): + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + # Skip None or empty values + if not weight or str(weight).lower() == "none": + continue + + # Construct path and check existence + full_path = os.path.join(str(lora_folder), str(weight)) + if not os.path.exists(full_path): + print(f"LoRA file not found: {full_path}") + continue + + # Add valid LoRA + valid_loras.append((full_path, str(mult))) + + if valid_loras: + weights = [w for w, _ in valid_loras] + multipliers = [m for _, m in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + # Final conversion to ensure all elements are strings + command_str = [str(item) for item in command] + + print(f"Running Command (wanx_extend_video_wrapper): {' '.join(command_str)}") + + # Process execution + p = subprocess.Popen( + command_str, # Use stringified command + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] # Store the generated (non-extended) video first + + # Process stdout in real time + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + # Yield empty list during processing, actual video is collected later + yield [], f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + return_code = p.wait() # Get return code + + # Clean CUDA cache and wait + clear_cuda_cache() + time.sleep(0.5) + + # Check return code + if return_code != 0: + print(f"❌ Error: wan_generate_video.py exited with code {return_code}") + yield [], f"Failed (seed: {current_seed})", f"Subprocess failed with code {return_code}" + return + + # Find the *newly generated* video first + generated_video_path = None + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + # Find the most recent mp4 containing the seed + all_mp4_files = glob.glob(os.path.join(save_path_abs, f"*_{current_seed}*.mp4")) + if all_mp4_files: + generated_video_path = max(all_mp4_files, key=os.path.getmtime) + print(f"Found newly generated video: {generated_video_path}") + + # Add metadata to the generated video before potential concatenation + parameters = { + "prompt": prompt, "negative_prompt": negative_prompt, "input_image": input_image, + "width": width, "height": height, "video_length": video_length, "fps": fps, + "infer_steps": infer_steps, "flow_shift": flow_shift, "guidance_scale": guidance_scale, + "seed": current_seed, "task": task, "dit_path": actual_model_path, # Store the actual path used + "vae_path": vae_path, "t5_path": t5_path, "clip_path": clip_path, + "save_path": save_path, "output_type": actual_output_type, "sample_solver": actual_sample_solver, + "exclude_single_blocks": exclude_single_blocks, "attn_mode": actual_attn_mode, + "block_swap": actual_block_swap, "fp8": fp8, "fp8_scaled": fp8_scaled, "fp8_t5": fp8_t5, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "slg_layers": slg_layers, "slg_start": slg_start, "slg_end": slg_end, + "is_extension_source": True # Flag this as the source for an extension + } + add_metadata_to_video(generated_video_path, parameters) + # videos.append((str(generated_video_path), f"Generated segment (Seed: {current_seed})")) # Optionally yield segment + else: + print(f"Could not find generated video segment for seed {current_seed} in {save_path_abs}") + + # Stop here if no new video segment was generated + if not generated_video_path: + yield [], f"Failed (seed: {current_seed})", "Could not find generated video segment." + return + + # Now concatenate with base video if we have the new segment and a base_video_path + if generated_video_path and base_video_path and os.path.exists(base_video_path): + try: + print(f"Extending base video: {base_video_path}") + + # Create unique output filename for the *extended* video + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + output_filename = f"extended_{timestamp}_seed{current_seed}_{Path(base_video_path).stem}.mp4" + output_path = os.path.join(save_path_abs, output_filename) + + # Create a temporary file list for ffmpeg concatenation + list_file = os.path.join(save_path_abs, f"temp_concat_list_{current_seed}.txt") + with open(list_file, "w") as f: + f.write(f"file '{os.path.abspath(base_video_path)}'\n") + f.write(f"file '{os.path.abspath(generated_video_path)}'\n") # Use the newly generated segment + + print(f"Concatenating: {base_video_path} + {generated_video_path} -> {output_path}") + + # Run ffmpeg concatenation command + concat_command = [ + "ffmpeg", + "-f", "concat", + "-safe", "0", # Allow relative paths if needed, but we use absolute + "-i", list_file, + "-c", "copy", # Fast concatenation without re-encoding + "-y", # Overwrite output if exists + output_path + ] + + # Convert all command parts to strings + concat_command_str = [str(item) for item in concat_command] + + print(f"Running FFmpeg command: {' '.join(concat_command_str)}") + concat_result = subprocess.run(concat_command_str, check=False, capture_output=True, text=True) # Don't check=True initially + + # Clean up temporary list file + if os.path.exists(list_file): + try: + os.remove(list_file) + except OSError as e: + print(f"Warning: Could not remove temp list file {list_file}: {e}") + + + # Check if concatenation was successful + if concat_result.returncode == 0 and os.path.exists(output_path): + # Optionally, add metadata to the *extended* video as well + extended_parameters = parameters.copy() + extended_parameters["is_extension_source"] = False + extended_parameters["base_video"] = os.path.basename(base_video_path) + add_metadata_to_video(output_path, extended_parameters) + + extended_video_gallery_item = [(output_path, f"Extended (Seed: {current_seed})")] + print(f"✅ Successfully created extended video: {output_path}") + yield extended_video_gallery_item, "Extended video created successfully", "" + return # Success! + else: + print(f"❌ Failed to create extended video at {output_path}") + print(f"FFmpeg stderr: {concat_result.stderr}") + # Yield the generated segment if concatenation failed + yield [(generated_video_path, f"Generated segment (Seed: {current_seed})")], "Generated segment (extension failed)", f"FFmpeg failed: {concat_result.stderr[:200]}..." + return + + except Exception as e: + print(f"❌ Error during concatenation: {str(e)}") + # Yield the generated segment if concatenation failed + yield [(generated_video_path, f"Generated segment (Seed: {current_seed})")], "Generated segment (extension error)", f"Error: {str(e)}" + return + + # If we got here, base_video_path was likely None or didn't exist, but generation succeeded + yield [(generated_video_path, f"Generated segment (Seed: {current_seed})")], "Generated segment (no base video provided)", "" + +def wanx_v2v_generate_video( + prompt, + negative_prompt, + input_video, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + strength, + seed, + task, + dit_folder, + dit_path, + vae_path, + t5_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers, + slg_start, + slg_end, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0, + enable_cfg_skip=False, + cfg_skip_mode="none", + cfg_apply_ratio=0.7, +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate video with WanX model in video-to-video mode""" + global stop_event + + # Convert values safely to float or None + try: + slg_start_float = float(slg_start) if slg_start is not None and str(slg_start).lower() != "none" else None + except (ValueError, TypeError): + slg_start_float = None + print(f"Warning: Could not convert slg_start '{slg_start}' to float") + + try: + slg_end_float = float(slg_end) if slg_end is not None and str(slg_end).lower() != "none" else None + except (ValueError, TypeError): + slg_end_float = None + print(f"Warning: Could not convert slg_end '{slg_end}' to float") + + print(f"slg_start_float: {slg_start_float}, slg_end_float: {slg_end_float}") + + if stop_event.is_set(): + yield [], "", "" + return + + # Check if we need input video (required for v2v) + if not input_video: + yield [], "Error: No input video provided", "Please provide an input video for video-to-video generation" + return + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + current_seed = seed + + # Prepare environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + clear_cuda_cache() + + # Construct full dit_path including folder - this is the fix + full_dit_path = os.path.join(dit_folder, dit_path) if not os.path.isabs(dit_path) else dit_path + + command = [ + sys.executable, + "wan_generate_video.py", + "--task", task, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--dit", full_dit_path, # Use full_dit_path instead of dit_path + "--vae", vae_path, + "--t5", t5_path, + "--sample_solver", sample_solver, + "--video_path", input_video, # This is the key for v2v mode + "--strength", str(strength) # Strength parameter for v2v + ] + if enable_cfg_skip and cfg_skip_mode != "none": + command.extend([ + "--cfg_skip_mode", cfg_skip_mode, + "--cfg_apply_ratio", str(cfg_apply_ratio) + ]) + # Handle SLG parameters + if slg_layers and str(slg_layers).strip() and slg_layers.lower() != "none": + try: + # Parse SLG layers + layer_list = [int(x) for x in str(slg_layers).split(",")] + if layer_list: # Only proceed if we have valid layer values + command.extend(["--slg_layers", ",".join(map(str, layer_list))]) + + # Only add slg_start and slg_end if we have valid slg_layers + try: + if slg_start_float is not None and slg_start_float >= 0: + command.extend(["--slg_start", str(slg_start_float)]) + if slg_end_float is not None and slg_end_float <= 1.0: + command.extend(["--slg_end", str(slg_end_float)]) + except ValueError as e: + print(f"Invalid SLG timing values: {str(e)}") + except ValueError as e: + print(f"Invalid SLG layers format: {slg_layers} - {str(e)}") + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + + if fp8: + command.append("--fp8") + + if fp8_scaled: + command.append("--fp8_scaled") + + if fp8_t5: + command.append("--fp8_t5") + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + + # Handle LoRA weights and multipliers + lora_weights = [lora1, lora2, lora3, lora4] + lora_multipliers = [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier] + + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + full_path = os.path.join(lora_folder, weight) + if not os.path.exists(full_path): + print(f"LoRA file not found: {full_path}") + continue + valid_loras.append((full_path, mult)) + + if valid_loras: + weights = [w for w, _ in valid_loras] + multipliers = [str(m) for _, m in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + print(f"Running: {' '.join(command)}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "task": task, + "flow_shift": flow_shift, + "guidance_scale": guidance_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "input_video": input_video, + "strength": strength, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "dit_path": full_dit_path, # Store the full path in metadata + "vae_path": vae_path, + "t5_path": t5_path, + "negative_prompt": negative_prompt if negative_prompt else None, + "sample_solver": sample_solver + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +def wanx_v2v_batch_handler( + prompt, + negative_prompt, + input_video, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + strength, + seed, + batch_size, + task, + dit_folder, # folder path + dit_path, # model filename + vae_path, + t5_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers: str, + slg_start: Optional[str], + slg_end: Optional[str], + enable_cfg_skip: bool, + cfg_skip_mode: str, + cfg_apply_ratio: float, + *lora_params +): + """Handle batch generation for WanX v2v""" + global stop_event + stop_event.clear() + + # Extract LoRA parameters + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Process each item in the batch + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Generate a single video + for videos, status, progress in wanx_v2v_generate_video( + prompt, + negative_prompt, + input_video, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + strength, + current_seed, + task, + dit_folder, # Pass folder path + dit_path, # Pass model filename + vae_path, + t5_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers, + slg_start, + slg_end, + *lora_weights, + *lora_multipliers, + enable_cfg_skip, + cfg_skip_mode, + cfg_apply_ratio, + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + +def update_wanx_v2v_dimensions(video): + """Update dimensions from uploaded video""" + if video is None: + return "", gr.update(value=832), gr.update(value=480) + + cap = cv2.VideoCapture(video) + if not cap.isOpened(): + return "Error opening video", gr.update(), gr.update() + + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + # Make dimensions divisible by 32 + w = (w // 32) * 32 + h = (h // 32) * 32 + + return f"{w}x{h}", w, h + +def send_wanx_v2v_to_hunyuan_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + guidance_scale: float, + negative_prompt: str +) -> Tuple: + """Send the selected WanX v2v video to Hunyuan v2v tab""" + if gallery is None or not gallery: + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + # If no selection made but we have videos, use the first one + if selected_index is None and len(gallery) > 0: + selected_index = 0 + + if selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + # Clean up path for Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Make sure it's a string + video_path = str(video_path) + + return (video_path, prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def handle_wanx_v2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when gallery item is clicked""" + return evt.index + +def variance_of_laplacian(image): + """ + Compute the variance of the Laplacian of the image. + Higher variance indicates a sharper image. + """ + return cv2.Laplacian(image, cv2.CV_64F).var() + +def extract_sharpest_frame(video_path, frames_to_check=30): + """ + Extract the sharpest frame from the last N frames of the video. + + Args: + video_path (str): Path to the video file + frames_to_check (int): Number of frames from the end to check + + Returns: + tuple: (temp_image_path, frame_number, sharpness_score) + """ + print(f"\n=== Extracting sharpest frame from the last {frames_to_check} frames ===") + print(f"Input video path: {video_path}") + + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None, None, None + + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None, None, None + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + print(f"Total frames detected: {total_frames}, FPS: {fps:.2f}") + + if total_frames < 1: + print("❌ Error: Video contains 0 frames") + return None, None, None + + # Determine how many frames to check (the last N frames) + if frames_to_check > total_frames: + frames_to_check = total_frames + start_frame = 0 + else: + start_frame = total_frames - frames_to_check + + print(f"Checking frames {start_frame} to {total_frames-1}") + + # Find the sharpest frame + sharpest_frame = None + max_sharpness = -1 + sharpest_frame_number = -1 + + # Set starting position + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + # Process frames with a progress bar + with tqdm(total=frames_to_check, desc="Finding sharpest frame") as pbar: + frame_idx = start_frame + while frame_idx < total_frames: + ret, frame = cap.read() + if not ret: + break + + # Convert to grayscale and calculate sharpness + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + sharpness = variance_of_laplacian(gray) + + # Update if this is the sharpest frame so far + if sharpness > max_sharpness: + max_sharpness = sharpness + sharpest_frame = frame.copy() + sharpest_frame_number = frame_idx + + frame_idx += 1 + pbar.update(1) + + cap.release() + + if sharpest_frame is None: + print("❌ Error: Failed to find a sharp frame") + return None, None, None + + # Prepare output path + temp_dir = os.path.abspath("temp_frames") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, f"sharpest_frame_{os.path.basename(video_path)}.png") + print(f"Saving frame to: {temp_path}") + + # Write and verify + if not cv2.imwrite(temp_path, sharpest_frame): + print("❌ Error: Failed to write frame to file") + return None, None, None + + if not os.path.exists(temp_path): + print("❌ Error: Output file not created") + return None, None, None + + # Calculate frame time in seconds + frame_time = sharpest_frame_number / fps + + print(f"✅ Extracted sharpest frame: {sharpest_frame_number} (at {frame_time:.2f}s) with sharpness {max_sharpness:.2f}") + return temp_path, sharpest_frame_number, max_sharpness + + except Exception as e: + print(f"❌ Unexpected error: {str(e)}") + return None, None, None + finally: + if 'cap' in locals(): + cap.release() + +def trim_video_to_frame(video_path, frame_number, output_dir="outputs"): + """ + Trim video up to the specified frame and save as a new video. + + Args: + video_path (str): Path to the video file + frame_number (int): Frame number to trim to + output_dir (str): Directory to save the trimmed video + + Returns: + str: Path to the trimmed video file + """ + print(f"\n=== Trimming video to frame {frame_number} ===") + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None + + try: + # Get video information + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None + + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + + # Calculate time in seconds + time_seconds = frame_number / fps + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Generate output filename + timestamp = f"{int(time_seconds)}s" + base_name = Path(video_path).stem + output_file = os.path.join(output_dir, f"{base_name}_trimmed_to_{timestamp}.mp4") + + # Use ffmpeg to trim the video + ( + ffmpeg + .input(video_path) + .output(output_file, to=time_seconds, c="copy") + .global_args('-y') # Overwrite output files + .run(quiet=True) + ) + + if not os.path.exists(output_file): + print("❌ Error: Failed to create trimmed video") + return None + + print(f"✅ Successfully trimmed video to {time_seconds:.2f}s: {output_file}") + return output_file + + except Exception as e: + print(f"❌ Error trimming video: {str(e)}") + return None + +def send_sharpest_frame_handler(gallery, selected_idx, frames_to_check=30): + """ + Extract the sharpest frame from the last N frames of the selected video + + Args: + gallery: Gradio gallery component with videos + selected_idx: Index of the selected video + frames_to_check: Number of frames from the end to check + + Returns: + tuple: (image_path, video_path, frame_number, sharpness) + """ + if gallery is None or not gallery: + return None, None, None, "No videos in gallery" + + if selected_idx is None and len(gallery) == 1: + selected_idx = 0 + + if selected_idx is None or selected_idx >= len(gallery): + return None, None, None, "No video selected" + + # Get the video path + item = gallery[selected_idx] + if isinstance(item, tuple): + video_path = item[0] + elif isinstance(item, dict): + video_path = item.get('name') or item.get('data') + else: + video_path = str(item) + + # Extract the sharpest frame + image_path, frame_number, sharpness = extract_sharpest_frame(video_path, frames_to_check) + + if image_path is None: + return None, None, None, "Failed to extract sharpest frame" + + return image_path, video_path, frame_number, f"Extracted frame {frame_number} with sharpness {sharpness:.2f}" + +def trim_and_prepare_for_extension(video_path, frame_number, save_path="outputs"): + """ + Trim the video to the specified frame and prepare for extension. + + Args: + video_path: Path to the video file + frame_number: Frame number to trim to + save_path: Directory to save the trimmed video + + Returns: + tuple: (trimmed_video_path, status_message) + """ + if not video_path or not os.path.exists(video_path): + return None, "No video selected or video file does not exist" + + if frame_number is None: + return None, "No frame number provided, please extract sharpest frame first" + + # Trim the video + trimmed_video = trim_video_to_frame(video_path, frame_number, save_path) + + if trimmed_video is None: + return None, "Failed to trim video" + + return trimmed_video, f"Video trimmed to frame {frame_number} and ready for extension" + +def send_last_frame_handler(gallery, selected_idx): + """Handle sending last frame to input with better error handling""" + if gallery is None or not gallery: + return None, None + + if selected_idx is None and len(gallery) == 1: + selected_idx = 0 + + if selected_idx is None or selected_idx >= len(gallery): + return None, None + + # Get the frame and video path + frame = handle_last_frame_transfer(gallery, selected_idx) + video_path = None + + if selected_idx < len(gallery): + item = gallery[selected_idx] + video_path = parse_video_path(item) + + return frame, video_path + +def extract_last_frame(video_path: str) -> Optional[str]: + """Extract last frame from video and return temporary image path with error handling""" + print(f"\n=== Starting frame extraction ===") + print(f"Input video path: {video_path}") + + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None + + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"Total frames detected: {total_frames}") + + if total_frames < 1: + print("❌ Error: Video contains 0 frames") + return None + + # Extract last frame + cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) + success, frame = cap.read() + + if not success or frame is None: + print("❌ Error: Failed to read last frame") + return None + + # Prepare output path + temp_dir = os.path.abspath("temp_frames") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, f"last_frame_{os.path.basename(video_path)}.png") + print(f"Saving frame to: {temp_path}") + + # Write and verify + if not cv2.imwrite(temp_path, frame): + print("❌ Error: Failed to write frame to file") + return None + + if not os.path.exists(temp_path): + print("❌ Error: Output file not created") + return None + + print("✅ Frame extraction successful") + return temp_path + + except Exception as e: + print(f"❌ Unexpected error: {str(e)}") + return None + finally: + if 'cap' in locals(): + cap.release() + +def handle_last_frame_transfer(gallery: list, selected_idx: int) -> Optional[str]: + """Improved frame transfer with video input validation""" + try: + if gallery is None or not gallery: + raise ValueError("No videos generated yet") + + if selected_idx is None: + # Auto-select last generated video if batch_size=1 + if len(gallery) == 1: + selected_idx = 0 + else: + raise ValueError("Please select a video first") + + if selected_idx >= len(gallery): + raise ValueError("Invalid selection index") + + item = gallery[selected_idx] + + # Video file existence check + video_path = parse_video_path(item) + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file missing: {video_path}") + + return extract_last_frame(video_path) + + except Exception as e: + print(f"Frame transfer failed: {str(e)}") + return None + +def parse_video_path(item) -> str: + """Parse different gallery item formats""" + if isinstance(item, tuple): + return item[0] + elif isinstance(item, dict): + return item.get('name') or item.get('data') + return str(item) + +def get_random_image_from_folder(folder_path): + """Get a random image from the specified folder""" + if not os.path.isdir(folder_path): + return None, f"Error: {folder_path} is not a valid directory" + + # Get all image files in the folder + image_files = [] + for ext in ('*.jpg', '*.jpeg', '*.png', '*.bmp', '*.webp'): + image_files.extend(glob.glob(os.path.join(folder_path, ext))) + for ext in ('*.JPG', '*.JPEG', '*.PNG', '*.BMP', '*.WEBP'): + image_files.extend(glob.glob(os.path.join(folder_path, ext))) + + if not image_files: + return None, f"Error: No image files found in {folder_path}" + + # Select a random image + random_image = random.choice(image_files) + return random_image, f"Selected: {os.path.basename(random_image)}" + +def resize_image_keeping_aspect_ratio(image_path, max_width, max_height): + """Resize image keeping aspect ratio and ensuring dimensions are divisible by 16""" + try: + img = Image.open(image_path) + width, height = img.size + + # Calculate aspect ratio + aspect_ratio = width / height + + # Calculate new dimensions while maintaining aspect ratio + if width > height: + new_width = min(max_width, width) + new_height = int(new_width / aspect_ratio) + else: + new_height = min(max_height, height) + new_width = int(new_height * aspect_ratio) + + # Make dimensions divisible by 16 + new_width = math.floor(new_width / 16) * 16 + new_height = math.floor(new_height / 16) * 16 + + # Ensure minimum size + new_width = max(16, new_width) + new_height = max(16, new_height) + + # Resize image + resized_img = img.resize((new_width, new_height), Image.LANCZOS) + + # Save to temporary file + temp_path = f"temp_resized_{os.path.basename(image_path)}" + resized_img.save(temp_path) + + return temp_path, (new_width, new_height) + except Exception as e: + return None, f"Error: {str(e)}" +# Function to process a batch of images from a folder +def batch_handler( + use_random, + prompt, negative_prompt, + width, height, + video_length, fps, infer_steps, + seed, flow_shift, guidance_scale, embedded_cfg_scale, + batch_size, input_folder_path, + dit_folder, model, vae, te1, te2, save_path, output_type, attn_mode, + block_swap, exclude_single_blocks, use_split_attn, use_fp8, split_uncond, + lora_folder, *lora_params +): + """Handle both folder-based batch processing and regular batch processing""" + global stop_event + + # Check if this is a SkyReels model that needs special handling + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + + if use_random: + # Random image from folder mode + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Get random image from folder + random_image, status = get_random_image_from_folder(input_folder_path) + if random_image is None: + yield all_videos, f"Error in batch {i+1}: {status}", "" + continue + + # Resize image + resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) + if resized_image is None: + yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" + continue + + # If we have dimensions, update them + local_width, local_height = width, height + if isinstance(size_info, tuple): + local_width, local_height = size_info + progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height}" + else: + progress_text = f"Using image: {os.path.basename(random_image)}" + + yield all_videos.copy(), batch_text, progress_text + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + # Process the image + # For SkyReels models, we need to create a command with dit_in_channels=32 + if is_skyreels_i2v: + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model + + # Extract parameters from lora_params + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + cmd = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(local_height), str(local_width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(embedded_cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128", + "--dit_in_channels", "32", # This is crucial for SkyReels i2v + "--image_path", resized_image # Pass the image directly + ] + + if use_fp8: + cmd.append("--fp8") + + if split_uncond: + cmd.append("--split_uncond") + + if use_split_attn: + cmd.append("--split_attn") + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + if negative_prompt: + cmd.extend(["--negative_prompt", negative_prompt]) + + if guidance_scale is not None: + cmd.extend(["--guidance_scale", str(guidance_scale)]) + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + print(f"Running command: {' '.join(cmd)}") + + # Run the process + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield all_videos, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield all_videos.copy(), f"Processing video {i+1} (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + all_videos.append((str(video_path), f"Seed: {current_seed}")) + else: + # For non-SkyReels models, use the regular process_single_video function + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + single_video_args = [ + prompt, local_width, local_height, 1, video_length, fps, infer_steps, + current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, embedded_cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + single_video_args.extend(lora_weights) + single_video_args.extend(lora_multipliers) + single_video_args.extend([None, resized_image, None, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) + + for videos, status, progress in process_single_video(*single_video_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clean up temporary file + try: + if os.path.exists(resized_image): + os.remove(resized_image) + except: + pass + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # Regular image input - this is the part we need to fix + # When a SkyReels I2V model is used, we need to use the direct command approach + # with dit_in_channels=32 explicitly specified, just like in the folder processing branch + if is_skyreels_i2v: + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract lora parameters + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + extra_args = list(lora_params[num_lora_weights*2:]) if len(lora_params) > num_lora_weights*2 else [] + + # Print extra_args for debugging + print(f"Extra args: {extra_args}") + + # Get input image path from extra args - this is where we need to fix + # In skyreels_generate_btn.click, we're passing skyreels_input which + # should be the image path + image_path = None + if len(extra_args) > 0 and extra_args[0] is not None: + image_path = extra_args[0] + print(f"Image path found in extra_args[0]: {image_path}") + + # If we still don't have an image path, this is a problem + if not image_path: + # Let's try to debug what's happening - in the future, you can remove these + # debug prints once everything works correctly + print("No image path found in extra_args[0]") + print(f"Full lora_params: {lora_params}") + yield [], "Error: No input image provided", "An input image is required for SkyReels I2V models" + return + + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Set up environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model + + # Build the command with dit_in_channels=32 + cmd = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(embedded_cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128", + "--dit_in_channels", "32", # This is crucial for SkyReels i2v + "--image_path", image_path + ] + + if use_fp8: + cmd.append("--fp8") + + if split_uncond: + cmd.append("--split_uncond") + + if use_split_attn: + cmd.append("--split_attn") + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + if negative_prompt: + cmd.extend(["--negative_prompt", negative_prompt]) + + if guidance_scale is not None: + cmd.extend(["--guidance_scale", str(guidance_scale)]) + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + print(f"Running command: {' '.join(cmd)}") + + # Run the process + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield all_videos, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield all_videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + all_videos.append((str(video_path), f"Seed: {current_seed}")) + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # For regular non-SkyReels models, use the original process_batch function + regular_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, guidance_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + yield from process_batch(*(regular_args + list(lora_params))) + +def get_dit_models(dit_folder: str) -> List[str]: + """Get list of available DiT models in the specified folder""" + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + +def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: + """Update both DiT and LoRA dropdowns""" + # Get model lists + dit_models = get_dit_models(dit_folder) + lora_choices = get_lora_options(lora_folder) + + # Current values processing + dit_value = current_values[0] + if dit_value not in dit_models: + dit_value = dit_models[0] if dit_models else None + + weights = current_values[1:5] + multipliers = current_values[5:9] + + results = [gr.update(choices=dit_models, value=dit_value)] + + # Add LoRA updates + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in lora_choices: + weight = "None" + results.extend([ + gr.update(choices=lora_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def extract_video_metadata(video_path: str) -> Dict: + """Extract metadata from video file using ffprobe.""" + cmd = [ + 'ffprobe', + '-v', 'quiet', + '-print_format', 'json', + '-show_format', + video_path + ] + + try: + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + metadata = json.loads(result.stdout.decode('utf-8')) + if 'format' in metadata and 'tags' in metadata['format']: + comment = metadata['format']['tags'].get('comment', '{}') + return json.loads(comment) + return {} + except Exception as e: + print(f"Metadata extraction failed: {str(e)}") + return {} + +def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: + """Map metadata parameters to Gradio components for different tabs""" + mapping = { + 'common': { + 'prompt': ('prompt', 'v2v_prompt', 'wanx_v2v_prompt'), # Add WanX-v2v mapping + 'width': ('width', 'v2v_width', 'wanx_v2v_width'), + 'height': ('height', 'v2v_height', 'wanx_v2v_height'), + 'batch_size': ('batch_size', 'v2v_batch_size', 'wanx_v2v_batch_size'), + 'video_length': ('video_length', 'v2v_video_length', 'wanx_v2v_video_length'), + 'fps': ('fps', 'v2v_fps', 'wanx_v2v_fps'), + 'infer_steps': ('infer_steps', 'v2v_infer_steps', 'wanx_v2v_infer_steps'), + 'seed': ('seed', 'v2v_seed', 'wanx_v2v_seed'), + 'flow_shift': ('flow_shift', 'v2v_flow_shift', 'wanx_v2v_flow_shift'), + 'guidance_scale': ('cfg_scale', 'v2v_cfg_scale', 'wanx_v2v_guidance_scale'), + 'negative_prompt': ('negative_prompt', 'v2v_negative_prompt', 'wanx_v2v_negative_prompt'), + 'strength': ('strength', 'v2v_strength', 'wanx_v2v_strength') + }, + 'lora': { + 'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]', f'wanx_v2v_lora_weights[{i}]') for i in range(4)], + 'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]', f'wanx_v2v_lora_multipliers[{i}]') for i in range(4)] + } + } + + results = {} + for param, value in metadata.items(): + # Handle common parameters + if param in mapping['common']: + target_idx = 0 if target_tab == 't2v' else 1 if target_tab == 'v2v' else 2 + if target_idx < len(mapping['common'][param]): + target = mapping['common'][param][target_idx] + results[target] = value + + # Handle LoRA parameters + if param == 'lora_weights': + for i, weight in enumerate(value[:4]): + target_idx = 0 if target_tab == 't2v' else 1 if target_tab == 'v2v' else 2 + if target_idx < len(mapping['lora']['lora_weights'][i]): + target = mapping['lora']['lora_weights'][i][target_idx] + results[target] = weight + + if param == 'lora_multipliers': + for i, mult in enumerate(value[:4]): + target_idx = 0 if target_tab == 't2v' else 1 if target_tab == 'v2v' else 2 + if target_idx < len(mapping['lora']['lora_multipliers'][i]): + target = mapping['lora']['lora_multipliers'][i][target_idx] + results[target] = float(mult) + + return results + +def add_metadata_to_video(video_path: str, parameters: dict) -> None: + """Add generation parameters to video metadata using ffmpeg.""" + import json + import subprocess + + # Convert parameters to JSON string + params_json = json.dumps(parameters, indent=2) + + # Temporary output path + temp_path = video_path.replace(".mp4", "_temp.mp4") + + # Add Fun-Control information to metadata if applicable + task = parameters.get("task", "") + if task.endswith("-FC"): + parameters["fun_control"] = True + # Store the control path in metadata if available + if "control_path" in parameters: + parameters["control_video"] = os.path.basename(parameters["control_path"]) + + # FFmpeg command to add metadata without re-encoding + cmd = [ + 'ffmpeg', + '-i', video_path, + '-metadata', f'comment={params_json}', + '-codec', 'copy', + temp_path + ] + + try: + # Execute FFmpeg command + subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + # Replace original file with the metadata-enhanced version + os.replace(temp_path, video_path) + except subprocess.CalledProcessError as e: + print(f"Failed to add metadata: {e.stderr.decode()}") + if os.path.exists(temp_path): + os.remove(temp_path) + except Exception as e: + print(f"Error: {str(e)}") + +def count_prompt_tokens(prompt: str) -> int: + enc = tiktoken.get_encoding("cl100k_base") + tokens = enc.encode(prompt) + return len(tokens) + + +def get_lora_options(lora_folder: str = "lora") -> List[str]: + if not os.path.exists(lora_folder): + return ["None"] + lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] + lora_files.sort(key=str.lower) + return ["None"] + lora_files + +def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: + """Transfer selected video and prompt to Video2Video tab""" + if not gallery or evt.index >= len(gallery): + return None, "", selected_index.value + + selected_item = gallery[evt.index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Update the selected index + selected_index.value = evt.index + + return str(video_path), prompt, evt.index + +def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: + """Send the currently selected video to V2V tab""" + if not gallery or selected_index.value is None or selected_index.value >= len(gallery): + return None, "" + + selected_item = gallery[selected_index.value] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return str(video_path), prompt + +def clear_cuda_cache(): + """Clear CUDA cache if available""" + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Optional: synchronize to ensure cache is cleared + torch.cuda.synchronize() + +def wanx_batch_handler( + use_random, + prompt, + negative_prompt, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + batch_size, + input_folder_path, + wanx_input_end, + task, + dit_folder, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers: str, + slg_start: Optional[str], + slg_end: Optional[str], + enable_cfg_skip: bool, + cfg_skip_mode: str, + cfg_apply_ratio: float, + enable_preview: bool, + preview_steps: int, + *lora_params, # <-- DO NOT ADD NAMED ARGS AFTER THIS! +): + """Handle both folder-based batch processing and regular processing for all WanX tabs""" + global stop_event + + # Convert None strings to actual None + slg_layers = None if slg_layers == "None" else slg_layers + slg_start = None if slg_start == "None" else slg_start + slg_end = None if slg_end == "None" else slg_end + + # Construct full dit_path including folder + full_dit_path = os.path.join(dit_folder, dit_path) if not os.path.isabs(dit_path) else dit_path + # Clean up LoRA params to proper format + clean_lora_params = [] + for param in lora_params: + # Convert None strings to "None" for consistency + if param is None or str(param).lower() == "none": + clean_lora_params.append("None") + else: + clean_lora_params.append(str(param)) + + # Extract LoRA weights and multipliers + num_lora_weights = 4 + lora_weights = clean_lora_params[:num_lora_weights] + lora_multipliers = [] + for mult in clean_lora_params[num_lora_weights:num_lora_weights*2]: + try: + lora_multipliers.append(float(mult)) + except (ValueError, TypeError): + lora_multipliers.append(1.0) + while len(lora_weights) < 4: + lora_weights.append("None") + while len(lora_multipliers) < 4: + lora_multipliers.append(1.0) + + # Now extract trailing params: input_file, control_video, control_strength, control_start, control_end + remaining_params = clean_lora_params[num_lora_weights*2:] + input_file = remaining_params[0] if len(remaining_params) > 0 else None + control_video = remaining_params[1] if len(remaining_params) > 1 else None + try: + control_strength = float(remaining_params[2]) if len(remaining_params) > 2 else 1.0 + except Exception: + control_strength = 1.0 + try: + control_start = float(remaining_params[3]) if len(remaining_params) > 3 else 0.0 + except Exception: + control_start = 0.0 + try: + control_end = float(remaining_params[4]) if len(remaining_params) > 4 else 1.0 + except Exception: + control_end = 1.0 + + yield [], [], "Preparing batch...", "" # Clear main and preview galleries + + if use_random: + stop_event.clear() + all_videos = [] + all_previews = [] # Keep track of previews from the last successful item? Or clear each time? Let's clear. + progress_text = "Starting generation..." + yield [], [], "Preparing...", progress_text # Clear galleries again just in case + batch_size = int(batch_size) + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, [], "Generation stopped by user", "" # Yield empty previews on stop + return + + # --- Clear previews for this item --- + current_previews_for_item = [] + yield all_videos.copy(), current_previews_for_item, f"Generating video {i + 1} of {batch_size}", progress_text # Yield cleared previews + + # ... (Keep existing random image logic: get random, resize) ... + random_image, status = get_random_image_from_folder(input_folder_path) + if random_image is None: + yield all_videos, current_previews_for_item, f"Error in batch {i+1}: {status}", "" + continue # Skip to next batch item on error + + resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) + if resized_image is None: + yield all_videos, current_previews_for_item, f"Error resizing image in batch {i+1}: {size_info}", "" + # Clean up the random image if resize failed but image exists + try: + if os.path.exists(random_image) and "temp_resized" not in random_image: # Avoid double delete if resize output existed + pass # Might not want to delete original random image here + except: pass + continue # Skip to next batch item on error + + local_width, local_height = width, height + if isinstance(size_info, tuple): local_width, local_height = size_info + progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height}" + yield all_videos.copy(), current_previews_for_item, f"Generating video {i + 1} of {batch_size}", progress_text + + current_seed = seed + if seed == -1: current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: current_seed = seed + i + + # --- Corrected call to wanx_generate_video with accumulation --- + newly_generated_video = None # Track the video generated *in this iteration* + last_status_for_item = f"Generating video {i+1}/{batch_size}" # Keep track of last status + last_progress_for_item = progress_text # Keep track of last progress line + + # Inner loop iterates through the generator for ONE batch item + for videos_update, previews_update, status, progress in wanx_generate_video( + prompt, negative_prompt, resized_image, local_width, local_height, + video_length, fps, infer_steps, flow_shift, guidance_scale, current_seed, + wanx_input_end, # Pass the argument + task, dit_folder, full_dit_path, vae_path, t5_path, clip_path, save_path, + output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, + fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora_weights[0], lora_weights[1], lora_weights[2], lora_weights[3], + lora_multipliers[0], lora_multipliers[1], lora_multipliers[2], lora_multipliers[3], + enable_cfg_skip, cfg_skip_mode, cfg_apply_ratio, + None, 1.0, 0.0, 1.0, # Placeholders for control video args in random mode + enable_preview=enable_preview, + preview_steps=preview_steps + ): + # Store the latest video info from this *specific* generator run + if videos_update: + # wanx_generate_video yields the *full* list it knows about, + # so we take the last item assuming it's the new one. + newly_generated_video = videos_update[-1] + + current_previews_for_item = previews_update # Update previews for *this* item + last_status_for_item = f"Batch {i+1}/{batch_size}: {status}" # Store last status + last_progress_for_item = progress # Store last progress line + # Yield the *current cumulative* list during progress updates + yield all_videos.copy(), current_previews_for_item, last_status_for_item, last_progress_for_item + + # --- After the inner loop finishes for item 'i' --- + # Now, add the video generated in this iteration to the main list + if newly_generated_video and newly_generated_video not in all_videos: + all_videos.append(newly_generated_video) + print(f"DEBUG: Appended video {newly_generated_video[1] if isinstance(newly_generated_video, tuple) else 'unknown'} to all_videos (Total: {len(all_videos)})") + # Yield the updated cumulative list *immediately* after appending + yield all_videos.copy(), current_previews_for_item, last_status_for_item, last_progress_for_item + elif not newly_generated_video: + print(f"DEBUG: No new video generated or yielded by wanx_generate_video for batch item {i+1}.") + + + # --- Cleanup for item 'i' (Correctly indented) --- + try: + # Only remove the temporary resized image + if os.path.exists(resized_image) and "temp_resized" in resized_image: + os.remove(resized_image) + print(f"DEBUG: Removed temporary resized image: {resized_image}") + except Exception as e: + print(f"Warning: Could not remove temp image {resized_image}: {e}") + clear_cuda_cache() + time.sleep(0.5) + # --- End Cleanup for item 'i' --- + + # --- After the outer loop (all batch items processed) --- + yield all_videos, [], "Batch complete", "" # Yield empty previews at the end + else: + # ... (Keep existing checks for non-random mode: input file, control video) ... + batch_size = int(batch_size) + if not input_file and "i2v" in task: + yield [], [], "Error: No input image provided", "An input image is required for I2V models" + return + if "-FC" in task and not control_video: + yield [], [], "Error: No control video provided", "A control video is required for Fun-Control models" + return + + if batch_size > 1: + stop_event.clear() + all_videos = [] + all_previews = [] # Clear previews at start of batch + progress_text = "Starting generation..." + yield [], [], "Preparing...", progress_text # Clear galleries + + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, [], "Generation stopped by user", "" # Yield empty previews + return + + # --- Clear previews for this item --- + current_previews_for_item = [] + yield all_videos.copy(), current_previews_for_item, f"Generating video {i+1}/{batch_size}", progress_text + + current_seed = seed + if seed == -1: current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: current_seed = seed + i + batch_text = f"Generating video {i+1}/{batch_size} (seed: {current_seed})" + yield all_videos.copy(), current_previews_for_item, batch_text, progress_text # Update status + + # --- Corrected call to wanx_generate_video with accumulation --- + newly_generated_video = None # Track the video generated *in this iteration* + last_status_for_item = f"Generating video {i+1}/{batch_size}" # Keep track of last status + last_progress_for_item = progress_text # Keep track of last progress line + + # Inner loop iterates through the generator for ONE batch item + for videos_update, previews_update, status, progress in wanx_generate_video( + prompt, negative_prompt, input_file, width, height, + video_length, fps, infer_steps, flow_shift, guidance_scale, current_seed, + wanx_input_end, # Pass the argument + task, dit_folder, full_dit_path, vae_path, t5_path, clip_path, save_path, + output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, + fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora_weights[0], lora_weights[1], lora_weights[2], lora_weights[3], + lora_multipliers[0], lora_multipliers[1], lora_multipliers[2], lora_multipliers[3], + enable_cfg_skip, cfg_skip_mode, cfg_apply_ratio, + control_video, control_strength, control_start, control_end, + # --- Pass preview args --- + enable_preview=enable_preview, + preview_steps=preview_steps + ): + # Store the latest video info from this *specific* generator run + if videos_update: + # wanx_generate_video yields the *full* list it knows about, + # so we take the last item assuming it's the new one. + newly_generated_video = videos_update[-1] + + current_previews_for_item = previews_update # Update previews for *this* item + last_status_for_item = f"Batch {i+1}/{batch_size}: {status}" # Store last status + last_progress_for_item = progress # Store last progress line + # Yield the *current cumulative* list during progress updates + yield all_videos.copy(), current_previews_for_item, last_status_for_item, last_progress_for_item + + # --- After the inner loop finishes for item 'i' --- + # Now, add the video generated in this iteration to the main list + if newly_generated_video and newly_generated_video not in all_videos: + all_videos.append(newly_generated_video) + print(f"DEBUG: Appended video {newly_generated_video[1] if isinstance(newly_generated_video, tuple) else 'unknown'} to all_videos (Total: {len(all_videos)})") + # Yield the updated cumulative list *immediately* after appending + yield all_videos.copy(), current_previews_for_item, last_status_for_item, last_progress_for_item + elif not newly_generated_video: + print(f"DEBUG: No new video generated or yielded by wanx_generate_video for batch item {i+1}.") + # --- End modified call --- + + clear_cuda_cache() + time.sleep(0.5) + yield all_videos, [], "Batch complete", "" # Yield empty previews at the end + else: # Single generation (batch_size = 1) + stop_event.clear() + # --- Modified call to wanx_generate_video (yield from) --- + # Add preview args directly + yield from wanx_generate_video( + prompt, negative_prompt, input_file, width, height, + video_length, fps, infer_steps, flow_shift, guidance_scale, seed, + wanx_input_end, # Pass the argument + task, dit_folder, full_dit_path, vae_path, t5_path, clip_path, save_path, + output_type, sample_solver, exclude_single_blocks, attn_mode, block_swap, + fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora_weights[0], lora_weights[1], lora_weights[2], lora_weights[3], + lora_multipliers[0], lora_multipliers[1], lora_multipliers[2], lora_multipliers[3], + enable_cfg_skip, cfg_skip_mode, cfg_apply_ratio, + control_video, control_strength, control_start, control_end, + # --- Pass preview args --- + enable_preview=enable_preview, + preview_steps=preview_steps + ) + +def process_single_video( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + lora1: str = "", + lora2: str = "", + lora3: str = "", + lora4: str = "", + lora1_multiplier: float = 1.0, + lora2_multiplier: float = 1.0, + lora3_multiplier: float = 1.0, + lora4_multiplier: float = 1.0, + video_path: Optional[str] = None, + image_path: Optional[str] = None, + strength: Optional[float] = None, + negative_prompt: Optional[str] = None, + embedded_cfg_scale: Optional[float] = None, + split_uncond: Optional[bool] = None, + guidance_scale: Optional[float] = None, + use_fp8: bool = True +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate a single video with the given parameters""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + if is_skyreels: + # Force certain parameters for SkyReels + if negative_prompt is None: + negative_prompt = "" + if embedded_cfg_scale is None: + embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels + if split_uncond is None: + split_uncond = True + if guidance_scale is None: + guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided + + # Determine the input channels based on model type + if is_skyreels_i2v: + dit_in_channels = 32 # SkyReels I2V uses 32 channels + else: + dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) + else: + dit_in_channels = 16 # Regular Hunyuan models use 16 channels + embedded_cfg_scale = cfg_scale + + if os.path.isabs(model): + model_path = model + else: + model_path = os.path.normpath(os.path.join(dit_folder, model)) + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + env["BATCH_RUN_ID"] = f"{time.time()}" + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) + if batch_size > 1: # Only modify seed for batch generation + current_seed = (seed + batch_id * 100003) % (2**32) + else: + current_seed = seed + + clear_cuda_cache() + + command = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128" + ] + + if use_fp8: + command.append("--fp8") + + # Add negative prompt and embedded cfg scale for SkyReels + if is_skyreels: + command.extend(["--dit_in_channels", str(dit_in_channels)]) + command.extend(["--guidance_scale", str(guidance_scale)]) + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + if split_uncond: + command.append("--split_uncond") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + if use_split_attn: + command.append("--split_attn") + + # Handle input paths + if video_path: + command.extend(["--video_path", video_path]) + if strength is not None: + command.extend(["--strength", str(strength)]) + elif image_path: + command.extend(["--image_path", image_path]) + # Only add strength parameter for non-SkyReels I2V models + # SkyReels I2V doesn't use strength parameter for image-to-video generation + if strength is not None and not is_skyreels_i2v: + command.extend(["--strength", str(strength)]) + + print(f"{command}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "model": model, + "vae": vae, + "te1": te1, + "te2": te2, + "save_path": save_path, + "flow_shift": flow_shift, + "cfg_scale": cfg_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "input_video": video_path if video_path else None, + "input_image": image_path if image_path else None, + "strength": strength, + "negative_prompt": negative_prompt if is_skyreels else None, + "embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +# The issue is in the process_batch function, in the section that handles different input types +# Here's the corrected version of that section: + +def process_batch( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + *args +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Process a batch of videos using Gradio's queue""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract additional arguments + num_lora_weights = 4 + lora_weights = args[:num_lora_weights] + lora_multipliers = args[num_lora_weights:num_lora_weights*2] + extra_args = args[num_lora_weights*2:] + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + # Handle input paths and additional parameters + input_path = extra_args[0] if extra_args else None + strength = float(extra_args[1]) if len(extra_args) > 1 else None + + # Get use_fp8 flag (it should be the last parameter) + use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True + + # Get SkyReels specific parameters if applicable + if is_skyreels: + # Always set embedded_cfg_scale to 1.0 for SkyReels models + embedded_cfg_scale = 1.0 + + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" + # Use cfg_scale for guidance_scale parameter + guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale + split_uncond = True if len(extra_args) > 4 and extra_args[4] else False + else: + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None + guidance_scale = cfg_scale + embedded_cfg_scale = cfg_scale + split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Handle different input types + video_path = None + image_path = None + + if input_path: + # Check if it's an image file (common image extensions) + is_image = False + lower_path = input_path.lower() + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') + is_image = any(lower_path.endswith(ext) for ext in image_extensions) + + # Only use image_path for SkyReels I2V models and actual image files + if is_skyreels_i2v and is_image: + image_path = input_path + else: + video_path = input_path + + # Prepare arguments for process_single_video + single_video_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + single_video_args.extend(lora_weights) + single_video_args.extend(lora_multipliers) + single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) + + for videos, status, progress in process_single_video(*single_video_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def update_wanx_image_dimensions(image): + """Update dimensions from uploaded image""" + if image is None: + return "", gr.update(value=832), gr.update(value=480) + img = Image.open(image) + w, h = img.size + w = (w // 32) * 32 + h = (h // 32) * 32 + return f"{w}x{h}", w, h + +def calculate_wanx_width(height, original_dims): + """Calculate width based on height maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 32) * 32 + return gr.update(value=new_width) + +def calculate_wanx_height(width, original_dims): + """Calculate height based on width maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 32) * 32 + return gr.update(value=new_height) + +def update_wanx_from_scale(scale, original_dims): + """Update dimensions based on scale percentage""" + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 32) * 32 + new_h = math.floor((orig_h * scale / 100) / 32) * 32 + return gr.update(value=new_w), gr.update(value=new_h) + +def recommend_wanx_flow_shift(width, height): + """Get recommended flow shift value based on dimensions""" + recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 + return gr.update(value=recommended_shift) + +def handle_wanx_gallery_select(evt: gr.SelectData, gallery) -> tuple: + """Track selected index and video path when gallery item is clicked""" + if gallery is None: + return None, None + + if evt.index >= len(gallery): + return None, None + + selected_item = gallery[evt.index] + video_path = None + + # Extract the video path based on the item type + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + return evt.index, video_path + +def get_step_from_preview_path(path): + match = re.search(r"step_(\d+)_", os.path.basename(path)) + return int(match.group(1)) if match else -1 + +def wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + wanx_input_end, + task, + dit_folder, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers, + slg_start, + slg_end, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0, + enable_cfg_skip=False, + cfg_skip_mode="none", + cfg_apply_ratio=0.7, + control_video=None, + control_strength=1.0, + control_start=0.0, + control_end=1.0, + enable_preview: bool = False, + preview_steps: int = 5 +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate video with WanX model (supports both i2v, t2v and Fun-Control)""" + global stop_event + + current_previews = [] + yield [], current_previews, "Preparing...", "" # Yield empty previews + + # Fix 1: Ensure lora_folder is a string + lora_folder = str(lora_folder) if lora_folder else "lora" + + # Debug prints + print(f"DEBUG - LoRA params: {lora1}, {lora2}, {lora3}, {lora4}") + print(f"DEBUG - LoRA multipliers: {lora1_multiplier}, {lora2_multiplier}, {lora3_multiplier}, {lora4_multiplier}") + print(f"DEBUG - LoRA folder: {lora_folder}") + + # Convert values safely to float or None + try: + slg_start_float = float(slg_start) if slg_start is not None and str(slg_start).lower() != "none" else None + except (ValueError, TypeError): + slg_start_float = None + print(f"Warning: Could not convert slg_start '{slg_start}' to float") + + try: + slg_end_float = float(slg_end) if slg_end is not None and str(slg_end).lower() != "none" else None + except (ValueError, TypeError): + slg_end_float = None + print(f"Warning: Could not convert slg_end '{slg_end}' to float") + + print(f"slg_start_float: {slg_start_float}, slg_end_float: {slg_end_float}") + + if stop_event.is_set(): + yield [], [], "", "" # Yield empty previews + return + + run_id = f"{int(time.time())}_{random.randint(1000, 9999)}" + unique_preview_suffix = f"wanx_{run_id}" # Add prefix for clarity + # --- Construct unique preview paths --- + preview_base_path = os.path.join(save_path, f"latent_preview_{unique_preview_suffix}") + preview_mp4_path = preview_base_path + ".mp4" + preview_png_path = preview_base_path + ".png" + + # Check if this is a Fun-Control task + is_fun_control = "-FC" in task and control_video is not None + if is_fun_control: + print(f"DEBUG - Using Fun-Control mode with control video: {control_video}") + # Verify control video is provided + if not control_video: + yield [], "Error: No control video provided", "Fun-Control requires a control video" + return + + # Verify needed files exist + for path_name, path in [ + ("DIT", dit_path), + ("VAE", vae_path), + ("T5", t5_path), + ("CLIP", clip_path) + ]: + if not os.path.exists(path): + yield [], f"Error: {path_name} model not found", f"Model file doesn't exist: {path}" + return + + # Get current seed or use provided seed + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + + # Check if we need input image (required for i2v, not for t2v) + if "i2v" in task and not input_image: + yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" + return + + # Check for Fun-Control requirements + if is_fun_control and not control_video: + yield [], "Error: No control video provided", "Please provide a control video for Fun-Control generation" + return + + # Prepare environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + clear_cuda_cache() + + # Fix 2: Create command array with all string values + command = [ + sys.executable, + "wan_generate_video.py", + "--task", str(task), + "--prompt", str(prompt), + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", str(save_path), + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", str(output_type), + "--attn_mode", str(attn_mode), + "--blocks_to_swap", str(block_swap), + "--dit", str(dit_path), + "--vae", str(vae_path), + "--t5", str(t5_path), + "--sample_solver", str(sample_solver) + ] + + # Fix 3: Only add boolean flags if they're True + if enable_preview and preview_steps > 0: + command.extend(["--preview", str(preview_steps)]) + # --- ADDED: Pass the unique suffix --- + command.extend(["--preview_suffix", unique_preview_suffix]) + # --- End Pass Suffix --- + print(f"DEBUG - Enabling preview every {preview_steps} steps with suffix {unique_preview_suffix}.") + + if enable_cfg_skip and cfg_skip_mode != "none": + command.extend([ + "--cfg_skip_mode", str(cfg_skip_mode), + "--cfg_apply_ratio", str(cfg_apply_ratio) + ]) + + if wanx_input_end and wanx_input_end != "none" and os.path.exists(str(wanx_input_end)): + command.extend(["--end_image_path", str(wanx_input_end)]) + command.extend(["--trim_tail_frames", "3"]) + + # Handle Fun-Control (control video path) + if is_fun_control and control_video: + command.extend(["--control_path", str(control_video)]) + command.extend(["--control_weight", str(control_strength)]) + command.extend(["--control_start", str(control_start)]) + command.extend(["--control_end", str(control_end)]) + + # Handle SLG parameters + if slg_layers and str(slg_layers).strip() and str(slg_layers).lower() != "none": + try: + # Make sure slg_layers is parsed as a list of integers + slg_list = [] + for layer in str(slg_layers).split(","): + layer = layer.strip() + if layer.isdigit(): # Only add if it's a valid integer + slg_list.append(int(layer)) + if slg_list: # Only add if we have valid layers + command.extend(["--slg_layers", ",".join(map(str, slg_list))]) + + # Only add slg_start and slg_end if we have valid slg_layers + try: + if slg_start_float is not None and slg_start_float >= 0: + command.extend(["--slg_start", str(slg_start_float)]) + if slg_end_float is not None and slg_end_float <= 1.0: + command.extend(["--slg_end", str(slg_end_float)]) + except ValueError as e: + print(f"Invalid SLG timing values: {str(e)}") + except ValueError as e: + print(f"Invalid SLG layers format: {slg_layers} - {str(e)}") + + + # Add image path only for i2v task and if input image is provided + if "i2v" in task and input_image: + command.extend(["--image_path", str(input_image)]) + command.extend(["--clip", str(clip_path)]) # CLIP is needed for i2v and Fun-Control + + # Add video path for v2v task + if "v2v" in task and input_image: + command.extend(["--video_path", str(input_image)]) + # Add strength parameter for video-to-video + if isinstance(guidance_scale, (int, float)) and guidance_scale > 0: + command.extend(["--strength", str(guidance_scale)]) + + if negative_prompt: + command.extend(["--negative_prompt", str(negative_prompt)]) + + # Add boolean flags correctly + if fp8: + command.append("--fp8") + + if fp8_scaled: + command.append("--fp8_scaled") + + if fp8_t5: + command.append("--fp8_t5") + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + + # Handle LoRA weights and multipliers + lora_weights = [lora1, lora2, lora3, lora4] + lora_multipliers = [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier] + + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + # Skip None, empty, or "None" values + if weight is None or not str(weight) or str(weight).lower() == "none": + continue + + # Ensure weight is a string + weight_str = str(weight) + + # Construct full path and verify file exists + full_path = os.path.join(lora_folder, weight_str) + if not os.path.exists(full_path): + print(f"LoRA file not found: {full_path}") + continue + + # Add valid LoRA to the list + valid_loras.append((full_path, mult)) + + # Only add LoRA parameters if we have valid LoRAs + if valid_loras: + weights = [w for w, _ in valid_loras] + multipliers = [str(m) for _, m in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + # Make sure every item in command is a string + command = [str(item) for item in command] + + print(f"Running: {' '.join(command)}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + processed_preview_files = set() # Keep track of previews already yielded - REMAINS THE SAME IN UI FUNCTION + # --- Reset preview state for this run --- + current_preview_yield_path = None + last_preview_mtime = 0 + + current_phase = "Preparing" # Add phase tracking like FramePack + while True: + if stop_event.is_set(): + try: + p.terminate() + p.wait(timeout=5) + except subprocess.TimeoutExpired: + p.kill() + p.wait() + except Exception as e: + print(f"Error terminating subprocess: {e}") + yield [], [], "Generation stopped by user.", "" # Yield empty previews + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + time.sleep(0.01); continue + + line = line.strip() + if not line: continue + print(f"WANX SUBPROCESS: {line}") # Log subprocess output + + # --- Adopt FramePack's Parsing Logic --- + status_text = f"Processing (seed: {current_seed})" # Default status + progress_text_update = line # Default progress + + # Check for TQDM progress using regex + tqdm_match = re.search(r'(\d+)\%\|.+\| (\d+)/(\d+) \[(\d{2}:\d{2})<(\d{2}:\d{2})', line) + + if tqdm_match: + percentage = int(tqdm_match.group(1)) + current_step = int(tqdm_match.group(2)) + total_steps = int(tqdm_match.group(3)) + time_elapsed = tqdm_match.group(4) + time_remaining = tqdm_match.group(5) + + current_phase = f"Denoising Step {current_step}/{total_steps}" # Update phase + + # Format progress text like FramePack for JS compatibility + progress_text_update = f"Step {current_step}/{total_steps} ({percentage}%) | Elapsed: {time_elapsed}, Remaining: {time_remaining}" + status_text = f"Generating (seed: {current_seed}) - {current_phase}" + + elif "ERROR" in line.upper() or "TRACEBACK" in line.upper(): + status_text = f"Error (seed: {current_seed})" + progress_text_update = line # Show error line + current_phase = "Error" + + # Add more phases if needed (e.g., "Decoding", "Saving") by checking logs + elif "Decoding video..." in line: # Placeholder check + current_phase = "Decoding Video" + status_text = f"Generating (seed: {current_seed}) - {current_phase}" + progress_text_update = "Decoding video..." + + elif "Video saved to:" in line: # Placeholder check + current_phase = "Saved" + status_text = f"Completed (seed: {current_seed})" + progress_text_update = line # Show the save line + # Add any other status parsing if needed + preview_updated = False + current_mtime = 0 + found_preview_path = None + + if enable_preview: + # --- MODIFIED: Check unique paths --- + if os.path.exists(preview_mp4_path): + current_mtime = os.path.getmtime(preview_mp4_path) + found_preview_path = preview_mp4_path + elif os.path.exists(preview_png_path): + current_mtime = os.path.getmtime(preview_png_path) + found_preview_path = preview_png_path + # --- End Modified Check --- + + if found_preview_path and current_mtime > last_preview_mtime: + print(f"DEBUG: Preview file updated: {found_preview_path} (mtime: {current_mtime})") + # Yield the clean path (already unique) + current_preview_yield_path = found_preview_path # No cache buster needed + last_preview_mtime = current_mtime + preview_updated = True + # --- End Preview Check --- + + # --- YIELD --- + # Yield progress and potentially updated unique preview path + preview_list_for_yield = [current_preview_yield_path] if current_preview_yield_path else [] + # Yield progress and potentially updated unique preview path list + yield videos.copy(), preview_list_for_yield, status_text, progress_text_update + + p.stdout.close() + rc = p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # --- Collect final generated video --- + generated_video_path = None + if rc == 0: # Only look for video if process succeeded + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + # Find the most recent mp4 containing the seed + all_mp4_files = glob.glob(os.path.join(save_path_abs, f"*_{current_seed}*.mp4")) + # Exclude files in the 'previews' subdirectory + all_mp4_files = [f for f in all_mp4_files if "previews" not in os.path.dirname(f)] + + if all_mp4_files: + # Find the *absolute* most recent one, as multiple might match seed in edge cases + generated_video_path = max(all_mp4_files, key=os.path.getmtime) + print(f"Found newly generated video: {generated_video_path}") + + # Add metadata (assuming add_metadata_to_video exists and works) + parameters = { + "prompt": prompt, "negative_prompt": negative_prompt, + "input_image": input_image if "i2v" in task else None, + "width": width, "height": height, "video_length": video_length, "fps": fps, + "infer_steps": infer_steps, "flow_shift": flow_shift, "guidance_scale": guidance_scale, + "seed": current_seed, "task": task, "dit_path": dit_path, + "vae_path": vae_path, "t5_path": t5_path, "clip_path": clip_path if "i2v" in task or is_fun_control else None, + "save_path": save_path, "output_type": output_type, "sample_solver": sample_solver, + "exclude_single_blocks": exclude_single_blocks, "attn_mode": attn_mode, + "block_swap": block_swap, "fp8": fp8, "fp8_scaled": fp8_scaled, "fp8_t5": fp8_t5, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "slg_layers": slg_layers, "slg_start": slg_start, "slg_end": slg_end, + "enable_cfg_skip": enable_cfg_skip, "cfg_skip_mode": cfg_skip_mode, "cfg_apply_ratio": cfg_apply_ratio, + "control_video": control_video if is_fun_control else None, + "control_strength": control_strength if is_fun_control else None, + "control_start": control_start if is_fun_control else None, + "control_end": control_end if is_fun_control else None, + } + try: + add_metadata_to_video(generated_video_path, parameters) + except NameError: + print("Warning: add_metadata_to_video function not found. Skipping metadata.") + except Exception as meta_err: + print(f"Warning: Failed to add metadata: {meta_err}") + + # Append to the final video list + videos.append((str(generated_video_path), f"Seed: {current_seed}")) + else: + print(f"Subprocess finished successfully (rc=0), but could not find generated video for seed {current_seed} in {save_path_abs}") + +# --- Final Yield --- + final_status = f"Completed (seed: {current_seed})" if rc == 0 and generated_video_path else f"Failed (seed: {current_seed}, rc={rc})" + final_progress = f"Video saved: {os.path.basename(generated_video_path)}" if rc == 0 and generated_video_path else f"Subprocess failed with exit code {rc}" + + # Check for the preview file one last time for the final update (using unique path) + # --- MODIFIED Final Preview Check and List Creation --- + final_preview_path = None + # --- Use the UNIQUE paths defined earlier in the function --- + if os.path.exists(preview_mp4_path): + final_preview_path = os.path.abspath(preview_mp4_path) + elif os.path.exists(preview_png_path): + final_preview_path = os.path.abspath(preview_png_path) + # --- End path checking --- + + final_preview_list_for_yield = [final_preview_path] if final_preview_path else [] + # --- End Modified --- + + yield videos, final_preview_list_for_yield, final_status, final_progress + +def send_wanx_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + guidance_scale: float, + negative_prompt: str +) -> Tuple: + """Send the selected WanX video to Video2Video tab""" + if gallery is None or not gallery: + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + # If no selection made but we have videos, use the first one + if selected_index is None and len(gallery) > 0: + selected_index = 0 + + if selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + # Clean up path for Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Make sure it's a string + video_path = str(video_path) + + return (video_path, prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def wanx_generate_video_batch( + prompt, + negative_prompt, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers: int, + slg_start: Optional[str], + slg_end: Optional[str], + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0, + batch_size=1, + input_image=None # Make input_image optional and place it at the end +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate videos with WanX with support for batches""" + slg_start = None if slg_start == 'None' or slg_start is None else slg_start + slg_end = None if slg_end == 'None' or slg_end is None else slg_end + + # Now safely convert to float if not None + slg_start_float = float(slg_start) if slg_start is not None and isinstance(slg_start, (str, int, float)) else None + slg_end_float = float(slg_end) if slg_end is not None and isinstance(slg_end, (str, int, float)) else None + print(f"slg_start_float: {slg_start_float}, slg_end_float: {slg_end_float}") + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Process each item in the batch + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Generate a single video using the existing function + for videos, status, progress in wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + current_seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_scaled, + fp8_t5, + lora_folder, + slg_layers, + slg_start, + slg_end, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def update_wanx_t2v_dimensions(size): + """Update width and height based on selected size""" + width, height = map(int, size.split('*')) + return gr.update(value=width), gr.update(value=height) + +def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when gallery item is clicked""" + return evt.index + +def send_wanx_t2v_to_v2v( + gallery, prompt, selected_index, width, height, video_length, + fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt +) -> Tuple: + """Send the selected WanX T2V video to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def prepare_for_batch_extension(input_img, base_video, batch_size): + """Prepare inputs for batch video extension""" + if input_img is None: + return None, None, batch_size, "No input image found", "" + + if base_video is None: + return input_img, None, batch_size, "No base video selected for extension", "" + + return input_img, base_video, batch_size, "Preparing batch extension...", f"Will create {batch_size} variations of extended video" + +def concat_batch_videos(base_video_path, generated_videos, save_path, original_video_path=None): + """Concatenate multiple generated videos with the base video""" + if not base_video_path: + return [], "No base video provided" + + if not generated_videos or len(generated_videos) == 0: + return [], "No new videos generated" + + # Create output directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + + # Track all extended videos + extended_videos = [] + + # For each generated video, create an extended version + for i, video_item in enumerate(generated_videos): + try: + # Extract video path from gallery item + if isinstance(video_item, tuple): + new_video_path = video_item[0] + seed_info = video_item[1] if len(video_item) > 1 else "" + elif isinstance(video_item, dict): + new_video_path = video_item.get("name", video_item.get("data", None)) + seed_info = "" + else: + new_video_path = video_item + seed_info = "" + + if not new_video_path or not os.path.exists(new_video_path): + print(f"Skipping missing video: {new_video_path}") + continue + + # Create unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + # Extract seed from seed_info if available + seed_match = re.search(r"Seed: (\d+)", seed_info) + seed_part = f"_seed{seed_match.group(1)}" if seed_match else f"_{i}" + + output_filename = f"extended_{timestamp}{seed_part}_{Path(base_video_path).stem}.mp4" + output_path = os.path.join(save_path, output_filename) + + # Create a temporary file list for ffmpeg + list_file = os.path.join(save_path, f"temp_list_{i}.txt") + with open(list_file, "w") as f: + f.write(f"file '{os.path.abspath(base_video_path)}'\n") + f.write(f"file '{os.path.abspath(new_video_path)}'\n") + + # Run ffmpeg concatenation + command = [ + "ffmpeg", + "-f", "concat", + "-safe", "0", + "-i", list_file, + "-c", "copy", + output_path + ] + + subprocess.run(command, check=True, capture_output=True) + + # Clean up temporary file + if os.path.exists(list_file): + os.remove(list_file) + + # Add to extended videos list if successful + if os.path.exists(output_path): + seed_display = f"Extended {seed_info}" if seed_info else f"Extended video #{i+1}" + extended_videos.append((output_path, seed_display)) + + except Exception as e: + print(f"Error processing video {i}: {str(e)}") + + if not extended_videos: + return [], "Failed to create any extended videos" + + return extended_videos, f"Successfully created {len(extended_videos)} extended videos" + +def wanx_extend_single_video( + prompt, negative_prompt, input_image, base_video_path, + width, height, video_length, fps, infer_steps, + flow_shift, guidance_scale, seed, + task, dit_path, vae_path, t5_path, clip_path, + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers="", slg_start=0.0, slg_end=1.0, + lora1="None", lora2="None", lora3="None", lora4="None", + lora1_multiplier=1.0, lora2_multiplier=1.0, lora3_multiplier=1.0, lora4_multiplier=1.0 +): + """Generate a single video and concatenate with base video""" + # First, generate the video with proper parameter handling + all_videos = [] + + # Sanitize lora parameters + lora_weights = [str(lora1) if lora1 is not None else "None", + str(lora2) if lora2 is not None else "None", + str(lora3) if lora3 is not None else "None", + str(lora4) if lora4 is not None else "None"] + + # Convert multipliers to float + try: + lora_multipliers = [float(lora1_multiplier), float(lora2_multiplier), + float(lora3_multiplier), float(lora4_multiplier)] + except (ValueError, TypeError): + # Fallback to defaults if conversion fails + lora_multipliers = [1.0, 1.0, 1.0, 1.0] + + # Debug print + print(f"Sanitized LoRA weights: {lora_weights}") + print(f"Sanitized LoRA multipliers: {lora_multipliers}") + + # Generate video + for videos, status, progress in wanx_generate_video( + prompt, negative_prompt, input_image, width, height, + video_length, fps, infer_steps, flow_shift, guidance_scale, + seed, task, dit_path, vae_path, t5_path, clip_path, + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora_weights[0], lora_weights[1], lora_weights[2], lora_weights[3], + lora_multipliers[0], lora_multipliers[1], lora_multipliers[2], lora_multipliers[3], + enable_cfg_skip=False, + cfg_skip_mode="none", + cfg_apply_ratio=0.7 + ): + + # Keep track of generated videos + if videos: + all_videos = videos + + # Forward progress updates + yield all_videos, status, progress + + # Now concatenate with base video if we have something + if all_videos and base_video_path and os.path.exists(base_video_path): + try: + print(f"Extending base video: {base_video_path}") + + # Create unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + output_filename = f"extended_{timestamp}_seed{seed}_{Path(base_video_path).stem}.mp4" + output_path = os.path.join(save_path, output_filename) + + # Extract the path from the gallery item + new_video_path = all_videos[0][0] if isinstance(all_videos[0], tuple) else all_videos[0] + + # Create a temporary file list for ffmpeg + list_file = os.path.join(save_path, f"temp_list_{seed}.txt") + with open(list_file, "w") as f: + f.write(f"file '{os.path.abspath(base_video_path)}'\n") + f.write(f"file '{os.path.abspath(new_video_path)}'\n") + + print(f"Concatenating: {base_video_path} + {new_video_path}") + + # Run ffmpeg concatenation + command = [ + "ffmpeg", + "-f", "concat", + "-safe", "0", + "-i", list_file, + "-c", "copy", + "-y", + output_path + ] + + subprocess.run(command, check=True, capture_output=True) + + # Clean up temporary file + if os.path.exists(list_file): + os.remove(list_file) + + # Return the extended video if successful + if os.path.exists(output_path): + extended_video = [(output_path, f"Extended (Seed: {seed})")] + print(f"Successfully created extended video: {output_path}") + yield extended_video, "Extended video created successfully", "" + return + else: + print(f"Failed to create extended video at {output_path}") + except Exception as e: + print(f"Error creating extended video: {str(e)}") + + # If we got here, something went wrong with the concatenation + yield all_videos, "Generated video (extension failed)", "" + +def process_batch_extension( + prompt, negative_prompt, input_image, base_video, + width, height, video_length, fps, infer_steps, + flow_shift, guidance_scale, seed, batch_size, + task, dit_folder, dit_path, vae_path, t5_path, clip_path, # <<< Added dit_folder + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_scaled, fp8_t5, lora_folder, + slg_layers, slg_start, slg_end, + lora1="None", lora2="None", lora3="None", lora4="None", + lora1_multiplier=1.0, lora2_multiplier=1.0, lora3_multiplier=1.0, lora4_multiplier=1.0 +): + """Process a batch of video extensions one at a time""" + global stop_event + stop_event.clear() + + all_extended_videos = [] # Store successfully extended videos + progress_text = "Starting video extension batch..." + yield [], progress_text, "" # Initial yield + + try: + # Ensure batch_size is treated as an integer + batch_size = int(batch_size) + except (ValueError, TypeError): + batch_size = 1 + print("Warning: Invalid batch_size, defaulting to 1.") + + # Ensure base_video exists + if not base_video or not os.path.exists(base_video): + yield [], "Error: Base video not found", f"Cannot find video at {base_video}" + return + + # Process each batch item independently + for i in range(batch_size): + if stop_event.is_set(): + yield all_extended_videos, "Extension stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Processing extension {i+1}/{batch_size} (seed: {current_seed})" + yield all_extended_videos, batch_text, progress_text # Update progress + + # Use the direct wrapper with correct parameter order, including dit_folder + generation_iterator = wanx_extend_video_wrapper( + prompt=prompt, negative_prompt=negative_prompt, input_image=input_image, base_video_path=base_video, + width=width, height=height, video_length=video_length, fps=fps, infer_steps=infer_steps, + flow_shift=flow_shift, guidance_scale=guidance_scale, seed=current_seed, + task=task, + dit_folder=dit_folder, # <<< Pass the folder path + dit_path=dit_path, # <<< Pass the model filename + vae_path=vae_path, + t5_path=t5_path, + clip_path=clip_path, + save_path=save_path, output_type=output_type, sample_solver=sample_solver, + exclude_single_blocks=exclude_single_blocks, attn_mode=attn_mode, block_swap=block_swap, + fp8=fp8, fp8_scaled=fp8_scaled, fp8_t5=fp8_t5, lora_folder=lora_folder, + slg_layers=slg_layers, slg_start=slg_start, slg_end=slg_end, + lora1=lora1, lora2=lora2, lora3=lora3, lora4=lora4, + lora1_multiplier=lora1_multiplier, lora2_multiplier=lora2_multiplier, + lora3_multiplier=lora3_multiplier, lora4_multiplier=lora4_multiplier + ) + + # Iterate through the generator for this single extension + final_videos_for_item = [] + final_status_for_item = "Unknown status" + final_progress_for_item = "" + try: + for videos, status, progress in generation_iterator: + # Forward progress information immediately + yield all_extended_videos, f"Batch {i+1}/{batch_size}: {status}", progress + + # Store the latest state for this item + final_videos_for_item = videos + final_status_for_item = status + final_progress_for_item = progress + + # After the loop for one item finishes, check the result + if final_videos_for_item: + # Check if the video is actually an extended one + is_extended = any("Extended" in (v[1] if isinstance(v, tuple) else "") for v in final_videos_for_item) + if is_extended: + all_extended_videos.extend(final_videos_for_item) + print(f"Added extended video to collection (total: {len(all_extended_videos)})") + else: + # It was just the generated segment, maybe log this? + print(f"Video segment generated for batch {i+1} but extension failed or wasn't performed.") + else: + print(f"No video returned for batch item {i+1}.") + + + except Exception as e: + print(f"Error during single extension processing (batch {i+1}): {e}") + yield all_extended_videos, f"Error in batch {i+1}: {e}", "" + + + # Clean CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + # Final yield after the loop + yield all_extended_videos, "Batch extension complete", "" + +def handle_extend_generation(base_video_path: str, new_videos: list, save_path: str, current_gallery: list) -> tuple: + """Combine generated video with base video and update gallery""" + if not base_video_path: + return current_gallery, "Extend failed: No base video provided" + + if not new_videos: + return current_gallery, "Extend failed: No new video generated" + + # Ensure save path exists + os.makedirs(save_path, exist_ok=True) + + # Get the first video from new_videos (gallery item) + new_video_path = new_videos[0][0] if isinstance(new_videos[0], tuple) else new_videos[0] + + # Create a unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + output_filename = f"extended_{timestamp}_{Path(base_video_path).stem}.mp4" + output_path = str(Path(save_path) / output_filename) + + try: + # Concatenate the videos using ffmpeg + ( + ffmpeg + .input(base_video_path) + .concat( + ffmpeg.input(new_video_path) + ) + .output(output_path) + .run(overwrite_output=True, quiet=True) + ) + + # Create a new gallery entry with the combined video + updated_gallery = [(output_path, f"Extended video: {Path(output_path).stem}")] + + return updated_gallery, f"Successfully extended video to {Path(output_path).name}" + except Exception as e: + print(f"Error extending video: {str(e)}") + return current_gallery, f"Failed to extend video: {str(e)}" + +# UI setup +with gr.Blocks( + theme=themes.Default( + primary_hue=colors.Color( + name="custom", + c50="#E6F0FF", + c100="#CCE0FF", + c200="#99C1FF", + c300="#66A3FF", + c400="#3384FF", + c500="#0060df", # This is your main color + c600="#0052C2", + c700="#003D91", + c800="#002961", + c900="#001430", + c950="#000A18" + ) + ), + css=""" + .gallery-item:first-child { border: 2px solid #4CAF50 !important; } + .gallery-item:first-child:hover { border-color: #45a049 !important; } + .green-btn { + background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; + color: white !important; + border: none !important; + } + .green-btn:hover { + background: linear-gradient(to bottom right, #27ae60, #219651) !important; + } + .refresh-btn { + max-width: 40px !important; + min-width: 40px !important; + height: 40px !important; + border-radius: 50% !important; + padding: 0 !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + } + .light-blue-btn { + background: linear-gradient(to bottom right, #AEC6CF, #9AB8C4) !important; /* Light blue gradient */ + color: #333 !important; /* Darker text for readability */ + border: 1px solid #9AB8C4 !important; /* Subtle border */ + } + .light-blue-btn:hover { + background: linear-gradient(to bottom right, #9AB8C4, #8AA9B5) !important; /* Slightly darker on hover */ + border-color: #8AA9B5 !important; + } + """, + +) as demo: + # Add state for tracking selected video indices in both tabs + selected_index = gr.State(value=None) # For Text to Video + v2v_selected_index = gr.State(value=None) # For Video to Video + params_state = gr.State() #New addition + i2v_selected_index = gr.State(value=None) + skyreels_selected_index = gr.State(value=None) + wanx_i2v_selected_index = gr.State(value=None) + extended_videos = gr.State(value=[]) + wanx_base_video = gr.State(value=None) + wanx_sharpest_frame_number = gr.State(value=None) + wanx_sharpest_frame_path = gr.State(value=None) + wanx_trimmed_video_path = gr.State(value=None) + wanx_v2v_selected_index = gr.State(value=None) + wanx_t2v_selected_index = gr.State(value=None) + framepack_selected_index = gr.State(value=None) + framepack_original_dims = gr.State(value="") + fpe_selected_index = gr.State(value=None) + demo.load(None, None, None, js=""" + () => { + document.title = 'H1111'; + + function updateTitle(text) { + if (text && text.trim()) { + // Regex for the FramePack format: "Item ... (...)% | ... Remaining: HH:MM" + const framepackMatch = text.match(/.*?\((\d+)%\).*?Remaining:\s*(\d{2}:\d{2})/); + // Regex for standard tqdm format (like WanX uses) + const tqdmMatch = text.match(/(\d+)%\|.*\[.*<(\d{2}:\d{2})/); // Adjusted slightly for robustness + + if (framepackMatch) { + // Handle FramePack format + const percentage = framepackMatch[1]; + const timeRemaining = framepackMatch[2]; + document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; + } else if (tqdmMatch) { // <<< ADDED ELSE IF for standard tqdm + // Handle standard tqdm format + const percentage = tqdmMatch[1]; + const timeRemaining = tqdmMatch[2]; + document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; + } else { + // Optional: Reset title if neither format matches? + // document.title = 'H1111'; + } + } + } + + setTimeout(() => { + // This selector should still find all relevant progress textareas + const progressElements = document.querySelectorAll('textarea.scroll-hide'); + progressElements.forEach(element => { + if (element) { + new MutationObserver(() => { + updateTitle(element.value); + }).observe(element, { + attributes: true, + childList: true, + characterData: true + }); + } + }); + }, 1000); + } + """) + + with gr.Tabs() as tabs: + + #FRAME PACK TAB + with gr.Tab(id=10, label="FramePack") as framepack_tab: + + with gr.Row(): + with gr.Column(scale=4): + framepack_prompt = gr.Textbox( + scale=3, label="Prompt (Supports sections: index:prompt;;;index:prompt)", + value="cinematic video of a cat wizard casting a spell", lines=3, + info="Use '0:prompt;;;-1:prompt' or '0-2:prompt;;;3:prompt'. Index total sections -1 is last section." + ) + framepack_negative_prompt = gr.Textbox(scale=3, label="Negative Prompt", value="", lines=3) + with gr.Column(scale=1): + framepack_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + framepack_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + framepack_is_f1 = gr.Checkbox(label="🏎️ Use F1 Model", value=False, + info="Switches to the F1 model (different DiT path and logic).") + with gr.Column(scale=2): + framepack_batch_progress = gr.Textbox(label="Status", interactive=False, value="") + framepack_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + with gr.Row(): + framepack_generate_btn = gr.Button("Generate FramePack Video", elem_classes="green-btn") + framepack_stop_btn = gr.Button("Stop Generation", variant="stop") + + # Main Content + with gr.Row(): + # --- Left Column --- + with gr.Column(): + framepack_input_image = gr.Image(label="Input Image (Video Start)", type="filepath") + with gr.Row(): + framepack_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False, + info="If checked, 'Input Image (Video Start)' is hidden. Each batch item uses a random image from the folder.") + framepack_input_folder_path = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images for batch processing", + visible=False # Initially hidden + ) + with gr.Row(visible=False) as framepack_folder_options_row: # Parent Row for folder options + framepack_validate_folder_btn = gr.Button("Validate Folder") + framepack_folder_status_text = gr.Textbox( + label="Folder Status", + placeholder="Validation status will appear here", + interactive=False + ) + with gr.Accordion("Optional End Frame Control (normal model only)", open=False): + framepack_input_end_frame = gr.Image(label="End Frame Image (Video End)", type="filepath", scale=1) + framepack_end_frame_influence = gr.Dropdown( + label="End Frame Influence Mode", + choices=["last", "half", "progressive", "bookend"], + value="last", + info="How the end frame affects generation (if provided)", + visible=False + ) + framepack_end_frame_weight = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.5, # Default changed from 0.3 + label="End Frame Weight", + info="Influence strength of the end frame (if provided)", + visible=False + ) + + gr.Markdown("### Resolution Options (Choose One)") + framepack_target_resolution = gr.Number( + label="Option 1: Target Resolution (Uses Buckets)", + value=640, minimum=0, maximum=1280, step=32, + info="Target bucket size (e.g., 640 for 640x640). Uses input image aspect ratio. Final size divisible by 32.", + interactive=True + ) + with gr.Accordion("Option 2: Explicit Resolution (Overrides Option 1)", open=False): + framepack_scale_slider = gr.Slider( + minimum=1, maximum=200, value=100, step=1, label="Scale % (UI Only)" + ) + with gr.Row(): + framepack_width = gr.Number( + label="Width", value=None, minimum=0, step=32, + info="Must be divisible by 32.", interactive=True + ) + framepack_calc_height_btn = gr.Button("→") + framepack_calc_width_btn = gr.Button("←") + framepack_height = gr.Number( + label="Height", value=None, minimum=0, step=32, + info="Must be divisible by 32.", interactive=True + ) + framepack_total_second_length = gr.Slider(minimum=1.0, maximum=120.0, step=0.5, label="Total Video Length (seconds)", value=5.0) + framepack_video_sections = gr.Number( + label="Total Video Sections (Overrides seconds if > 0)", + value=None, step=1, + info="Specify exact number of sections. If set, 'Total Video Length (seconds)' is ignored by the backend." + ) + framepack_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Output FPS", value=30) + with gr.Row(): + framepack_seed = gr.Number(label="Seed (-1 for random)", value=-1) + framepack_random_seed =gr.Button("🎲️") + framepack_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Steps", value=25, interactive=True) # Moved here + + # --- Right Column --- + with gr.Column(): + framepack_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], rows=[1], + object_fit="contain", height="auto", show_label=True, + elem_id="gallery_framepack", allow_preview=True, preview=True + ) + with gr.Accordion("Latent Preview (During Generation)", open=True): + with gr.Row(): + framepack_enable_preview = gr.Checkbox(label="Enable Latent Preview", value=True) + framepack_preview_every_n_sections = gr.Slider( + minimum=1, maximum=50, step=1, value=1, + label="Preview Every N Sections", + info="Generates previews during the sampling loop." + ) + framepack_preview_output = gr.Video( # Changed from Gallery to Video + label="Latest Preview", height=300, + interactive=False, # Not interactive for display + elem_id="framepack_preview_video" + ) + framepack_skip_btn = gr.Button("Skip Batch Item", elem_classes="light-blue-btn") + with gr.Group(): + with gr.Row(): + framepack_refresh_lora_btn = gr.Button("🔄 LoRA", elem_classes="refresh-btn") # Specific LoRA refresh + framepack_lora_folder = gr.Textbox(label="LoRa Folder", value="lora", scale=4) + framepack_lora_weights = [] + framepack_lora_multipliers = [] + for i in range(4): # Assuming max 4 LoRAs like other tabs + with gr.Row(): + framepack_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", choices=get_lora_options("lora"), + value="None", allow_custom_value=False, interactive=True, scale=2 + )) + framepack_lora_multipliers.append(gr.Slider( + label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0, scale=1, interactive=True + )) + # Fixed Generation Parameters Section + with gr.Accordion("Generation Parameters", open=True): + with gr.Row(): + framepack_distilled_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Distilled Guidance Scale (embedded_cfg_scale)", value=10.0, interactive=True) + framepack_guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.1, label="Guidance Scale (CFG)", value=1.0, interactive=True, info="Default 1.0 (no CFG), backend recommends not changing.") + with gr.Row(): + framepack_guidance_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CFG Rescale (rs)", value=0.0, interactive=True, info="Default 0.0, backend recommends not changing.") + framepack_latent_window_size = gr.Number(label="Latent Window Size", value=9, interactive=True, info="Default 9") + framepack_sample_solver = gr.Dropdown(label="Sample Solver", choices=["unipc", "dpm++", "vanilla"], value="unipc", interactive=True) + + with gr.Accordion("Advanced Section Control (Optional)", open=False): + gr.Markdown( + "Define specific prompts and starting images for different sections of the video. " + "For the index you can input a range or a single index. A 5 second default video has 4 sections. The first section is 0 and the last is 3" + ) + # --- Define section controls explicitly --- + with gr.Row(): + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("**--- Control Slot 1 ---**") + with gr.Row(): + + framepack_sec_1 = gr.Textbox(label="Index/Range", value="0", placeholder="e.g., 0 or 0-1", interactive=True) + framepack_sec_prompt_1 = gr.Textbox(label="Prompt Override", lines=2, placeholder="Overrides base prompt for these sections") + framepack_sec_image_1 = gr.Image(label="Start Image Override", type="filepath", scale=1) + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("**--- Control Slot 2 ---**") + with gr.Row(): + + framepack_sec_2 = gr.Textbox(label="Index/Range", value="1", placeholder="e.g., 2 or 2-3", interactive=True) + framepack_sec_prompt_2 = gr.Textbox(label="Prompt Override", lines=2) + framepack_sec_image_2 = gr.Image(label="Start Image Override", type="filepath", scale=1) + with gr.Row(): + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("**--- Control Slot 3 ---**") + with gr.Row(): + + framepack_sec_3 = gr.Textbox(label="Index/Range", value="2", placeholder="e.g., 4 or 4-5", interactive=True) + framepack_sec_prompt_3 = gr.Textbox(label="Prompt Override", lines=2) + framepack_sec_image_3 = gr.Image(label="Start Image Override", type="filepath", scale=1) + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("**--- Control Slot 4 ---**") + with gr.Row(): + framepack_sec_4 = gr.Textbox(label="Index/Range", value="3", placeholder="e.g., 6 or 6-7", interactive=True) + framepack_sec_prompt_4 = gr.Textbox(label="Prompt Override", lines=2) + framepack_sec_image_4 = gr.Image(label="Start Image Override", type="filepath", scale=1) + + # Group section control components for easier passing to functions (remains the same) + framepack_secs = [framepack_sec_1, framepack_sec_2, framepack_sec_3, framepack_sec_4] + framepack_sec_prompts = [framepack_sec_prompt_1, framepack_sec_prompt_2, framepack_sec_prompt_3, framepack_sec_prompt_4] + framepack_sec_images = [framepack_sec_image_1, framepack_sec_image_2, framepack_sec_image_3, framepack_sec_image_4] + + # Performance/Memory Accordion - Updated + with gr.Accordion("Performance / Memory", open=True): + with gr.Row(): + framepack_fp8 = gr.Checkbox(label="Use FP8 DiT", value=False, info="Enable FP8 precision for the main Transformer model.") + framepack_fp8_llm = gr.Checkbox(label="Use FP8 LLM (Text Encoder 1)", value=False, info="Enable FP8 for the Llama text encoder.", visible=False) + framepack_fp8_scaled = gr.Checkbox(label="Use Scaled FP8 DiT", value=False, info="Requires FP8 DiT. Use scaled math (potential quality improvement).") + framepack_blocks_to_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Blocks to Swap (to Save VRAM, 0=disable)", value=26, + info="Higher values = less VRAM usage but slower generation") + framepack_bulk_decode = gr.Checkbox(label="Bulk Decode Frames (Faster Decode, Higher VRAM)", value=False, info="Decode all frames at once instead of section by section.") + with gr.Row(): + framepack_attn_mode = gr.Dropdown( + label="Attention Mode", + choices=["torch", "sdpa", "flash", "xformers", "sageattn"], # Added choices from script + value="sdpa", # Defaulting to sdpa + interactive=True + ) + framepack_vae_chunk_size = gr.Number(label="VAE Chunk Size (CausalConv3d)", value=32, step=1, minimum=0, info="0 or None=disable (Default: None)") + framepack_vae_spatial_tile_sample_min_size = gr.Number(label="VAE Spatial Tile Min Size", value=128, step=16, minimum=0, info="0 or None=disable (Default: None)") + framepack_device = gr.Textbox(label="Device Override (optional)", placeholder="e.g., cuda:0, cpu") + with gr.Row(): + framepack_use_teacache = gr.Checkbox(label="Use TeaCache", value=False, info="Enable TeaCache for faster generation (shits hands).") + framepack_teacache_steps = gr.Number(label="TeaCache Init Steps", value=25, step=1, minimum=1, info="Steps for TeaCache init (match Inference Steps)") + framepack_teacache_thresh = gr.Slider(label="TeaCache Threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.15, info="Relative L1 distance threshold for skipping.") + + with gr.Accordion("Model Paths / Advanced", open=False): + with gr.Row(): + framepack_transformer_path = gr.Textbox(label="Transformer Path (DiT)", value="hunyuan/FramePackI2V_HY_bf16.safetensors", interactive=True) + framepack_vae_path = gr.Textbox(label="VAE Path", value="hunyuan/pytorch_model.pt") + with gr.Row(): + framepack_text_encoder_path = gr.Textbox(label="Text Encoder 1 (Llama) Path *Required*", value="hunyuan/llava_llama3_fp16.safetensors") + framepack_text_encoder_2_path = gr.Textbox(label="Text Encoder 2 (CLIP) Path *Required*", value="hunyuan/clip_l.safetensors") + with gr.Row(): + framepack_image_encoder_path = gr.Textbox(label="Image Encoder (SigLIP) Path *Required*", value="hunyuan/model.safetensors") + framepack_save_path = gr.Textbox(label="Save Path *Required*", value="outputs") +### FRAMEPACK EXTENSION + with gr.Tab(id=11, label="FramePack-Extension") as framepack_extension_tab: + with gr.Row(): + with gr.Column(scale=4): + fpe_prompt = gr.Textbox( + scale=3, label="Prompt", + value="cinematic video of a cat wizard casting a spell, epic action scene", lines=3 + ) + fpe_negative_prompt = gr.Textbox(scale=3, label="Negative Prompt", value="", lines=3) + with gr.Column(scale=1): + fpe_use_normal_framepack = gr.Checkbox(label="Use Normal FramePack Model", value=False, info="Uses og model supports end frame. Default is F1 model.") + fpe_batch_count = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + with gr.Column(scale=2): + fpe_batch_progress = gr.Textbox(label="Status", interactive=False, value="") + fpe_progress_text = gr.Textbox(label="Progress", interactive=False, lines=1, elem_id="fpe_progress_text") # Unique elem_id + + with gr.Row(): + fpe_generate_btn = gr.Button("Generate Extended Video", elem_classes="green-btn") + fpe_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): # Left column for inputs + fpe_input_video = gr.Video(label="Input Video for Extension", sources=['upload'], height=300) + with gr.Accordion("Optional End Frame (for Normal FramePack Model)", open=False, visible=False) as fpe_end_frame_accordion: + fpe_end_frame = gr.Image(label="End Frame for Extension", type="filepath") + fpe_end_frame_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=1.0, label="End Frame Weight") + with gr.Accordion("Optional Start Guidance Image (for F1 Model Extension)", open=False, visible=True) as fpe_start_guidance_accordion: # Initially hidden + fpe_start_guidance_image = gr.Image(label="Start Guidance Image for Extension", type="filepath") + fpe_start_guidance_image_clip_weight = gr.Slider( + minimum=0.0, maximum=5.0, step=0.05, value=0.75, + label="Start Guidance Image CLIP Weight", + info="Blend weight for the guidance image's CLIP embedding with input video's first frame CLIP." + ) + fpe_use_guidance_image_as_first_latent = gr.Checkbox( + label="Use Guidance Image as First Latent", value=False, + info="If checked, the VAE latent of the guidance image will be used as the initial conditioning for the first generated segment. Turn down context frames when using this" + ) + gr.Markdown("### Core Generation Parameters") + with gr.Row(): + fpe_seed = gr.Number(label="Seed (-1 for random)", value=-1) + # fpe_random_seed_btn = gr.Button("🎲️") # Optional: Add random seed button + + fpe_resolution_max_dim = gr.Number(label="Resolution (Max Dimension)", value=640, step=32, info="Target max width/height for bucket.") + fpe_total_second_length = gr.Slider(minimum=1.0, maximum=120.0, step=0.5, label="Additional Video Length (seconds)", value=5.0) + fpe_latent_window_size = gr.Slider(minimum=9, maximum=33, step=1, label="Latent Window Size", value=9, info="Default 9 for F1 model.") + fpe_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps", value=25) + + with gr.Row(): + fpe_cfg_scale = gr.Slider(minimum=1.0, maximum=32.0, step=0.1, label="CFG Scale", value=1.0, info="Usually 1.0 for F1 (no external CFG).") + fpe_distilled_guidance_scale = gr.Slider(minimum=1.0, maximum=32.0, step=0.1, label="Distilled Guidance (GS)", value=3.0) + # fpe_rs_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="CFG Rescale (RS)", value=0.0, visible=False) + with gr.Row(): + with gr.Accordion("Advanced & Performance", open=True): + fpe_gpu_memory_preservation = gr.Slider(label="GPU Memory Preserve (GB)", minimum=1.0, maximum=16.0, value=6.0, step=0.1) + fpe_use_teacache = gr.Checkbox(label="Use TeaCache", value=False) + fpe_no_resize = gr.Checkbox(label="Force Original Video Resolution (No Resize)", value=False) + fpe_extension_only = gr.Checkbox(label="Save Extension Only", value=False, info="If checked, only the newly generated extension part of the video will be saved.") + fpe_mp4_crf = gr.Slider(label="MP4 CRF (Quality)", minimum=0, maximum=51, value=1, step=1, info="Lower is better quality, larger file.") + fpe_num_clean_frames = gr.Slider(label="Context Frames (1x from Input)", minimum=1, maximum=10, value=5, step=1) + fpe_vae_batch_size = gr.Slider(label="VAE Batch Size (Input Video Encoding)", minimum=4, maximum=128, value=72, step=4) + + fpe_attn_mode = gr.Dropdown(label="Attention Mode (DiT)", choices=["torch", "sdpa", "flash", "xformers", "sageattn"], value="torch") + fpe_fp8_llm = gr.Checkbox(label="Use FP8 LLM (Text Encoder 1)", value=False, visible=False) + fpe_vae_chunk_size = gr.Number(label="VAE Chunk Size (CausalConv3d)", value=32, step=1, minimum=0, info="0 or None=disable") + fpe_vae_spatial_tile_sample_min_size = gr.Number(label="VAE Spatial Tile Min Size", value=128, step=16, minimum=0, info="0 or None=disable") + + + with gr.Column(): # Right column for outputs and advanced settings + fpe_output_gallery = gr.Gallery( + label="Generated Extended Videos", columns=[1], rows=[1], # Show one main video at a time + object_fit="contain", height=480, show_label=True, + elem_id="gallery_framepack_extension", allow_preview=True, preview=True + ) + with gr.Accordion("Live Preview (During Generation)", open=True): + with gr.Row(): + fpe_enable_preview = gr.Checkbox(label="Enable Live Preview", value=True, visible=False) + fpe_preview_interval = gr.Slider( + minimum=1, maximum=50, step=1, value=5, + label="Preview Every N Steps", + info="Saves a PNG preview during sampling.", + visible=False + ) + fpe_preview_output_component = gr.Video( # Changed to Video for MP4 previews + label="Latest Section Preview", height=300, + interactive=False, elem_id="fpe_preview_video" + ) + # fpe_skip_btn = gr.Button("Skip Batch Item", elem_classes="light-blue-btn") # Optional + gr.Markdown("### LoRA Configuration") + with gr.Row(): + fpe_refresh_lora_btn = gr.Button("🔄 LoRA", elem_classes="refresh-btn") + fpe_lora_folder = gr.Textbox(label="LoRA Folder", value="lora", scale=4) + fpe_lora_weights_ui = [] + fpe_lora_multipliers_ui = [] + for i in range(4): + with gr.Row(): + fpe_lora_weights_ui.append(gr.Dropdown( + label=f"LoRA {i+1}", choices=get_lora_options("lora"), + value="None", allow_custom_value=False, interactive=True, scale=2 + )) + fpe_lora_multipliers_ui.append(gr.Slider( + label=f"Multiplier", minimum=0.0, maximum=2.0, step=0.05, value=1.0, scale=1, interactive=True + )) + with gr.Row(): + with gr.Accordion("Model Paths (FramePack-Extension)", open=False): + fpe_transformer_path = gr.Textbox(label="DiT Path (F1 Model)", value="hunyuan/FramePack_F1_I2V_HY_20250503.safetensors") # Default to F1 + fpe_vae_path = gr.Textbox(label="VAE Path", value="hunyuan/pytorch_model.pt") + fpe_text_encoder_path = gr.Textbox(label="Text Encoder 1 (Llama)", value="hunyuan/llava_llama3_fp16.safetensors") + fpe_text_encoder_2_path = gr.Textbox(label="Text Encoder 2 (CLIP)", value="hunyuan/clip_l.safetensors") + fpe_image_encoder_path = gr.Textbox(label="Image Encoder (SigLIP)", value="hunyuan/model.safetensors") + fpe_save_path = gr.Textbox(label="Save Path (Output Directory)", value="outputs/framepack_extensions") + + # Text to Video Tab + with gr.Tab(id=1, label="Hunyuan-t2v"): + with gr.Row(): + with gr.Column(scale=4): + prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + + with gr.Column(scale=1): + token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + + t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") + t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") + video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") + fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") + infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") + flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") + cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") + + with gr.Column(): + + with gr.Row(): + video_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + with gr.Row(): + refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + lora_weights = [] + lora_multipliers = [] + for i in range(4): + with gr.Column(): + lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + seed = gr.Number(label="Seed (use -1 for random)", value=-1) + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + + #Image to Video Tab + with gr.Tab(label="Hunyuan-i2v") as i2v_tab: # Keep tab name consistent if needed elsewhere + # ... (Keep existing Rows for prompt, batch size, progress) ... + with gr.Row(): + with gr.Column(scale=4): + i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + + with gr.Column(scale=1): + i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress_i2v") # Unique elem_id + i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text_i2v") # Unique elem_id + + with gr.Row(): + i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + i2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + + with gr.Row(): + with gr.Column(): + i2v_input = gr.Image(label="Input Image", type="filepath") + # REMOVED i2v_strength slider, as hv_i2v_generate_video.py doesn't seem to use it based on the sample command + # i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale % (UI Only - affects W/H)") # Clarified UI only + original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + # Width and height inputs + with gr.Row(): + # Renamed width/height to avoid potential conflicts if they weren't already prefixed + i2v_width = gr.Number(label="New Width", value=720, step=16) # Default from sample + calc_height_btn = gr.Button("→") + calc_width_btn = gr.Button("←") + i2v_height = gr.Number(label="New Height", value=720, step=16) # Default from sample + i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=49) # Default from sample + i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) # Default from sample + i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) # Default from sample + i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=17.0) # Default from sample + i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="Embedded CFG Scale", value=7.0) # Default from sample + i2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale (CFG)", value=1.0) # Default from sample (usually 1.0 for no CFG) + + with gr.Column(): + i2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery_i2v", # Unique elem_id + allow_preview=True, + preview=True + ) + i2v_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") # Keep sending to original V2V + + # Add LoRA section for Image2Video + i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + i2v_lora_weights = [] + i2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + i2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + i2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + i2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states_i2v.pt", # Default from sample + allow_custom_value=True, + interactive=True + ) + i2v_vae = gr.Textbox(label="VAE Path", value="hunyuan/pytorch_model.pt") # Default from sample + i2v_te1 = gr.Textbox(label="Text Encoder 1 Path", value="hunyuan/llava_llama3_fp16.safetensors") # Default from sample + i2v_te2 = gr.Textbox(label="Text Encoder 2 Path", value="hunyuan/clip_l.safetensors") # Default from sample + i2v_clip_vision_path = gr.Textbox(label="CLIP Vision Path", value="hunyuan/llava_llama3_vision.safetensors") # Default from sample + i2v_save_path = gr.Textbox(label="Save Path", value="outputs") # Default from sample + with gr.Row(): + i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") # Default from sample + i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) # Not in sample, keep default False + i2v_use_fp8 = gr.Checkbox(label="Use FP8 DiT", value=False) # Not in sample, keep default False + i2v_fp8_llm = gr.Checkbox(label="Use FP8 LLM", value=False) # Not in sample, keep default False + i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") # Default from sample + i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=30) # Default from sample + # Add VAE tiling options like sample command + i2v_vae_chunk_size = gr.Number(label="VAE Chunk Size", value=32, step=1, info="For CausalConv3d, set 0 to disable") + i2v_vae_spatial_tile_min = gr.Number(label="VAE Spatial Tile Min Size", value=128, step=16, info="Set 0 to disable spatial tiling") + + # Video to Video Tab + with gr.Tab(id=2, label="Hunyuan v2v") as v2v_tab: + with gr.Row(): + with gr.Column(scale=4): + v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + v2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt (for SkyReels models)", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + v2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + v2v_input = gr.Video(label="Input Video", format="mp4") + v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and Height Inputs + with gr.Row(): + v2v_width = gr.Number(label="New Width", value=544, step=16) + v2v_calc_height_btn = gr.Button("→") + v2v_calc_width_btn = gr.Button("←") + v2v_height = gr.Number(label="New Height", value=544, step=16) + v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) + with gr.Column(): + v2v_output = gr.Gallery( + label="Generated Videos", + columns=[1], + rows=[1], + object_fit="contain", + height="auto" + ) + v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button + v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + v2v_lora_weights = [] + v2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + v2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + v2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + v2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + v2v_save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) + +### SKYREELS + + with gr.Tab(label="SkyReels-i2v") as skyreels_tab: + with gr.Row(): + with gr.Column(scale=4): + skyreels_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + skyreels_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") + with gr.Row(): + skyreels_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) + skyreels_input_folder = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images", + visible=False + ) + skyreels_folder_status = gr.Textbox( + label="Folder Status", + placeholder="Status will appear here", + interactive=False, + visible=False + ) + skyreels_validate_folder_btn = gr.Button("Validate Folder", visible=False) + skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + + # Scale slider as percentage + skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height inputs + with gr.Row(): + skyreels_width = gr.Number(label="New Width", value=544, step=16) + skyreels_calc_height_btn = gr.Button("→") + skyreels_calc_width_btn = gr.Button("←") + skyreels_height = gr.Number(label="New Height", value=544, step=16) + + skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) + skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) + + with gr.Column(): + skyreels_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for SKYREELS + skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + skyreels_lora_weights = [] + skyreels_lora_multipliers = [] + for i in range(4): + with gr.Column(): + skyreels_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + skyreels_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + skyreels_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("skyreels"), + value="skyreels_hunyuan_i2v_bf16.safetensors", + allow_custom_value=True, + interactive=True + ) + skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) + + # WanX Image to Video Tab + with gr.Tab(id=4, label="WanX-i2v") as wanx_i2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + lines=3, + ) + + with gr.Column(scale=1): + wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + wanx_input = gr.Image(label="Input Image", type="filepath") + with gr.Row(): + wanx_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) + wanx_input_folder = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images", + visible=False + ) + wanx_folder_status = gr.Textbox( + label="Folder Status", + placeholder="Status will appear here", + interactive=False, + visible=False + ) + wanx_validate_folder_btn = gr.Button("Validate Folder", visible=False) + with gr.Row(): + wanx_use_end_image = gr.Checkbox(label="use ending image", value=False) + wanx_input_end = gr.Image(label="End Image", type="filepath", visible=False) + wanx_trim_frames = gr.Checkbox(label="trim last 3 frames", value=True, visible=False, interactive=True) + + with gr.Row(): + wanx_use_fun_control = gr.Checkbox(label="Use Fun-Control Model", value=False) + wanx_control_video = gr.Video(label="Control Video for Fun-Control", visible=False, format="mp4") + wanx_control_strength = gr.Slider(minimum=0.1, maximum=2.0, step=0.05, value=1.0, + label="Control Strength", visible=False, + info="Adjust influence of control video (1.0 = normal)") + wanx_control_start = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=0.0, + label="Control Start (Fun-Control fade-in)", + visible=False, + info="When (0-1) in the timeline control influence is full after fade-in" + ) + wanx_control_end = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=1.0, + label="Control End (Fun-Control fade-out start)", + visible=False, + info="When (0-1) in the timeline control starts to fade out" + ) + wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height display + with gr.Row(): + wanx_width = gr.Number(label="Width", value=832, interactive=True) + wanx_calc_height_btn = gr.Button("→") + wanx_calc_width_btn = gr.Button("←") + wanx_height = gr.Number(label="Height", value=480, interactive=True) + wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_video_length = gr.Slider(minimum=1, maximum=401, step=4, label="Video Length in Frames", value=81) + wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, + info="Recommended: 3.0 for 480p, 5.0 for others") + wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + with gr.Accordion("Latent Preview (During Generation)", open=True): + wanx_enable_preview = gr.Checkbox(label="Enable Latent Preview", value=True) + wanx_preview_steps = gr.Slider(minimum=1, maximum=50, step=1, value=5, + label="Preview Every N Steps", info="Generates previews during the sampling loop.") + wanx_preview_output = gr.Gallery( + label="Latent Previews", columns=4, rows=2, object_fit="contain", height=300, + allow_preview=True, preview=True, show_label=True, elem_id="wanx_preview_gallery" + ) + wanx_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") + wanx_i2v_send_to_wanx_v2v_btn = gr.Button("Send Selected to WanX-v2v") + wanx_send_last_frame_btn = gr.Button("Send Last Frame to Input") + wanx_extend_btn = gr.Button("Extend Video") + wanx_frames_to_check = gr.Slider(minimum=1, maximum=100, step=1, value=30, + label="Frames to Check from End", + info="Number of frames from the end to check for sharpness") + wanx_send_sharpest_frame_btn = gr.Button("Extract Sharpest Frame") + wanx_trim_and_extend_btn = gr.Button("Trim Video & Prepare for Extension") + wanx_sharpest_frame_status = gr.Textbox(label="Status", interactive=False) + + # Add a new button for directly extending with the trimmed video + wanx_extend_with_trimmed_btn = gr.Button("Extend with Trimmed Video") + + # Add LoRA section for WanX-i2v similar to other tabs + wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_lora_weights = [] + wanx_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + # Update the wanx_task dropdown choices to include Fun-Control options + wanx_task = gr.Dropdown( + label="Task", + choices=["i2v-14B", "i2v-14B-FC", "i2v-14B-FC-1.1", "t2v-14B", "t2v-1.3B", "t2v-14B-FC", "t2v-1.3B-FC", "i2v-1.3B-new"], + value="i2v-14B", + info="Select model type. *-FC options enable Fun-Control features" + ) + wanx_dit_folder = gr.Textbox(label="DiT Model Folder", value="wan") + wanx_dit_path = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("wan"), # Use the existing function to get available models + value="wan2.1_i2v_720p_14B_fp16.safetensors", + allow_custom_value=True, + interactive=True + ) + wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") + wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) + + with gr.Column(): + wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_fp8_scaled = gr.Checkbox(label="Use Scaled FP8", value=False, info="For mixing fp16/bf16 and fp8 weights") + wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + # Add new row for Skip Layer Guidance options + with gr.Row(): + wanx_slg_layers = gr.Textbox(label="SLG Layers", value="", placeholder="Comma-separated layer indices, e.g. 1,5,10", info="Layers to skip for guidance") + wanx_slg_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG Start", value=0.0, info="When to start skipping layers (% of total steps)") + wanx_slg_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG End", value=1.0, info="When to stop skipping layers (% of total steps)") + + with gr.Row(): + wanx_enable_cfg_skip = gr.Checkbox(label="Enable CFG Skip (similar to teacache)", value=False) + with gr.Column(visible=False) as wanx_cfg_skip_options: + wanx_cfg_skip_mode = gr.Radio( + choices=["early", "late", "middle", "early_late", "alternate", "none"], + label="CFG Skip Mode", + value="none", + info="Controls which steps to apply CFG on" + ) + wanx_cfg_apply_ratio = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.7, + label="CFG Apply Ratio", + info="Ratio of steps to apply CFG (0.0-1.0). Lower values = faster, but less accurate" + ) + + #WanX-t2v Tab + + # WanX Text to Video Tab + with gr.Tab(id=5, label="WanX-t2v") as wanx_t2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_t2v_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_t2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") + wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") + wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, + info="Recommended: 3.0 for I2V with 480p, 5.0 for others") + wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_t2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + with gr.Accordion("Latent Preview (During Generation)", open=False): + wanx_t2v_enable_preview = gr.Checkbox(label="Enable Latent Preview", value=False) + wanx_t2v_preview_steps = gr.Slider(minimum=1, maximum=50, step=1, value=5, + label="Preview Every N Steps", info="Generates previews during the sampling loop.") + wanx_t2v_preview_output = gr.Gallery( + label="Latent Previews", columns=4, rows=2, object_fit="contain", height=300, + allow_preview=True, preview=True, show_label=True, elem_id="wanx_t2v_preview_gallery" + ) + wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan v2v") + wanx_t2v_send_to_wanx_v2v_btn = gr.Button("Send Selected to WanX-v2v") + + # Add LoRA section for WanX-t2v + wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_t2v_lora_weights = [] + wanx_t2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_t2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_t2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_t2v_task = gr.Dropdown( + label="Task", + choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], + value="t2v-14B", + info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" + ) + wanx_t2v_dit_path = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("wan"), + value="wan2.1_t2v_14B_fp16.safetensors", + allow_custom_value=True, + interactive=True + ) + wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") + wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, + info="Max 39 for 14B model, 29 for 1.3B model") + + with gr.Column(): + wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_t2v_fp8_scaled = gr.Checkbox(label="Use Scaled FP8", value=False, + info="For mixing fp16/bf16 and fp8 weights") + wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + # Add new row for Skip Layer Guidance options + with gr.Row(): + wanx_t2v_slg_layers = gr.Textbox(label="SLG Layers", value="", placeholder="Comma-separated layer indices, e.g. 1,5,10", info="Layers to skip for guidance") + wanx_t2v_slg_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG Start", value=0.0, info="When to start skipping layers (% of total steps)") + wanx_t2v_slg_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG End", value=1.0, info="When to stop skipping layers (% of total steps)") + wanx_t2v_use_random_folder = gr.Checkbox(visible=False, value=False, label="Use Random Images") + wanx_t2v_input_folder = gr.Textbox(visible=False, value="", label="Image Folder") + wanx_t2v_input_end = gr.Textbox(visible=False, value="none", label="End Frame") + + with gr.Row(): + wanx_t2v_enable_cfg_skip = gr.Checkbox(label="Enable CFG Skip (similar to teacache)", value=False) + with gr.Column(visible=False) as wanx_t2v_cfg_skip_options: + wanx_t2v_cfg_skip_mode = gr.Radio( + choices=["early", "late", "middle", "early_late", "alternate", "none"], + label="CFG Skip Mode", + value="none", + info="Controls which steps to apply CFG on" + ) + wanx_t2v_cfg_apply_ratio = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.7, + label="CFG Apply Ratio", + info="Ratio of steps to apply CFG (0.0-1.0). Lower values = faster, but less accurate" + ) + + #WanX-v2v Tab + with gr.Tab(id=6, label="WanX-v2v") as wanx_v2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_v2v_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_v2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_v2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + wanx_v2v_input = gr.Video(label="Input Video", format="mp4") + wanx_v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength", + info="0 = keep original, 1 = full generation") + wanx_v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + wanx_v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and Height Inputs + with gr.Row(): + wanx_v2v_width = gr.Number(label="New Width", value=832, step=32) + wanx_v2v_calc_height_btn = gr.Button("→") + wanx_v2v_calc_width_btn = gr.Button("←") + wanx_v2v_height = gr.Number(label="New Height", value=480, step=32) + wanx_v2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_v2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=40) + wanx_v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, + info="Recommended: 3.0 for 480p, 5.0 for others") + wanx_v2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_v2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + wanx_v2v_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") + + # Add LoRA section for WanX-v2v + wanx_v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_v2v_lora_weights = [] + wanx_v2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_v2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_v2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_v2v_task = gr.Dropdown( + label="Task", + choices=["t2v-14B", "t2v-1.3B"], + value="t2v-14B", + info="Model size: t2v-1.3B is faster, t2v-14B has higher quality" + ) + wanx_v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="wan") + wanx_v2v_dit_path = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("wan"), + value="wan2.1_t2v_14B_fp16.safetensors", + allow_custom_value=True, + interactive=True + ) + wanx_v2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_v2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_v2v_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_v2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_v2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, + info="Max 39 for 14B model, 29 for 1.3B model") + + with gr.Column(): + wanx_v2v_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_v2v_fp8_scaled = gr.Checkbox(label="Use Scaled FP8", value=False, + info="For mixing fp16/bf16 and fp8 weights") + wanx_v2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + # Add Skip Layer Guidance options + with gr.Row(): + wanx_v2v_slg_layers = gr.Textbox(label="SLG Layers", value="", placeholder="Comma-separated layer indices, e.g. 1,5,10", info="Layers to skip for guidance") + wanx_v2v_slg_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG Start", value=0.0, info="When to start skipping layers (% of total steps)") + wanx_v2v_slg_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="SLG End", value=1.0, info="When to stop skipping layers (% of total steps)") + + with gr.Row(): + wanx_v2v_enable_cfg_skip = gr.Checkbox(label="Enable CFG Skip (similar to teacache)", value=False) + with gr.Column(visible=False) as wanx_v2v_cfg_skip_options: + wanx_v2v_cfg_skip_mode = gr.Radio( + choices=["early", "late", "middle", "early_late", "alternate", "none"], + label="CFG Skip Mode", + value="none", + info="Controls which steps to apply CFG on" + ) + wanx_v2v_cfg_apply_ratio = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.7, + label="CFG Apply Ratio", + info="Ratio of steps to apply CFG (0.0-1.0). Lower values = faster, but less accurate" + ) + + #Video Info Tab + with gr.Tab("Video Info") as video_info_tab: + with gr.Row(): + video_input = gr.Video(label="Upload Video", interactive=True) + metadata_output = gr.JSON(label="Generation Parameters") + + with gr.Row(): + send_to_fpe_btn = gr.Button("Send to FramePack-Extension", variant="primary") + send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") + send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") + with gr.Row(): + send_to_framepack_btn = gr.Button("Send to FramePack", variant="primary") + send_to_wanx_i2v_btn = gr.Button("Send to WanX-i2v", variant="primary") + send_to_wanx_t2v_btn = gr.Button("Send to WanX-t2v", variant="primary") + send_to_wanx_v2v_btn = gr.Button("Send to WanX-v2v", variant="primary") + + + with gr.Row(): + status = gr.Textbox(label="Status", interactive=False) + + #Convert lora tab + with gr.Tab("Convert LoRA") as convert_lora_tab: + def suggest_output_name(file_obj) -> str: + """Generate suggested output name from input file""" + if not file_obj: + return "" + # Get input filename without extension and add MUSUBI + base_name = os.path.splitext(os.path.basename(file_obj.name))[0] + return f"{base_name}_MUSUBI" + + def convert_lora(input_file, output_name: str, target_format: str) -> str: + """Convert LoRA file to specified format""" + try: + if input_file is None: + return "Error: No input file selected" + + # Ensure output directory exists + os.makedirs("lora", exist_ok=True) + + # Construct output path + output_path = os.path.join("lora", f"{output_name}.safetensors") + + # Determine which script to use based on target_format + if target_format == "Hunyuan to FramePack": + script_name = "convert_hunyuan_to_framepack.py" + cmd = [ + sys.executable, + script_name, + "--input", input_file.name, + "--output", output_path + ] + print(f"Using '{script_name}' to convert {input_file.name} to {output_path} for FramePack.") + else: # Existing logic for "default" and "other" + script_name = "convert_lora.py" + cmd = [ + sys.executable, + script_name, + "--input", input_file.name, + "--output", output_path, + "--target", target_format.lower() + ] + + print(f"Running conversion command: {' '.join(cmd)}") + + # Check if the selected script file exists + if not os.path.exists(script_name): + return f"Error: Conversion script '{script_name}' not found. Please ensure it's in the same directory as h1111.py." + + # Execute conversion + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + console_output = result.stdout if result.stdout else "" + if result.stderr: + console_output += f"\n--- Script STDERR ---\n{result.stderr}" + if not console_output.strip(): + console_output = "Conversion script completed with no output." + if os.path.exists(output_path): + console_output += f"\n[UI Info] Output file confirmed by h1111.py at: {output_path}" + else: + console_output += f"\n[UI Warning] Output file NOT found by h1111.py at expected location: {output_path}" + return console_output.strip() + except subprocess.CalledProcessError as e: + error_message = f"Conversion Script Error (Exit Code: {e.returncode}):\n" + if e.stdout and e.stdout.strip(): + error_message += f"--- Script STDOUT ---\n{e.stdout.strip()}\n" + if e.stderr and e.stderr.strip(): + error_message += f"--- Script STDERR ---\n{e.stderr.strip()}\n" + if not (e.stdout and e.stdout.strip()) and not (e.stderr and e.stderr.strip()): + error_message += "Script produced no output on STDOUT or STDERR." + + print(f"Subprocess error details logged to console. UI will show combined script output.") # Log for server console + return error_message.strip() + + + with gr.Row(): + input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) + output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") + format_radio = gr.Radio( + choices=["default", "other", "Hunyuan to FramePack"], # <-- Added new choice here + value="default", + label="Target Format", + info="Choose 'default' for H1111/MUSUBI format, 'other' for diffusion pipe format, or 'Hunyuan to FramePack' for FramePack compatibility." + ) + + with gr.Row(): + convert_btn = gr.Button("Convert LoRA", variant="primary") + status_output = gr.Textbox(label="Status", interactive=False) + + # Automatically update output name when file is selected + input_file.change( + fn=suggest_output_name, + inputs=[input_file], + outputs=[output_name] + ) + + # Handle conversion + convert_btn.click( + fn=convert_lora, + inputs=[input_file, output_name, format_radio], + outputs=status_output + ) + with gr.Tab("Model Merging") as model_merge_tab: + with gr.Row(): + with gr.Column(): + # Model selection + dit_model = gr.Dropdown( + label="Base DiT Model", + choices=["mp_rank_00_model_states.pt"], + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + with gr.Row(): + with gr.Column(): + # Output model name + output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + merge_btn = gr.Button("Merge Models", variant="primary") + merge_status = gr.Textbox(label="Status", interactive=False) + with gr.Row(): + # LoRA selection section (similar to Text2Video) + merge_lora_weights = [] + merge_lora_multipliers = [] + for i in range(4): + with gr.Column(): + merge_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + merge_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + + #Event handlers etc + +# Toggle visibility of End Frame controls and DiT path based on fpe_use_normal_framepack + def toggle_fpe_normal_framepack_options(use_normal_fp): + f1_dit_path = "hunyuan/FramePack_F1_I2V_HY_20250503.safetensors" + normal_fp_dit_path = "hunyuan/FramePackI2V_HY_bf16.safetensors" + + updated_dit_path = normal_fp_dit_path if use_normal_fp else f1_dit_path + + # Check if the target path exists and fallback if necessary + if not os.path.exists(updated_dit_path): + fallback_path = f1_dit_path if use_normal_fp and os.path.exists(f1_dit_path) else normal_fp_dit_path if not use_normal_fp and os.path.exists(normal_fp_dit_path) else None + if fallback_path and os.path.exists(fallback_path): + print(f"Warning: DiT path '{updated_dit_path}' not found. Falling back to '{fallback_path}'.") + updated_dit_path = fallback_path + else: # If preferred and fallback are missing, stick to the intended one and let later checks handle it. + print(f"Warning: DiT path '{updated_dit_path}' not found. No fallback available or fallback also missing.") + + return ( + gr.update(visible=use_normal_fp), # fpe_end_frame_accordion + gr.update(visible=not use_normal_fp), # fpe_start_guidance_accordion (NEW) + gr.update(value=updated_dit_path), # fpe_transformer_path + gr.update(visible=use_normal_fp) # fpe_fp8_llm + ) + + fpe_use_normal_framepack.change( + fn=toggle_fpe_normal_framepack_options, + inputs=[fpe_use_normal_framepack], + outputs=[ + fpe_end_frame_accordion, + fpe_start_guidance_accordion, # NEW output + fpe_transformer_path, + fpe_fp8_llm + ] + ) + + fpe_generate_btn.click( + fn=process_framepack_extension_video, + inputs=[ + fpe_input_video, fpe_prompt, fpe_negative_prompt, fpe_seed, fpe_batch_count, + fpe_use_normal_framepack, fpe_end_frame, fpe_end_frame_weight, + fpe_resolution_max_dim, fpe_total_second_length, fpe_latent_window_size, + fpe_steps, fpe_cfg_scale, fpe_distilled_guidance_scale, + fpe_gpu_memory_preservation, fpe_use_teacache, fpe_no_resize, fpe_mp4_crf, + fpe_num_clean_frames, fpe_vae_batch_size, fpe_save_path, + # Model Paths + fpe_transformer_path, fpe_vae_path, fpe_text_encoder_path, + fpe_text_encoder_2_path, fpe_image_encoder_path, + # Advanced + fpe_attn_mode, fpe_fp8_llm, fpe_vae_chunk_size, fpe_vae_spatial_tile_sample_min_size, + # LoRAs + fpe_lora_folder, + fpe_lora_weights_ui[0], fpe_lora_multipliers_ui[0], + fpe_lora_weights_ui[1], fpe_lora_multipliers_ui[1], + fpe_lora_weights_ui[2], fpe_lora_multipliers_ui[2], + fpe_lora_weights_ui[3], fpe_lora_multipliers_ui[3], + # Preview (UI state, not directly passed to scripts) + fpe_enable_preview, fpe_preview_interval, + fpe_extension_only, + fpe_start_guidance_image, + fpe_start_guidance_image_clip_weight, + fpe_use_guidance_image_as_first_latent, + ], + outputs=[ + fpe_output_gallery, + fpe_preview_output_component, + fpe_batch_progress, + fpe_progress_text + ], + queue=True + ) + + fpe_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + def handle_fpe_gallery_select(evt: gr.SelectData) -> int: + return evt.index + fpe_output_gallery.select(fn=handle_fpe_gallery_select, outputs=fpe_selected_index) + + fpe_lora_refresh_outputs_list = [] + for i in range(len(fpe_lora_weights_ui)): + fpe_lora_refresh_outputs_list.extend([fpe_lora_weights_ui[i], fpe_lora_multipliers_ui[i]]) + + fpe_refresh_lora_btn.click( + fn=refresh_lora_dropdowns_simple, + inputs=[fpe_lora_folder], + outputs=fpe_lora_refresh_outputs_list + ) + + def change_to_framepack_tab(): + return gr.Tabs(selected=10) # FramePack tab has id=10 + + def handle_send_to_framepack_tab(metadata: dict) -> Tuple[str, dict, str]: # Added str return type for state value + """Prepare parameters specifically for the FramePack tab.""" + if not metadata: + # Return default/empty values for status, params, and original_dims state + return "No parameters to send", {}, "" + + # Extract the value intended for the state here + original_dims_value = metadata.get("original_dims_str", "") + + # Return status message, the full metadata for params_state, and the specific value for framepack_original_dims state + return "Parameters ready for FramePack", metadata, original_dims_value + + send_to_framepack_btn.click( + fn=handle_send_to_framepack_tab, + inputs=[metadata_output], + outputs=[status, params_state, framepack_original_dims] # Add framepack_original_dims here + ).then( + # This lambda now prepares updates for UI components (32 items) + lambda params: ( + # Prepare the full list of 32 update values first + ( + # Fetch LoRA lists from params, default to empty lists if not found + (weights_from_meta := params.get("lora_weights", [])), + (mults_from_meta := params.get("lora_multipliers", [])), + # Create explicitly padded lists ensuring 4 elements + (padded_weights := (weights_from_meta + ["None"] * 4)[:4]), + (padded_mults := ([float(m) for m in mults_from_meta] + [1.0] * 4)[:4]), # Ensure multipliers are floats + + # Build the list of update values + [ + params.get("prompt", "cinematic video of a cat wizard casting a spell"), + params.get("negative_prompt", ""), + # Handle resolution: Prioritize explicit W/H if valid (divisible by 8), else use target_res, else default + gr_update(value=int(params["video_width"])) if params.get("video_width") and int(params.get("video_width", 0)) > 0 and int(params.get("video_width", 0)) % 8 == 0 else gr_update(value=None), + gr_update(value=int(params["video_height"])) if params.get("video_height") and int(params.get("video_height", 0)) > 0 and int(params.get("video_height", 0)) % 8 == 0 else gr_update(value=None), + # Use target resolution only if explicit width/height are *not* validly provided from metadata + gr_update(value=int(params.get("target_resolution"))) if not (params.get("video_width") and int(params.get("video_width", 0)) > 0 and int(params.get("video_width", 0)) % 8 == 0) and params.get("target_resolution") else gr_update(value=640), + params.get("video_seconds", 5.0), + params.get("fps", 30), + params.get("seed", -1), + params.get("infer_steps", 25), + params.get("embedded_cfg_scale", 10.0), # Distilled Guidance + params.get("guidance_scale", 1.0), # CFG + params.get("guidance_rescale", 0.0), # RS + params.get("sample_solver", "unipc"), + # Unpack the *padded* lists + *padded_weights, # 4 items + *padded_mults, # 4 items + # Performance/Memory + params.get("fp8", False), + params.get("fp8_scaled", False), + params.get("fp8_llm", False), + params.get("blocks_to_swap", 26), + params.get("bulk_decode", False), + params.get("attn_mode", "sdpa"), + params.get("vae_chunk_size", 32), + params.get("vae_spatial_tile_sample_min_size", 128), + params.get("device", ""), + # End Frame Blending Params - Use UI defaults + params.get("end_frame_influence", "last"), + params.get("end_frame_weight", 0.5), + params.get("is_f1", False) + ] + )[-1] # Return the list of values we just built + ) if params else [gr.update()] * 32, + inputs=params_state, # Read parameters from state + outputs=[ + # Map to FramePack components (UI only - 32 components) + framepack_prompt, + framepack_negative_prompt, + framepack_width, # Will be updated or set to None + framepack_height, # Will be updated or set to None + framepack_target_resolution, # Will be updated or set to None/default + framepack_total_second_length, + framepack_fps, + framepack_seed, + framepack_steps, + framepack_distilled_guidance_scale, + framepack_guidance_scale, + framepack_guidance_rescale, + framepack_sample_solver, + # LoRAs (unpacking the lists - 8 components total) + *framepack_lora_weights, # 4 components + *framepack_lora_multipliers, # 4 components + # Performance/Memory + framepack_fp8, + framepack_fp8_scaled, + framepack_fp8_llm, + framepack_blocks_to_swap, + framepack_bulk_decode, + framepack_attn_mode, + framepack_vae_chunk_size, + framepack_vae_spatial_tile_sample_min_size, + framepack_device, + # Map to new UI components + framepack_end_frame_influence, + framepack_end_frame_weight, + framepack_is_f1 + ] + ).then( + fn=change_to_framepack_tab, # Switch to the FramePack tab + inputs=None, + outputs=[tabs] + ) + # Connect FramePack Generate button + def update_framepack_image_dimensions(image): + """Update FramePack dimensions from uploaded image, store raw dims, set default target res""" + if image is None: + return "", gr.update(value=None), gr.update(value=None), gr.update(value=640) # Reset W/H, default target res + try: + img = Image.open(image) + w, h = img.size + original_dims_str = f"{w}x{h}" # Store raw WxH + target_res_default = 640 + # Return original dims string, clear explicit W/H, set default target res + return original_dims_str, gr.update(value=None), gr.update(value=None), gr.update(value=target_res_default) + except Exception as e: + print(f"Error reading image dimensions: {e}") + return "", gr.update(value=None), gr.update(value=None), gr.update(value=640) # Fallback + + framepack_input_image.change( + fn=update_framepack_image_dimensions, + inputs=[framepack_input_image], + outputs=[framepack_original_dims, framepack_width, framepack_height, framepack_target_resolution] + ) + + framepack_prompt.change(fn=count_prompt_tokens, inputs=framepack_prompt, outputs=framepack_token_counter) + # If explicit width/height is set (and valid), clear target resolution + def clear_target_res_on_explicit_change(val): + return gr.update(value=None) if val is not None and val > 0 else gr.update() + + framepack_scale_slider.change( + fn=update_framepack_from_scale, + inputs=[framepack_scale_slider, framepack_original_dims], + outputs=[framepack_width, framepack_height, framepack_target_resolution] # Also clears target res + ) + + framepack_calc_width_btn.click( + fn=calculate_framepack_width, + inputs=[framepack_height, framepack_original_dims], + outputs=[framepack_width] + ).then( + fn=clear_target_res_on_explicit_change, # Clear target res if width is manually set + inputs=[framepack_width], + outputs=[framepack_target_resolution] + ) + + framepack_calc_height_btn.click( + fn=calculate_framepack_height, + inputs=[framepack_width, framepack_original_dims], + outputs=[framepack_height] + ).then( + fn=clear_target_res_on_explicit_change, # Clear target res if height is manually set + inputs=[framepack_height], + outputs=[framepack_target_resolution] + ) + + framepack_width.change( + fn=clear_target_res_on_explicit_change, + inputs=[framepack_width], + outputs=[framepack_target_resolution] + ) + framepack_height.change( + fn=clear_target_res_on_explicit_change, + inputs=[framepack_height], + outputs=[framepack_target_resolution] + ) + + # If target resolution is set (and valid), clear explicit width/height + def clear_explicit_res_on_target_change(target_res): + return (gr.update(value=None), gr.update(value=None)) if target_res is not None and target_res > 0 else (gr.update(), gr.update()) + + framepack_target_resolution.change( + fn=clear_explicit_res_on_target_change, + inputs=[framepack_target_resolution], + outputs=[framepack_width, framepack_height] + ) + framepack_use_random_folder.change( + fn=lambda use_folder_mode: ( + gr.update(visible=use_folder_mode), # framepack_input_folder_path + gr.update(visible=use_folder_mode), # framepack_folder_options_row (which contains validate button and status) + gr.update(visible=not use_folder_mode) # framepack_input_image + ), + inputs=[framepack_use_random_folder], + outputs=[framepack_input_folder_path, framepack_folder_options_row, framepack_input_image] + ) + + # Validate folder button handler + framepack_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], # Reuse existing helper + inputs=[framepack_input_folder_path], + outputs=[framepack_folder_status_text] + ) + def toggle_f1_model_path(is_f1): + f1_path = "hunyuan/FramePack_F1_I2V_HY_20250503.safetensors" + standard_path = "hunyuan/FramePackI2V_HY_bf16.safetensors" + target_path = f1_path if is_f1 else standard_path + + # Check if the target path exists + if not os.path.exists(target_path): + print(f"Warning: F1 model path '{target_path}' not found. Falling back to standard path.") + # Optionally fall back or just update with the non-existent path + # Let's fall back to standard if F1 is missing, but keep standard if standard is missing (error handled later) + if is_f1 and os.path.exists(standard_path): + print(f"Falling back to standard path: {standard_path}") + return gr.update(value=standard_path) + elif is_f1: + print(f"F1 path missing and standard path also missing. Cannot automatically switch.") + # Return the intended (missing) path, error will be caught later + return gr.update(value=target_path) + else: # Standard path is missing + print(f"Warning: Standard path '{standard_path}' not found.") + return gr.update(value=target_path) # Return the missing standard path + + print(f"Switching DiT path to: {target_path}") + return gr.update(value=target_path) + + framepack_is_f1.change( + fn=toggle_f1_model_path, + inputs=[framepack_is_f1], + outputs=[framepack_transformer_path] + ) + + framepack_generate_btn.click( + fn=process_framepack_video, + inputs=[ + framepack_prompt, framepack_negative_prompt, framepack_input_image, + framepack_input_end_frame, framepack_end_frame_influence, framepack_end_frame_weight, + framepack_transformer_path, framepack_vae_path, framepack_text_encoder_path, + framepack_text_encoder_2_path, framepack_image_encoder_path, + framepack_target_resolution, framepack_width, framepack_height, framepack_original_dims, + framepack_total_second_length, framepack_video_sections, framepack_fps, framepack_seed, framepack_steps, + framepack_distilled_guidance_scale, framepack_guidance_scale, framepack_guidance_rescale, + framepack_sample_solver, framepack_latent_window_size, + framepack_fp8, framepack_fp8_scaled, framepack_fp8_llm, + framepack_blocks_to_swap, framepack_bulk_decode, framepack_attn_mode, + framepack_vae_chunk_size, framepack_vae_spatial_tile_sample_min_size, + framepack_device, + framepack_use_teacache, + framepack_teacache_steps, + framepack_teacache_thresh, + framepack_batch_size, framepack_save_path, + framepack_lora_folder, + framepack_enable_preview, + framepack_preview_every_n_sections, + framepack_is_f1, + framepack_use_random_folder, + framepack_input_folder_path, + *framepack_secs, *framepack_sec_prompts, *framepack_sec_images, + *framepack_lora_weights, *framepack_lora_multipliers + ], + outputs=[ + framepack_output, # Main gallery + framepack_preview_output, # Preview video player + framepack_batch_progress, # Status text + framepack_progress_text # Progress text + ], + queue=True + ) + + framepack_random_seed.click( + fn=set_random_seed, + inputs=None, + outputs=[framepack_seed] + ) + # Connect FramePack Stop button + framepack_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Connect FramePack Gallery selection + def handle_framepack_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + framepack_output.select( + fn=handle_framepack_gallery_select, + outputs=framepack_selected_index + ) + + # FramePack LoRA Refresh Button Handler + framepack_lora_refresh_outputs = [] + for i in range(len(framepack_lora_weights)): + framepack_lora_refresh_outputs.extend([framepack_lora_weights[i], framepack_lora_multipliers[i]]) + + framepack_refresh_lora_btn.click( + fn=refresh_lora_dropdowns_simple, # Use the new simplified function + inputs=[framepack_lora_folder], # Only needs the folder path as input + outputs=framepack_lora_refresh_outputs # Still outputs updates to all 8 components + ) + def trigger_skip(): + """Sets the skip event and returns a status message.""" + print("FramePack Skip button clicked, setting skip_event.") + skip_event.set() + return "Skip signal sent..." + + framepack_skip_btn.click( + fn=trigger_skip, + inputs=None, + outputs=[framepack_batch_progress], # Update status text + queue=False # Send signal immediately + ) + + def toggle_fun_control(use_fun_control): + """Toggle control video visibility and update task suffix""" + # Only update visibility, don't try to set paths + return gr.update(visible=use_fun_control) + + def update_task_for_funcontrol(use_fun_control, current_task): + """Add or remove -FC suffix from task based on checkbox""" + if use_fun_control: + if not current_task.endswith("-FC"): + if "i2v" in current_task: + return "i2v-14B-FC" + elif "t2v" in current_task: + return "t2v-14B-FC" + return current_task + else: + if current_task.endswith("-FC"): + return current_task.replace("-FC", "") + return current_task + + wanx_use_fun_control.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)), + inputs=[wanx_use_fun_control], + outputs=[wanx_control_video, wanx_control_strength, wanx_control_start, wanx_control_end] + ) + + # Make task change update checkbox state + def update_from_task(task): + """Update Fun-Control checkbox and control video visibility based on task""" + is_fun_control = "-FC" in task + return gr.update(value=is_fun_control), gr.update(visible=is_fun_control) + + wanx_task.change( + fn=update_from_task, + inputs=[wanx_task], + outputs=[wanx_use_fun_control, wanx_control_video] + ) + wanx_enable_cfg_skip.change( + fn=lambda x: gr.update(visible=x), + inputs=[wanx_enable_cfg_skip], + outputs=[wanx_cfg_skip_options] + ) + + wanx_t2v_enable_cfg_skip.change( + fn=lambda x: gr.update(visible=x), + inputs=[wanx_t2v_enable_cfg_skip], + outputs=[wanx_t2v_cfg_skip_options] + ) + + wanx_v2v_enable_cfg_skip.change( + fn=lambda x: gr.update(visible=x), + inputs=[wanx_v2v_enable_cfg_skip], + outputs=[wanx_v2v_cfg_skip_options] + ) + + #WanX-v2v tab functions + wanx_v2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_v2v_prompt, outputs=wanx_v2v_token_counter) + + # Stop button handler + wanx_v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Video input handling + wanx_v2v_input.change( + fn=update_wanx_v2v_dimensions, + inputs=[wanx_v2v_input], + outputs=[wanx_v2v_original_dims, wanx_v2v_width, wanx_v2v_height] + ) + + # Flow shift recommendation button + wanx_v2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_v2v_width, wanx_v2v_height], + outputs=[wanx_v2v_flow_shift] + ) + + # Width/height calculation buttons + wanx_v2v_calc_width_btn.click( + fn=calculate_wanx_width, # Reuse function from WanX tabs + inputs=[wanx_v2v_height, wanx_v2v_original_dims], + outputs=[wanx_v2v_width] + ) + + wanx_v2v_calc_height_btn.click( + fn=calculate_wanx_height, # Reuse function from WanX tabs + inputs=[wanx_v2v_width, wanx_v2v_original_dims], + outputs=[wanx_v2v_height] + ) + + # Scale slider handling for adjusting dimensions + wanx_v2v_scale_slider.change( + fn=update_wanx_from_scale, # Reuse function from WanX tabs + inputs=[wanx_v2v_scale_slider, wanx_v2v_original_dims], + outputs=[wanx_v2v_width, wanx_v2v_height] + ) + + def change_to_wanx_v2v_tab(): + return gr.Tabs(selected=6) + + def send_wanx_t2v_to_v2v_input(gallery, selected_index): + """Send the selected WanX-t2v video to WanX-v2v input""" + if gallery is None or not gallery: + return None, None + + if selected_index is None and len(gallery) == 1: + selected_index = 0 + + if selected_index is None or selected_index >= len(gallery): + return None, None + + # Get the video path + item = gallery[selected_index] + video_path = parse_video_path(item) + + return video_path, "Video sent from WanX-t2v tab" + + wanx_t2v_send_to_wanx_v2v_btn.click( + fn=send_wanx_t2v_to_v2v_input, + inputs=[wanx_t2v_output, wanx_t2v_selected_index], + outputs=[wanx_v2v_input, wanx_v2v_batch_progress] + ).then( + fn=lambda prompt: prompt, + inputs=[wanx_t2v_prompt], + outputs=[wanx_v2v_prompt] + ).then( + fn=change_to_wanx_v2v_tab, + inputs=None, + outputs=[tabs] + ) + + # Send video from WanX-i2v to WanX-v2v + wanx_i2v_send_to_wanx_v2v_btn.click( + fn=send_wanx_t2v_to_v2v_input, # Reuse the same function + inputs=[wanx_output, wanx_i2v_selected_index], + outputs=[wanx_v2v_input, wanx_v2v_batch_progress] + ).then( + fn=lambda prompt: prompt, + inputs=[wanx_prompt], + outputs=[wanx_v2v_prompt] + ).then( + fn=change_to_wanx_v2v_tab, + inputs=None, + outputs=[tabs] + ) + + # Update model paths when task changes + def update_model_paths_for_task(task): + if "1.3B" in task: + return gr.update(value="wan/wan2.1_t2v_1.3B_fp16.safetensors") + else: + return gr.update(value="wan/wan2.1_t2v_14B_fp16.safetensors") + + wanx_v2v_task.change( + fn=update_model_paths_for_task, + inputs=[wanx_v2v_task], + outputs=[wanx_v2v_dit_path] + ) + + # Generate button handler + wanx_v2v_generate_btn.click( + fn=wanx_v2v_batch_handler, + inputs=[ + wanx_v2v_prompt, + wanx_v2v_negative_prompt, + wanx_v2v_input, + wanx_v2v_width, + wanx_v2v_height, + wanx_v2v_video_length, + wanx_v2v_fps, + wanx_v2v_infer_steps, + wanx_v2v_flow_shift, + wanx_v2v_guidance_scale, + wanx_v2v_strength, + wanx_v2v_seed, + wanx_v2v_batch_size, + wanx_v2v_task, + wanx_v2v_dit_folder, + wanx_v2v_dit_path, + wanx_v2v_vae_path, + wanx_v2v_t5_path, + wanx_v2v_save_path, + wanx_v2v_output_type, + wanx_v2v_sample_solver, + wanx_v2v_exclude_single_blocks, + wanx_v2v_attn_mode, + wanx_v2v_block_swap, + wanx_v2v_fp8, + wanx_v2v_fp8_scaled, + wanx_v2v_fp8_t5, + wanx_v2v_lora_folder, + wanx_v2v_slg_layers, + wanx_v2v_slg_start, + wanx_v2v_slg_end, + wanx_v2v_enable_cfg_skip, + wanx_v2v_cfg_skip_mode, + wanx_v2v_cfg_apply_ratio, + *wanx_v2v_lora_weights, + *wanx_v2v_lora_multipliers + ], + outputs=[wanx_v2v_output, wanx_v2v_batch_progress, wanx_v2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_v2v_batch_size], + outputs=wanx_v2v_selected_index + ) + + # Gallery selection handling + wanx_v2v_output.select( + fn=handle_wanx_v2v_gallery_select, + outputs=wanx_v2v_selected_index + ) + def change_to_tab_two(): + return gr.Tabs(selected=2) + + # Send to Hunyuan v2v tab + wanx_v2v_send_to_v2v_btn.click( + fn=send_wanx_v2v_to_hunyuan_v2v, + inputs=[ + wanx_v2v_output, + wanx_v2v_prompt, + wanx_v2v_selected_index, + wanx_v2v_width, + wanx_v2v_height, + wanx_v2v_video_length, + wanx_v2v_fps, + wanx_v2v_infer_steps, + wanx_v2v_seed, + wanx_v2v_flow_shift, + wanx_v2v_guidance_scale, + wanx_v2v_negative_prompt + ], + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + + # Add refresh button handler for WanX-v2v tab + wanx_v2v_refresh_outputs = [wanx_v2v_dit_path] # This is one output + for i in range(4): + wanx_v2v_refresh_outputs.extend([wanx_v2v_lora_weights[i], wanx_v2v_lora_multipliers[i]]) # This adds 8 more outputs + + wanx_v2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, # We need to use this function instead + inputs=[wanx_v2v_dit_folder, wanx_v2v_lora_folder, wanx_v2v_dit_path] + wanx_v2v_lora_weights + wanx_v2v_lora_multipliers, + outputs=wanx_v2v_refresh_outputs + ) + + # Add function to send videos from Video Info tab to WanX-v2v + def send_to_wanx_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: + """Handle both parameters and video transfer from Video Info to WanX-v2v tab with debugging""" + if not video_path: + return "No video selected", {}, None + + # Print debug information + print(f"VIDEO INFO TO WANX-V2V TRANSFER:") + print(f"Original metadata: {metadata}") + print(f"Video path: {video_path}") + + # Special handling for WanX-v2v prompt fields + # Create a copy of metadata with explicit prompt fields + enhanced_metadata = metadata.copy() + if "prompt" in metadata: + enhanced_metadata["wanx_v2v_prompt"] = metadata["prompt"] + if "negative_prompt" in metadata: + enhanced_metadata["wanx_v2v_negative_prompt"] = metadata["negative_prompt"] + + print(f"Enhanced metadata: {enhanced_metadata}") + + status_msg, params = send_parameters_to_tab(enhanced_metadata, "wanx_v2v") + print(f"Mapped parameters: {params}") + + return f"Parameters ready for WanX-v2v (DEBUG INFO IN CONSOLE)", enhanced_metadata, video_path + + # Then, implement a proper handler to change to the WanX-v2v tab + def change_to_wanx_v2v_tab(): + return gr.Tabs(selected=6) # WanX-v2v is tab index 6 + + # Next, connect the button to the functions with proper parameter mapping + send_to_wanx_v2v_btn.click( + fn=lambda m, v: handle_send_to_wanx_tab(m, 'wanx_v2v', v), + inputs=[metadata_output, video_input], + outputs=[status, params_state, wanx_v2v_input] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 40), + params.get("seed", -1), + params.get("flow_shift", 5.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0), + params.get("negative_prompt", ""), + params.get("strength", 0.75), + *[params.get("lora_weights", ["None"]*4)[i] if isinstance(params.get("lora_weights", []), list) and i < len(params.get("lora_weights", [])) else "None" for i in range(4)], + *[params.get("lora_multipliers", [1.0]*4)[i] if isinstance(params.get("lora_multipliers", []), list) and i < len(params.get("lora_multipliers", [])) else 1.0 for i in range(4)] + ] if params else [gr.update()]*21, + inputs=params_state, + outputs=[ + wanx_v2v_prompt, + wanx_v2v_width, + wanx_v2v_height, + wanx_v2v_video_length, + wanx_v2v_fps, + wanx_v2v_infer_steps, + wanx_v2v_seed, + wanx_v2v_flow_shift, + wanx_v2v_guidance_scale, + wanx_v2v_attn_mode, + wanx_v2v_block_swap, + wanx_v2v_negative_prompt, + wanx_v2v_strength, + *wanx_v2v_lora_weights, + *wanx_v2v_lora_multipliers + ] + ).then( + fn=change_to_wanx_v2v_tab, inputs=None, outputs=[tabs] + ) + + #Video Extension + wanx_send_last_frame_btn.click( + fn=send_last_frame_handler, + inputs=[wanx_output, wanx_i2v_selected_index], + outputs=[wanx_input, wanx_base_video] + ) + + wanx_extend_btn.click( + fn=prepare_for_batch_extension, + inputs=[wanx_input, wanx_base_video, wanx_batch_size], + outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] + ).then( + fn=lambda batch_size, base_video: + "Starting batch extension..." if base_video and batch_size > 0 else + "Error: Missing base video or invalid batch size", + inputs=[wanx_batch_size, wanx_base_video], + outputs=[wanx_batch_progress] + ).then( + # Process batch extension one at a time + fn=process_batch_extension, + inputs=[ + wanx_prompt, + wanx_negative_prompt, + wanx_input, # Input image (last frame) + wanx_base_video, # Base video to extend + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_flow_shift, + wanx_guidance_scale, + wanx_seed, + wanx_batch_size, + wanx_task, + wanx_dit_folder, # <<< Pass the folder path + wanx_dit_path, # <<< Pass the model filename + wanx_vae_path, + wanx_t5_path, + wanx_clip_path, + wanx_save_path, + wanx_output_type, + wanx_sample_solver, + wanx_exclude_single_blocks, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, + wanx_fp8_scaled, + wanx_fp8_t5, + wanx_lora_folder, + wanx_slg_layers, + wanx_slg_start, + wanx_slg_end, + # Pass LoRA weights and multipliers individually + wanx_lora_weights[0], + wanx_lora_weights[1], + wanx_lora_weights[2], + wanx_lora_weights[3], + wanx_lora_multipliers[0], + wanx_lora_multipliers[1], + wanx_lora_multipliers[2], + wanx_lora_multipliers[3] + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] + ) + + # Extract and send sharpest frame to input + wanx_send_sharpest_frame_btn.click( + fn=send_sharpest_frame_handler, + inputs=[wanx_output, wanx_i2v_selected_index, wanx_frames_to_check], + outputs=[wanx_input, wanx_base_video, wanx_sharpest_frame_number, wanx_sharpest_frame_status] + ) + + # Trim video to sharpest frame and prepare for extension + wanx_trim_and_extend_btn.click( + fn=trim_and_prepare_for_extension, + inputs=[wanx_base_video, wanx_sharpest_frame_number, wanx_save_path], + outputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status] + ).then( + fn=lambda path, status: (path, status if "Failed" in status else "Video trimmed successfully and ready for extension"), + inputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status], + outputs=[wanx_base_video, wanx_sharpest_frame_status] + ) + + wanx_extend_with_trimmed_btn.click( + # Prepare step: Sets the base video to the trimmed video path + fn=prepare_for_batch_extension, + inputs=[wanx_input, wanx_trimmed_video_path, wanx_batch_size], # Use trimmed video path here + outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] # Update base_video state + ).then( + # Actual extension processing step + fn=process_batch_extension, + inputs=[ + wanx_prompt, + wanx_negative_prompt, + wanx_input, # Input image (sharpest frame) + wanx_trimmed_video_path, # Base video to extend (the trimmed one) + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_flow_shift, + wanx_guidance_scale, + wanx_seed, + wanx_batch_size, + wanx_task, + wanx_dit_folder, # <<< Pass the folder path + wanx_dit_path, # <<< Pass the model filename + wanx_vae_path, + wanx_t5_path, + wanx_clip_path, + wanx_save_path, + wanx_output_type, + wanx_sample_solver, + wanx_exclude_single_blocks, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, + wanx_fp8_scaled, + wanx_fp8_t5, + wanx_lora_folder, + wanx_slg_layers, + wanx_slg_start, + wanx_slg_end, + # Pass LoRA weights and multipliers individually + wanx_lora_weights[0], + wanx_lora_weights[1], + wanx_lora_weights[2], + wanx_lora_weights[3], + wanx_lora_multipliers[0], + wanx_lora_multipliers[1], + wanx_lora_multipliers[2], + wanx_lora_multipliers[3] + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] + ) + + #Video Info + def handle_send_to_wanx_tab(metadata, target_tab, video_path=None): + """Common handler for sending video parameters to WanX tabs""" + if not metadata: + return "No parameters to send", {}, None # Return three values + + # Tab names for clearer messages + tab_names = { + 'wanx_i2v': 'WanX-i2v', + 'wanx_t2v': 'WanX-t2v', + 'wanx_v2v': 'WanX-v2v' + } + + # Just pass through all parameters - we'll use them in the .then() function + return f"Parameters ready for {tab_names.get(target_tab, target_tab)}", metadata, video_path + + def change_to_wanx_i2v_tab(): + return gr.Tabs(selected=4) # WanX-i2v tab index + + def change_to_wanx_t2v_tab(): + return gr.Tabs(selected=5) # WanX-t2v tab index + + + send_to_wanx_i2v_btn.click( + fn=lambda m: ("Parameters ready for WanX-i2v", m), + inputs=[metadata_output], + outputs=[status, params_state] + ).then( + # Reusing the same pattern as other tab transfers with LoRA handling + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 40), + params.get("seed", -1), + params.get("flow_shift", 3.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0), + params.get("task", "i2v-14B"), + params.get("negative_prompt", ""), + *[params.get("lora_weights", ["None"]*4)[i] if isinstance(params.get("lora_weights", []), list) and i < len(params.get("lora_weights", [])) else "None" for i in range(4)], + *[params.get("lora_multipliers", [1.0]*4)[i] if isinstance(params.get("lora_multipliers", []), list) and i < len(params.get("lora_multipliers", [])) else 1.0 for i in range(4)] + ] if params else [gr.update()]*20, + inputs=params_state, + outputs=[ + wanx_prompt, wanx_width, wanx_height, wanx_video_length, + wanx_fps, wanx_infer_steps, wanx_seed, wanx_flow_shift, + wanx_guidance_scale, wanx_attn_mode, wanx_block_swap, + wanx_task, wanx_negative_prompt, + *wanx_lora_weights, + *wanx_lora_multipliers + ] + ).then( + fn=change_to_wanx_i2v_tab, + inputs=None, + outputs=[tabs] + ) + + # 3. Update the WanX-t2v button handler + send_to_wanx_t2v_btn.click( + fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_t2v'), + inputs=[metadata_output], + outputs=[status, params_state] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 50), + params.get("seed", -1), + params.get("flow_shift", 5.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0), + params.get("negative_prompt", ""), + *[params.get("lora_weights", ["None"]*4)[i] if isinstance(params.get("lora_weights", []), list) and i < len(params.get("lora_weights", [])) else "None" for i in range(4)], + *[params.get("lora_multipliers", [1.0]*4)[i] if isinstance(params.get("lora_multipliers", []), list) and i < len(params.get("lora_multipliers", [])) else 1.0 for i in range(4)] + ] if params else [gr.update()]*20, + inputs=params_state, + outputs=[ + wanx_t2v_prompt, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_seed, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_attn_mode, + wanx_t2v_block_swap, + wanx_t2v_negative_prompt, + *wanx_t2v_lora_weights, + *wanx_t2v_lora_multipliers + ] + ).then( + fn=change_to_wanx_t2v_tab, inputs=None, outputs=[tabs] + ) + # FramePack-Extension send-to logic + def handle_send_to_fpe_tab(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: + """Prepare parameters and video path for the FramePack-Extension tab.""" + if not video_path: + return "No video selected to send to FramePack-Extension", {}, None + + # If metadata is empty, provide a message but still allow video transfer + status_msg = "Parameters ready for FramePack-Extension." + if not metadata: + status_msg = "Video sent to FramePack-Extension (no parameters found in metadata)." + metadata = {} # Ensure metadata is a dict + + return status_msg, metadata, video_path + + def change_to_fpe_tab(): + return gr.Tabs(selected=11) # FramePack-Extension tab has id=11 + + send_to_fpe_btn.click( + fn=handle_send_to_fpe_tab, + inputs=[metadata_output, video_input], + outputs=[status, params_state, fpe_input_video] # status, state for params, and video input for FPE + ).then( + lambda params: ( + ( + (is_f1_from_meta := params.get("is_f1", True)), # Default to F1 if not specified + (use_normal_fp_val := not is_f1_from_meta), # fpe_use_normal_framepack is opposite of is_f1 + + # Determine resolution_max_dim + (target_res_meta := params.get("target_resolution")), + (video_w_meta := params.get("video_width")), + (video_h_meta := params.get("video_height")), + ( + res_max_dim_val := int(target_res_meta) if target_res_meta and int(target_res_meta) > 0 + else max(int(video_w_meta), int(video_h_meta)) if video_w_meta and video_h_meta and int(video_w_meta) > 0 and int(video_h_meta) > 0 + else 640 # Default + ), + # LoRA handling + (weights_from_meta := params.get("lora_weights", [])), + (mults_from_meta := params.get("lora_multipliers", [])), + (padded_weights := (weights_from_meta + ["None"] * 4)[:4]), + (padded_mults := ([float(m) if isinstance(m, (int, float, str)) and str(m).replace('.', '', 1).isdigit() else 1.0 for m in mults_from_meta] + [1.0] * 4)[:4]), + + [ + params.get("prompt", "cinematic video of a cat wizard casting a spell"), + params.get("negative_prompt", ""), + params.get("seed", -1), + use_normal_fp_val, + # fpe_end_frame and fpe_end_frame_weight are typically not in generic metadata, use defaults + gr_update(value=None), # fpe_end_frame (Image) + gr_update(value=1.0), # fpe_end_frame_weight + res_max_dim_val, + params.get("video_seconds", params.get("total_second_length", 5.0)), # Map from FramePack's video_seconds + params.get("latent_window_size", 9), + params.get("infer_steps", params.get("steps", 25)), # Map from FramePack's infer_steps + params.get("guidance_scale", params.get("cfg_scale", 1.0)), # Map from FramePack's guidance_scale to fpe_cfg_scale + params.get("embedded_cfg_scale", params.get("distilled_guidance_scale", 3.0)), # Map from FramePack's embedded_cfg_scale + # Model Paths - use FPE defaults or specific paths from metadata if available + # The DiT path is now primarily handled by the fpe_use_normal_framepack.change event + params.get("transformer_path", "hunyuan/FramePack_F1_I2V_HY_20250503.safetensors"), # Placeholder, will be overridden + params.get("vae_path", "hunyuan/pytorch_model.pt"), + params.get("text_encoder_path", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("text_encoder_2_path", "hunyuan/clip_l.safetensors"), + params.get("image_encoder_path", "hunyuan/model.safetensors"), + # Advanced performance + params.get("attn_mode", "torch"), + params.get("fp8_llm", False), # This will be correctly set by fpe_use_normal_framepack.change + params.get("vae_chunk_size", 32), + params.get("vae_spatial_tile_sample_min_size", 128), + # LoRAs + *padded_weights, + *padded_mults, + ] + )[-1] # Return the list of values + ) if params else [gr.update()] * (18 + 8), # 18 direct params + 4 lora weights + 4 lora mults + inputs=params_state, + outputs=[ + fpe_prompt, fpe_negative_prompt, fpe_seed, + fpe_use_normal_framepack, # This will trigger its own .change event + fpe_end_frame, fpe_end_frame_weight, # These are UI only if fpe_use_normal_framepack is True + fpe_resolution_max_dim, fpe_total_second_length, fpe_latent_window_size, + fpe_steps, fpe_cfg_scale, fpe_distilled_guidance_scale, + # Model Paths + fpe_transformer_path, # Will be set by fpe_use_normal_framepack.change + fpe_vae_path, fpe_text_encoder_path, fpe_text_encoder_2_path, fpe_image_encoder_path, + # Advanced + fpe_attn_mode, fpe_fp8_llm, # fpe_fp8_llm also set by fpe_use_normal_framepack.change + fpe_vae_chunk_size, fpe_vae_spatial_tile_sample_min_size, + # LoRAs + *fpe_lora_weights_ui, *fpe_lora_multipliers_ui, + ] + ).then( + fn=change_to_fpe_tab, + inputs=None, + outputs=[tabs] + ) + #text to video + def change_to_tab_one(): + return gr.Tabs(selected=1) #This will navigate + #video to video + + def change_to_skyreels_tab(): + return gr.Tabs(selected=3) + + #SKYREELS TAB!!! + # Add state management for dimensions + def sync_skyreels_dimensions(width, height): + return gr.update(value=width), gr.update(value=height) + + # Add this function to update the LoRA dropdowns in the SKYREELS tab + def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + + # Add this function to update the models dropdown in the SKYREELS tab + def update_skyreels_model_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Add event handler for model dropdown refresh + skyreels_dit_folder.change( + fn=update_skyreels_model_dropdown, + inputs=[skyreels_dit_folder], + outputs=[skyreels_model] + ) + + # Add handlers for the refresh button + skyreels_refresh_btn.click( + fn=update_skyreels_lora_dropdowns, + inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] + ) + # Skyreels dimension handling + def calculate_skyreels_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 + return gr.update(value=new_width) + + def calculate_skyreels_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 + return gr.update(value=new_height) + + def update_skyreels_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_skyreels_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + def send_skyreels_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float, + negative_prompt: str = "" # Add this parameter + ) -> Tuple: + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + # Add event handlers for the SKYREELS tab + skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) + skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling + skyreels_input.change( + fn=update_skyreels_dimensions, + inputs=[skyreels_input], + outputs=[skyreels_original_dims, skyreels_width, skyreels_height] + ) + + skyreels_scale_slider.change( + fn=update_skyreels_from_scale, + inputs=[skyreels_scale_slider, skyreels_original_dims], + outputs=[skyreels_width, skyreels_height] + ) + + skyreels_calc_width_btn.click( + fn=calculate_skyreels_width, + inputs=[skyreels_height, skyreels_original_dims], + outputs=[skyreels_width] + ) + + skyreels_calc_height_btn.click( + fn=calculate_skyreels_height, + inputs=[skyreels_width, skyreels_original_dims], + outputs=[skyreels_height] + ) + + # Handle checkbox visibility toggling + skyreels_use_random_folder.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), + inputs=[skyreels_use_random_folder], + outputs=[skyreels_input_folder, skyreels_folder_status, skyreels_input] + ) + + # Validate folder button click handler + skyreels_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], + inputs=[skyreels_input_folder], + outputs=[skyreels_folder_status] + ) + + skyreels_use_random_folder.change( + fn=lambda x: gr.update(visible=x), + inputs=[skyreels_use_random_folder], + outputs=[skyreels_validate_folder_btn] + ) + + # Modify the skyreels_generate_btn.click event handler to use process_random_image_batch when folder mode is on + skyreels_generate_btn.click( + fn=batch_handler, + inputs=[ + skyreels_use_random_folder, + # Rest of the arguments + skyreels_prompt, + skyreels_negative_prompt, + skyreels_width, + skyreels_height, + skyreels_video_length, + skyreels_fps, + skyreels_infer_steps, + skyreels_seed, + skyreels_flow_shift, + skyreels_guidance_scale, + skyreels_embedded_cfg_scale, + skyreels_batch_size, + skyreels_input_folder, + skyreels_dit_folder, + skyreels_model, + skyreels_vae, + skyreels_te1, + skyreels_te2, + skyreels_save_path, + skyreels_output_type, + skyreels_attn_mode, + skyreels_block_swap, + skyreels_exclude_single_blocks, + skyreels_use_split_attn, + skyreels_use_fp8, + skyreels_split_uncond, + skyreels_lora_folder, + *skyreels_lora_weights, + *skyreels_lora_multipliers, + skyreels_input # Add the input image path + ], + outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[skyreels_batch_size], + outputs=skyreels_selected_index + ) + + # Gallery selection handling + skyreels_output.select( + fn=handle_skyreels_gallery_select, + outputs=skyreels_selected_index + ) + + # Send to Video2Video handler + skyreels_send_to_v2v_btn.click( + fn=send_skyreels_to_v2v, + inputs=[ + skyreels_output, skyreels_prompt, skyreels_selected_index, + skyreels_width, skyreels_height, skyreels_video_length, + skyreels_fps, skyreels_infer_steps, skyreels_seed, + skyreels_flow_shift, skyreels_guidance_scale + ] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component + outputs=[ + v2v_input, v2v_prompt, v2v_width, v2v_height, + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + + # Refresh button handler + skyreels_refresh_outputs = [skyreels_model] + for i in range(4): + skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) + + skyreels_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=skyreels_refresh_outputs + ) + + def calculate_v2v_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_width) + + def calculate_v2v_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_height) + + def update_v2v_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_v2v_dimensions(video): + if video is None: + return "", gr.update(value=544), gr.update(value=544) + cap = cv2.VideoCapture(video) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + # Make dimensions divisible by 16 + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + # Event Handlers for Video to Video Tab + v2v_input.change( + fn=update_v2v_dimensions, + inputs=[v2v_input], + outputs=[v2v_original_dims, v2v_width, v2v_height] + ) + + v2v_scale_slider.change( + fn=update_v2v_from_scale, + inputs=[v2v_scale_slider, v2v_original_dims], + outputs=[v2v_width, v2v_height] + ) + + v2v_calc_width_btn.click( + fn=calculate_v2v_width, + inputs=[v2v_height, v2v_original_dims], + outputs=[v2v_width] + ) + + v2v_calc_height_btn.click( + fn=calculate_v2v_height, + inputs=[v2v_width, v2v_original_dims], + outputs=[v2v_height] + ) + + ##Image 2 video dimension logic + def calculate_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_width) + + def calculate_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_height) + + def update_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + # Make dimensions divisible by 16 + w = (w // 16) * 16 # Changed from 8 to 16 + h = (h // 16) * 16 # Changed from 8 to 16 + return f"{w}x{h}", w, h + i2v_input.change( + fn=update_dimensions, + inputs=[i2v_input], + outputs=[original_dims, i2v_width, i2v_height] # Update correct components + ) + + scale_slider.change( + fn=update_from_scale, + inputs=[scale_slider, original_dims], + outputs=[i2v_width, i2v_height] # Update correct components + ) + + calc_width_btn.click( + fn=calculate_width, + inputs=[i2v_height, original_dims], # Update correct components + outputs=[i2v_width] + ) + + calc_height_btn.click( + fn=calculate_height, + inputs=[i2v_width, original_dims], # Update correct components + outputs=[i2v_height] + ) + + # Function to get available DiT models + def get_dit_models(dit_folder: str) -> List[str]: + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + + # Function to perform model merging + def merge_models( + dit_folder: str, + dit_model: str, + output_model: str, + exclude_single_blocks: bool, + merge_lora_folder: str, + *lora_params # Will contain both weights and multipliers + ) -> str: + try: + # Separate weights and multipliers + num_loras = len(lora_params) // 2 + weights = list(lora_params[:num_loras]) + multipliers = list(lora_params[num_loras:]) + + # Filter out "None" selections + valid_loras = [] + for weight, mult in zip(weights, multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) + + if not valid_loras: + return "No LoRA models selected for merging" + + # Create output path in the dit folder + os.makedirs(dit_folder, exist_ok=True) + output_path = os.path.join(dit_folder, output_model) + + # Prepare command + cmd = [ + sys.executable, + "merge_lora.py", + "--dit", os.path.join(dit_folder, dit_model), + "--save_merged_model", output_path + ] + + # Add LoRA weights and multipliers + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + # Execute merge operation + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + if os.path.exists(output_path): + return f"Successfully merged model and saved to {output_path}" + else: + return "Error: Output file not created" + + except subprocess.CalledProcessError as e: + return f"Error during merging: {e.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + # Update DiT model dropdown + def update_dit_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Connect events + merge_btn.click( + fn=merge_models, + inputs=[ + dit_folder, + dit_model, + output_model, + exclude_single_blocks, + merge_lora_folder, + *merge_lora_weights, + *merge_lora_multipliers + ], + outputs=merge_status + ) + + # Refresh buttons for both DiT and LoRA dropdowns + merge_refresh_btn.click( + fn=lambda f: update_dit_dropdown(f), + inputs=[dit_folder], + outputs=[dit_model] + ) + + # LoRA refresh handling + merge_refresh_outputs = [] + for i in range(4): + merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) + + merge_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, + outputs=merge_refresh_outputs + ) + # Event handlers + prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) + v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) + stop_btn.click(fn=lambda: stop_event.set(), queue=False) + v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + #Image_to_Video + def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters + img = Image.open(image_path) + + # Resize to the specified dimensions + img_resized = img.resize((width, height), Image.LANCZOS) + temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") + img_resized.save(temp_image_path) + + # Rest of function remains the same + frame_rate = 24 + duration = frames / frame_rate + command = [ + "ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", + "-t", str(duration), "-pix_fmt", "yuv420p", + "-vf", f"fps={frame_rate}", output_path + ] + + try: + subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + print(f"Video saved to {output_path}") + return True + except subprocess.CalledProcessError as e: + print(f"An error occurred while creating the video: {e}") + return False + finally: + # Clean up the temporary image file + if os.path.exists(temp_image_path): + os.remove(temp_image_path) + img.close() # Make sure to close the image file explicitly + + def generate_from_image( + image_path, + prompt, width, height, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, strength, batch_size, *lora_params + ): + """Generate video from input image with progressive updates""" + global stop_event + stop_event.clear() + + # Create temporary video path + temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") + + try: + # Convert image to video + if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): + yield [], "Failed to create temporary video", "Error in video creation" + return + + # Ensure video is fully written before proceeding + time.sleep(1) + if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: + yield [], "Failed to create temporary video", "Temporary video file is empty or missing" + return + + # Get video dimensions + try: + probe = ffmpeg.probe(temp_video_path) + video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) + if video_stream is None: + raise ValueError("No video stream found") + width = int(video_stream['width']) + height = int(video_stream['height']) + except Exception as e: + yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" + return + + # Generate the video using the temporary file + try: + generator = process_single_video( + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_params, video_path=temp_video_path, strength=strength + ) + + # Forward all generator updates + for videos, batch_text, progress_text in generator: + yield videos, batch_text, progress_text + + except Exception as e: + yield [], f"Error in video generation: {str(e)}", "Generation error" + return + + except Exception as e: + yield [], f"Unexpected error: {str(e)}", "Error occurred" + return + + finally: + # Clean up temporary file + try: + if os.path.exists(temp_video_path): + os.remove(temp_video_path) + except Exception: + pass # Ignore cleanup errors + + + # Add event handlers + i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) + i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + def handle_i2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when I2V gallery item is clicked""" + return evt.index + + def send_i2v_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: + """Send the selected video and parameters from Image2Video tab to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ + lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Use the original width and height without doubling + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) + + # Generate button handler for h-basic-i2v + i2v_generate_btn.click( + fn=process_i2v_batch, # <<< Use the new batch function + inputs=[ + i2v_prompt, + i2v_input, # Image path + i2v_width, + i2v_height, + i2v_batch_size, + i2v_video_length, + i2v_fps, + i2v_infer_steps, + i2v_seed, + i2v_dit_folder, + i2v_model, + i2v_vae, + i2v_te1, + i2v_te2, + i2v_clip_vision_path, + i2v_save_path, + i2v_flow_shift, + i2v_cfg_scale, # embedded_cfg_scale + i2v_guidance_scale, # main CFG scale + i2v_output_type, + i2v_attn_mode, + i2v_block_swap, + i2v_exclude_single_blocks, + i2v_use_split_attn, + i2v_lora_folder, + i2v_vae_chunk_size, + i2v_vae_spatial_tile_min, + # --- Add negative prompt component if you have one --- + # i2v_negative_prompt, # Uncomment if you added this textbox + # --- If no negative prompt textbox, pass None or "": --- + gr.Textbox(value="", visible=False), # Placeholder if no UI element + # --- End negative prompt handling --- + i2v_use_fp8, + i2v_fp8_llm, + *i2v_lora_weights, # Pass LoRA weights components + *i2v_lora_multipliers # Pass LoRA multipliers components + ], + outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[i2v_batch_size], + outputs=i2v_selected_index + ) + # Send to Video2Video + i2v_output.select( + fn=handle_i2v_gallery_select, + outputs=i2v_selected_index + ) + + i2v_send_to_v2v_btn.click( + fn=send_i2v_to_v2v, # Function definition needs careful review/update if args changed + inputs=[ + i2v_output, i2v_prompt, i2v_selected_index, + i2v_width, i2v_height, # <<< Use i2v width/height + i2v_video_length, i2v_fps, i2v_infer_steps, + i2v_seed, i2v_flow_shift, i2v_cfg_scale # <<< Use i2v cfg_scale (embedded) + ] + i2v_lora_weights + i2v_lora_multipliers, # <<< Use i2v LoRAs + outputs=[ + v2v_input, v2v_prompt, + v2v_width, v2v_height, # Target V2V components + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale # Target V2V components + ] + v2v_lora_weights + v2v_lora_multipliers # Target V2V LoRAs + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + #Video Info + def clean_video_path(video_path) -> str: + """Extract clean video path from Gradio's various return formats""" + print(f"Input video_path: {video_path}, type: {type(video_path)}") + if isinstance(video_path, dict): + path = video_path.get("name", "") + elif isinstance(video_path, (tuple, list)): + path = video_path[0] + elif isinstance(video_path, str): + path = video_path + else: + path = "" + print(f"Cleaned path: {path}") + return path + def handle_video_upload(video_path: str) -> Dict: + """Handle video upload and metadata extraction""" + if not video_path: + return {}, "No video uploaded" + + metadata = extract_video_metadata(video_path) + if not metadata: + return {}, "No metadata found in video" + + return metadata, "Metadata extracted successfully" + + def get_video_info(video_path: str) -> dict: + try: + probe = ffmpeg.probe(video_path) + video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') + + width = int(video_info['width']) + height = int(video_info['height']) + fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 + + # Calculate total frames + duration = float(probe['format']['duration']) + total_frames = int(duration * fps) + + # Ensure video length does not exceed 201 frames + if total_frames > 201: + total_frames = 201 + duration = total_frames / fps # Adjust duration accordingly + + return { + 'width': width, + 'height': height, + 'fps': fps, + 'total_frames': total_frames, + 'duration': duration # Might be useful in some contexts + } + except Exception as e: + print(f"Error extracting video info: {e}") + return {} + + def extract_video_details(video_path: str) -> Tuple[dict, str]: + metadata = extract_video_metadata(video_path) + video_details = get_video_info(video_path) + + # Combine metadata with video details + for key, value in video_details.items(): + if key not in metadata: + metadata[key] = value + + # Ensure video length does not exceed 201 frames + if 'video_length' in metadata: + metadata['video_length'] = min(metadata['video_length'], 201) + else: + metadata['video_length'] = min(video_details.get('total_frames', 0), 201) + + # Return both the updated metadata and a status message + return metadata, "Video details extracted successfully" + + def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: + """Create parameter mapping for target tab""" + if not metadata: + return "No parameters to send", {} + + tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" + try: + mapping = create_parameter_transfer_map(metadata, target_tab) + return f"Parameters ready for {tab_name}", mapping + except Exception as e: + return f"Error: {str(e)}", {} + + video_input.upload( + fn=extract_video_details, + inputs=video_input, + outputs=[metadata_output, status] + ) + + send_to_t2v_btn.click( + fn=lambda m: send_parameters_to_tab(m, "t2v"), + inputs=metadata_output, + outputs=[status, params_state] + ).then( + fn=change_to_tab_one, inputs=None, outputs=[tabs] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 544), # Parameter mapping is fine here + params.get("height", 544), # Parameter mapping is fine here + params.get("batch_size", 1), + params.get("video_length", 25), + params.get("fps", 24), + params.get("infer_steps", 30), + params.get("seed", -1), + params.get("model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("vae", "hunyuan/pytorch_model.pt"), + params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("te2", "hunyuan/clip_l.safetensors"), + params.get("save_path", "outputs"), + params.get("flow_shift", 11.0), + params.get("cfg_scale", 7.0), + params.get("output_type", "video"), + params.get("attn_mode", "sdpa"), + params.get("block_swap", "0"), + *[params.get(f"lora{i+1}", "") for i in range(4)], + *[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] + ] if params else [gr.update()]*26, # This lambda returns values based on param keys + inputs=params_state, + outputs=[prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, seed, # <<< CORRECTED HERE: use t2v_width, t2v_height + model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap] + lora_weights + lora_multipliers + ) + # Text to Video generation + generate_btn.click( + fn=process_batch, + inputs=[ + prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 + ], + outputs=[video_output, batch_progress, progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[batch_size], + outputs=selected_index + ) + + # Update gallery selection handling + def handle_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + # Track selected index when gallery item is clicked + video_output.select( + fn=handle_gallery_select, + outputs=selected_index + ) + + # Track selected index when Video2Video gallery item is clicked + def handle_v2v_gallery_select(evt: gr.SelectData) -> int: + """Handle gallery selection without automatically updating the input""" + return evt.index + + # Update the gallery selection event + v2v_output.select( + fn=handle_v2v_gallery_select, + outputs=v2v_selected_index + ) + + # Send button handler with gallery selection + def handle_send_button( + gallery: list, + prompt: str, + idx: int, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> tuple: + if not gallery or idx is None or idx >= len(gallery): + return (None, "", width, height, batch_size, video_length, fps, infer_steps, + seed, flow_shift, cfg_scale, + lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + "") # Add empty string for negative_prompt in the return values + + # Auto-select first item if only one exists and no selection made + if idx is None and len(gallery) == 1: + idx = 0 + + selected_item = gallery[idx] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return ( + str(video_path), + prompt, + width, + height, + batch_size, + video_length, + fps, + infer_steps, + seed, + flow_shift, + cfg_scale, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier, + "" # Add empty string for negative_prompt + ) + + send_t2v_to_v2v_btn.click( + fn=handle_send_button, + inputs=[ + video_output, prompt, selected_index, + t2v_width, t2v_height, batch_size, video_length, + fps, infer_steps, seed, flow_shift, cfg_scale + ] + lora_weights + lora_multipliers, # Remove the string here + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_batch_size, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: + """Handle both parameters and video transfer""" + status_msg, params = send_parameters_to_tab(metadata, "v2v") + return status_msg, params, video_path + + def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: + """Handle both parameters and video transfer from Video Info to V2V tab""" + if not video_path: + return "No video selected", {}, None + + status_msg, params = send_parameters_to_tab(metadata, "v2v") + # Just return the path directly + return status_msg, params, video_path + + # Send button click handler + send_to_v2v_btn.click( + fn=handle_info_to_v2v, + inputs=[metadata_output, video_input], + outputs=[status, params_state, v2v_input] + ).then( + lambda params: [ + params.get("v2v_prompt", ""), + params.get("v2v_width", 544), + params.get("v2v_height", 544), + params.get("v2v_batch_size", 1), + params.get("v2v_video_length", 25), + params.get("v2v_fps", 24), + params.get("v2v_infer_steps", 30), + params.get("v2v_seed", -1), + params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("v2v_vae", "hunyuan/pytorch_model.pt"), + params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("v2v_te2", "hunyuan/clip_l.safetensors"), + params.get("v2v_save_path", "outputs"), + params.get("v2v_flow_shift", 11.0), + params.get("v2v_cfg_scale", 7.0), + params.get("v2v_output_type", "video"), + params.get("v2v_attn_mode", "sdpa"), + params.get("v2v_block_swap", "0"), + *[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], + *[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] + ] if params else [gr.update()] * 26, + inputs=params_state, + outputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, + v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, + v2v_attn_mode, v2v_block_swap + ] + v2v_lora_weights + v2v_lora_multipliers + ).then( + lambda: print(f"Tabs object: {tabs}"), # Debug print + outputs=None + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + # Handler for sending selected video from Video2Video gallery to input + def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: + """Send the currently selected video in V2V gallery to V2V input""" + if not gallery or idx is None or idx >= len(gallery): + return None, "" + + selected_item = gallery[idx] + video_path = None + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] # Gallery returns (path, caption) + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, str): + video_path = selected_item + + if not video_path: + return None, "" + + # Check if the file exists and is accessible + if not os.path.exists(video_path): + print(f"Warning: Video file not found at {video_path}") + return None, "" + + return video_path, prompt + + v2v_send_to_input_btn.click( + fn=handle_v2v_send_button, + inputs=[v2v_output, v2v_prompt, v2v_selected_index], + outputs=[v2v_input, v2v_prompt] + ).then( + lambda: gr.update(visible=True), # Ensure the video input is visible + outputs=v2v_input + ) + + # Video to Video generation + v2v_generate_btn.click( + fn=process_batch, + inputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, + v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, + v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, + *v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, + v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 + ], + outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[v2v_batch_size], + outputs=v2v_selected_index + ) + refresh_outputs = [model] # Add model dropdown to outputs + for i in range(4): + refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) + + refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, + outputs=refresh_outputs + ) + # Image2Video refresh + i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs + for i in range(4): + i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) + + i2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, + outputs=i2v_refresh_outputs + ) + + # Video2Video refresh + v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs + for i in range(4): + v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) + + v2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, + outputs=v2v_refresh_outputs + ) + + # WanX-i2v tab connections + wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) + wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling for WanX-i2v + wanx_input.change( + fn=update_wanx_image_dimensions, + inputs=[wanx_input], + outputs=[wanx_original_dims, wanx_width, wanx_height] + ) + + # Scale slider handling for WanX-i2v + wanx_scale_slider.change( + fn=update_wanx_from_scale, + inputs=[wanx_scale_slider, wanx_original_dims], + outputs=[wanx_width, wanx_height] + ) + + # Width/height calculation buttons for WanX-i2v + wanx_calc_width_btn.click( + fn=calculate_wanx_width, + inputs=[wanx_height, wanx_original_dims], + outputs=[wanx_width] + ) + + wanx_calc_height_btn.click( + fn=calculate_wanx_height, + inputs=[wanx_width, wanx_original_dims], + outputs=[wanx_height] + ) + # Add visibility toggle for the folder input components + wanx_use_random_folder.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), + inputs=[wanx_use_random_folder], + outputs=[wanx_input_folder, wanx_folder_status, wanx_validate_folder_btn, wanx_input] + ) + def toggle_end_image(use_end_image): + return ( + gr.update(visible=use_end_image, interactive=use_end_image), # wanx_input_end + gr.update(visible=False) # wanx_trim_frames + ) + wanx_use_end_image.change( + fn=toggle_end_image, + inputs=[wanx_use_end_image], + outputs=[wanx_input_end, wanx_trim_frames] + ) + # Validate folder button handler + wanx_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], + inputs=[wanx_input_folder], + outputs=[wanx_folder_status] + ) + + # Flow shift recommendation buttons + wanx_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_width, wanx_height], + outputs=[wanx_flow_shift] + ) + + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Generate button handler + wanx_generate_btn.click( + fn=wanx_batch_handler, + inputs=[ + wanx_use_random_folder, + wanx_prompt, + wanx_negative_prompt, + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_flow_shift, + wanx_guidance_scale, + wanx_seed, + wanx_batch_size, + wanx_input_folder, + wanx_input_end, # Make sure this is passed + wanx_task, + wanx_dit_folder, + wanx_dit_path, + wanx_vae_path, + wanx_t5_path, + wanx_clip_path, + wanx_save_path, + wanx_output_type, + wanx_sample_solver, + wanx_exclude_single_blocks, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, + wanx_fp8_scaled, + wanx_fp8_t5, + wanx_lora_folder, + wanx_slg_layers, + wanx_slg_start, + wanx_slg_end, + wanx_enable_cfg_skip, + wanx_cfg_skip_mode, + wanx_cfg_apply_ratio, + # --- ADDED PREVIEW INPUTS --- + wanx_enable_preview, + wanx_preview_steps, + # --- END ADDED --- + *wanx_lora_weights, + *wanx_lora_multipliers, + wanx_input, # Input image (used as input_file in handler) + wanx_control_video, # Control video + wanx_control_strength, + wanx_control_start, + wanx_control_end, + ], + outputs=[ + wanx_output, # Main video gallery + wanx_preview_output, # ADDED: Preview gallery + wanx_batch_progress, # Status text + wanx_progress_text # Progress text + ], # Now 4 outputs + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_batch_size], + outputs=wanx_i2v_selected_index + ) + + # Add refresh button handler for WanX-i2v tab + wanx_refresh_outputs = [wanx_dit_path] # Add model dropdown to outputs + for i in range(4): + wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) + + wanx_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, # This function already exists and handles both updates + inputs=[wanx_dit_folder, wanx_lora_folder, wanx_dit_path] + wanx_lora_weights + wanx_lora_multipliers, + outputs=wanx_refresh_outputs + ) + wanx_dit_folder.change( + fn=update_dit_dropdown, + inputs=[wanx_dit_folder], + outputs=[wanx_dit_path] + ) + + wanx_dit_folder.change( + fn=update_dit_dropdown, + inputs=[wanx_dit_folder], + outputs=[wanx_t2v_dit_path] + ) + + wanx_dit_folder.change( + fn=update_dit_dropdown, + inputs=[wanx_dit_folder], + outputs=[wanx_v2v_dit_path] + ) + + # Gallery selection handling + wanx_output.select( + fn=handle_wanx_gallery_select, + inputs=[wanx_output], + outputs=[wanx_i2v_selected_index, wanx_base_video] + ) + + # Send to Video2Video handler + wanx_send_to_v2v_btn.click( + fn=send_wanx_to_v2v, + inputs=[ + wanx_output, # Gallery with videos + wanx_prompt, # Prompt text + wanx_i2v_selected_index, # Use the correct selected index state + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_seed, + wanx_flow_shift, + wanx_guidance_scale, + wanx_negative_prompt + ], + outputs=[ + v2v_input, # Video input in V2V tab + v2v_prompt, # Prompt in V2V tab + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, # Function to switch to Video2Video tab + inputs=None, + outputs=[tabs] + ) + # Connect prompt token counter + wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) + + # Stop button handler + wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Flow shift recommendation button + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Task change handler to update CLIP visibility and path + def update_clip_visibility(task): + is_i2v = "i2v" in task + return gr.update(visible=is_i2v) + + wanx_t2v_task.change( + fn=update_clip_visibility, + inputs=[wanx_t2v_task], + outputs=[wanx_t2v_clip_path] + ) + + # Generate button handler for T2V + wanx_t2v_generate_btn.click( + fn=wanx_batch_handler, + inputs=[ + wanx_t2v_use_random_folder, # use_random + wanx_t2v_prompt, # prompt + wanx_t2v_negative_prompt, # negative_prompt + wanx_t2v_width, # width + wanx_t2v_height, # height + wanx_t2v_video_length, # video_length + wanx_t2v_fps, # fps + wanx_t2v_infer_steps, # infer_steps + wanx_t2v_flow_shift, # flow_shift + wanx_t2v_guidance_scale, # guidance_scale + wanx_t2v_seed, # seed + wanx_t2v_batch_size, # batch_size + wanx_t2v_input_folder, # input_folder_path + wanx_t2v_input_end, # wanx_input_end + wanx_t2v_task, # task + wanx_dit_folder, # dit_folder (shared) + wanx_t2v_dit_path, # dit_path + wanx_t2v_vae_path, # vae_path + wanx_t2v_t5_path, # t5_path + wanx_t2v_clip_path, # clip_path (often None for t2v) + wanx_t2v_save_path, # save_path + wanx_t2v_output_type, # output_type + wanx_t2v_sample_solver, # sample_solver + wanx_t2v_exclude_single_blocks, # exclude_single_blocks + wanx_t2v_attn_mode, # attn_mode + wanx_t2v_block_swap, # block_swap + wanx_t2v_fp8, # fp8 + wanx_t2v_fp8_scaled, # fp8_scaled + wanx_t2v_fp8_t5, # fp8_t5 + wanx_t2v_lora_folder, # lora_folder + wanx_t2v_slg_layers, # slg_layers + wanx_t2v_slg_start, # slg_start + wanx_t2v_slg_end, # slg_end + wanx_t2v_enable_cfg_skip, # enable_cfg_skip + wanx_t2v_cfg_skip_mode, # cfg_skip_mode + wanx_t2v_cfg_apply_ratio, # cfg_apply_ratio + # --- ADDED PREVIEW INPUTS --- + wanx_t2v_enable_preview, + wanx_t2v_preview_steps, + # --- END ADDED --- + *wanx_t2v_lora_weights, # *lora_params (weights) + *wanx_t2v_lora_multipliers, # *lora_params (multipliers) + # --- ADDED Placeholders for trailing args expected by wanx_batch_handler --- + gr.File(value=None, visible=False), # Placeholder for input_file (None for T2V) + gr.Video(value=None, visible=False), # Placeholder for control_video (None for T2V) + gr.Number(value=1.0, visible=False), # Placeholder for control_strength + gr.Number(value=0.0, visible=False), # Placeholder for control_start + gr.Number(value=1.0, visible=False), # Placeholder for control_end + # --- END Placeholders --- + ], + outputs=[ + wanx_t2v_output, # Main video gallery + wanx_t2v_preview_output, # ADDED: Preview gallery + wanx_t2v_batch_progress, # Status text + wanx_t2v_progress_text # Progress text + ], # Now 4 outputs + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_t2v_batch_size], + outputs=wanx_t2v_selected_index + ) + + # Add refresh button handler for WanX-t2v tab + wanx_t2v_refresh_outputs = [wanx_t2v_dit_path] # This is one output + for i in range(4): + wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) # This adds 8 more outputs + + wanx_t2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, # Change to this function instead + inputs=[wanx_dit_folder, wanx_t2v_lora_folder, wanx_t2v_dit_path] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, + outputs=wanx_t2v_refresh_outputs + ) + + # Gallery selection handling + wanx_t2v_output.select( + fn=handle_wanx_t2v_gallery_select, + outputs=wanx_t2v_selected_index + ) + + # Send to Video2Video handler + wanx_t2v_send_to_v2v_btn.click( + fn=send_wanx_t2v_to_v2v, + inputs=[ + wanx_t2v_output, + wanx_t2v_prompt, + wanx_t2v_selected_index, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_seed, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_negative_prompt + ], + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Launch H1111 Gradio Interface.") + parser.add_argument( + "--share", + action="store_true", + help="Enable Gradio sharing link." + ) + args = parser.parse_args() + + # Make sure 'outputs' directory exists + os.makedirs("outputs", exist_ok=True) + # Optional: Clean temp_frames directory on startup + #if os.path.exists("temp_frames"): + # try: shutil.rmtree("temp_frames") + # except OSError as e: print(f"Error removing temp_frames: {e}") + os.makedirs("temp_frames", exist_ok=True) + + demo.queue().launch(server_name="0.0.0.0", share=args.share) \ No newline at end of file diff --git a/h2222.py b/h2222.py new file mode 100644 index 0000000000000000000000000000000000000000..d824c72c214aac747c23189f644fb7192ab8c42d --- /dev/null +++ b/h2222.py @@ -0,0 +1,3194 @@ +import gradio as gr +from gradio import update as gr_update +import subprocess +import threading +import time +import re +import os +import random +import tiktoken +import sys +import ffmpeg +from typing import List, Tuple, Optional, Generator, Dict +import json +from gradio import themes +from gradio.themes.utils import colors +import subprocess +from PIL import Image +import math +import cv2 + +# Add global stop event +stop_event = threading.Event() + +def get_dit_models(dit_folder: str) -> List[str]: + """Get list of available DiT models in the specified folder""" + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + +def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: + """Update both DiT and LoRA dropdowns""" + # Get model lists + dit_models = get_dit_models(dit_folder) + lora_choices = get_lora_options(lora_folder) + + # Current values processing + dit_value = current_values[0] + if dit_value not in dit_models: + dit_value = dit_models[0] if dit_models else None + + weights = current_values[1:5] + multipliers = current_values[5:9] + + results = [gr.update(choices=dit_models, value=dit_value)] + + # Add LoRA updates + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in lora_choices: + weight = "None" + results.extend([ + gr.update(choices=lora_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def extract_video_metadata(video_path: str) -> Dict: + """Extract metadata from video file using ffprobe.""" + cmd = [ + 'ffprobe', + '-v', 'quiet', + '-print_format', 'json', + '-show_format', + video_path + ] + + try: + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + metadata = json.loads(result.stdout.decode('utf-8')) + if 'format' in metadata and 'tags' in metadata['format']: + comment = metadata['format']['tags'].get('comment', '{}') + return json.loads(comment) + return {} + except Exception as e: + print(f"Metadata extraction failed: {str(e)}") + return {} + +def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: + """Map metadata parameters to Gradio components for different tabs""" + mapping = { + 'common': { + 'prompt': ('prompt', 'v2v_prompt'), + 'width': ('width', 'v2v_width'), + 'height': ('height', 'v2v_height'), + 'batch_size': ('batch_size', 'v2v_batch_size'), + 'video_length': ('video_length', 'v2v_video_length'), + 'fps': ('fps', 'v2v_fps'), + 'infer_steps': ('infer_steps', 'v2v_infer_steps'), + 'seed': ('seed', 'v2v_seed'), + 'model': ('model', 'v2v_model'), + 'vae': ('vae', 'v2v_vae'), + 'te1': ('te1', 'v2v_te1'), + 'te2': ('te2', 'v2v_te2'), + 'save_path': ('save_path', 'v2v_save_path'), + 'flow_shift': ('flow_shift', 'v2v_flow_shift'), + 'cfg_scale': ('cfg_scale', 'v2v_cfg_scale'), + 'output_type': ('output_type', 'v2v_output_type'), + 'attn_mode': ('attn_mode', 'v2v_attn_mode'), + 'block_swap': ('block_swap', 'v2v_block_swap') + }, + 'lora': { + 'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]') for i in range(4)], + 'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]') for i in range(4)] + } + } + + results = {} + for param, value in metadata.items(): + # Handle common parameters + if param in mapping['common']: + target = mapping['common'][param][0 if target_tab == 't2v' else 1] + results[target] = value + + # Handle LoRA parameters + if param == 'lora_weights': + for i, weight in enumerate(value[:4]): + target = mapping['lora']['lora_weights'][i][1 if target_tab == 'v2v' else 0] + results[target] = weight + + if param == 'lora_multipliers': + for i, mult in enumerate(value[:4]): + target = mapping['lora']['lora_multipliers'][i][1 if target_tab == 'v2v' else 0] + results[target] = float(mult) + + return results + +def add_metadata_to_video(video_path: str, parameters: dict) -> None: + """Add generation parameters to video metadata using ffmpeg.""" + import json + import subprocess + + # Convert parameters to JSON string + params_json = json.dumps(parameters, indent=2) + + # Temporary output path + temp_path = video_path.replace(".mp4", "_temp.mp4") + + # FFmpeg command to add metadata without re-encoding + cmd = [ + 'ffmpeg', + '-i', video_path, + '-metadata', f'comment={params_json}', + '-codec', 'copy', + temp_path + ] + + try: + # Execute FFmpeg command + subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + # Replace original file with the metadata-enhanced version + os.replace(temp_path, video_path) + except subprocess.CalledProcessError as e: + print(f"Failed to add metadata: {e.stderr.decode()}") + if os.path.exists(temp_path): + os.remove(temp_path) + except Exception as e: + print(f"Error: {str(e)}") + +def count_prompt_tokens(prompt: str) -> int: + enc = tiktoken.get_encoding("cl100k_base") + tokens = enc.encode(prompt) + return len(tokens) + + +def get_lora_options(lora_folder: str = "lora") -> List[str]: + if not os.path.exists(lora_folder): + return ["None"] + lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] + lora_files.sort(key=str.lower) + return ["None"] + lora_files + +def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: + """Transfer selected video and prompt to Video2Video tab""" + if not gallery or evt.index >= len(gallery): + return None, "", selected_index.value + + selected_item = gallery[evt.index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Update the selected index + selected_index.value = evt.index + + return str(video_path), prompt, evt.index + +def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: + """Send the currently selected video to V2V tab""" + if not gallery or selected_index.value is None or selected_index.value >= len(gallery): + return None, "" + + selected_item = gallery[selected_index.value] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return str(video_path), prompt + +def clear_cuda_cache(): + """Clear CUDA cache if available""" + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Optional: synchronize to ensure cache is cleared + torch.cuda.synchronize() + +def process_single_video( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + lora1: str = "", + lora2: str = "", + lora3: str = "", + lora4: str = "", + lora1_multiplier: float = 1.0, + lora2_multiplier: float = 1.0, + lora3_multiplier: float = 1.0, + lora4_multiplier: float = 1.0, + video_path: Optional[str] = None, + image_path: Optional[str] = None, + strength: Optional[float] = None, + negative_prompt: Optional[str] = None, + embedded_cfg_scale: Optional[float] = None, + split_uncond: Optional[bool] = None, + guidance_scale: Optional[float] = None, + use_fp8: bool = True +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate a single video with the given parameters""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + if is_skyreels: + # Force certain parameters for SkyReels + if negative_prompt is None: + negative_prompt = "" + if embedded_cfg_scale is None: + embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels + if split_uncond is None: + split_uncond = True + if guidance_scale is None: + guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided + + # Determine the input channels based on model type + if is_skyreels_i2v: + dit_in_channels = 32 # SkyReels I2V uses 32 channels + else: + dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) + else: + dit_in_channels = 16 # Regular Hunyuan models use 16 channels + embedded_cfg_scale = cfg_scale + + if os.path.isabs(model): + model_path = model + else: + model_path = os.path.normpath(os.path.join(dit_folder, model)) + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + env["BATCH_RUN_ID"] = f"{time.time()}" + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) + if batch_size > 1: # Only modify seed for batch generation + current_seed = (seed + batch_id * 100003) % (2**32) + else: + current_seed = seed + + clear_cuda_cache() + + command = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128" + ] + + if use_fp8: + command.append("--fp8") + + # Add negative prompt and embedded cfg scale for SkyReels + if is_skyreels: + command.extend(["--dit_in_channels", str(dit_in_channels)]) + command.extend(["--guidance_scale", str(guidance_scale)]) + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + if split_uncond: + command.append("--split_uncond") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + if use_split_attn: + command.append("--split_attn") + + # Handle input paths + if video_path: + command.extend(["--video_path", video_path]) + if strength is not None: + command.extend(["--strength", str(strength)]) + elif image_path: + command.extend(["--image_path", image_path]) + # Only add strength parameter for non-SkyReels I2V models + # SkyReels I2V doesn't use strength parameter for image-to-video generation + if strength is not None and not is_skyreels_i2v: + command.extend(["--strength", str(strength)]) + + print(f"{command}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "model": model, + "vae": vae, + "te1": te1, + "te2": te2, + "save_path": save_path, + "flow_shift": flow_shift, + "cfg_scale": cfg_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "input_video": video_path if video_path else None, + "input_image": image_path if image_path else None, + "strength": strength, + "negative_prompt": negative_prompt if is_skyreels else None, + "embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +# The issue is in the process_batch function, in the section that handles different input types +# Here's the corrected version of that section: + +def process_batch( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + *args +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Process a batch of videos using Gradio's queue""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract additional arguments + num_lora_weights = 4 + lora_weights = args[:num_lora_weights] + lora_multipliers = args[num_lora_weights:num_lora_weights*2] + extra_args = args[num_lora_weights*2:] + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + # Handle input paths and additional parameters + input_path = extra_args[0] if extra_args else None + strength = float(extra_args[1]) if len(extra_args) > 1 else None + + # Get use_fp8 flag (it should be the last parameter) + use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True + + # Get SkyReels specific parameters if applicable + if is_skyreels: + # Always set embedded_cfg_scale to 1.0 for SkyReels models + embedded_cfg_scale = 1.0 + + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" + # Use cfg_scale for guidance_scale parameter + guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale + split_uncond = True if len(extra_args) > 4 and extra_args[4] else False + else: + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None + guidance_scale = cfg_scale + embedded_cfg_scale = cfg_scale + split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Handle different input types + video_path = None + image_path = None + + if input_path: + # Check if it's an image file (common image extensions) + is_image = False + lower_path = input_path.lower() + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') + is_image = any(lower_path.endswith(ext) for ext in image_extensions) + + # Only use image_path for SkyReels I2V models and actual image files + if is_skyreels_i2v and is_image: + image_path = input_path + else: + video_path = input_path + + # Prepare arguments for process_single_video + single_video_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + single_video_args.extend(lora_weights) + single_video_args.extend(lora_multipliers) + single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) + + for videos, status, progress in process_single_video(*single_video_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def update_wanx_image_dimensions(image): + """Update dimensions from uploaded image""" + if image is None: + return "", gr.update(value=832), gr.update(value=480) + img = Image.open(image) + w, h = img.size + w = (w // 32) * 32 + h = (h // 32) * 32 + return f"{w}x{h}", w, h + +def calculate_wanx_width(height, original_dims): + """Calculate width based on height maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 32) * 32 + return gr.update(value=new_width) + +def calculate_wanx_height(width, original_dims): + """Calculate height based on width maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 32) * 32 + return gr.update(value=new_height) + +def update_wanx_from_scale(scale, original_dims): + """Update dimensions based on scale percentage""" + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 32) * 32 + new_h = math.floor((orig_h * scale / 100) / 32) * 32 + return gr.update(value=new_w), gr.update(value=new_h) + +def recommend_wanx_flow_shift(width, height): + """Get recommended flow shift value based on dimensions""" + recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 + return gr.update(value=recommended_shift) + +def handle_wanx_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when gallery item is clicked""" + return evt.index + +def wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + attn_mode, + block_swap, + fp8, + fp8_t5 +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate video with WanX model (supports both i2v and t2v)""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + current_seed = seed + + # Check if we need input image (required for i2v, not for t2v) + if "i2v" in task and not input_image: + yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" + return + + # Prepare environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + clear_cuda_cache() + + command = [ + sys.executable, + "wan_generate_video.py", + "--task", task, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--dit", dit_path, + "--vae", vae_path, + "--t5", t5_path, + "--sample_solver", sample_solver + ] + + # Add image path only for i2v task and if input image is provided + if "i2v" in task and input_image: + command.extend(["--image_path", input_image]) + command.extend(["--clip", clip_path]) # CLIP is only needed for i2v + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + + if fp8: + command.append("--fp8") + + if fp8_t5: + command.append("--fp8_t5") + + print(f"Running: {' '.join(command)}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "task": task, + "flow_shift": flow_shift, + "guidance_scale": guidance_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "input_image": input_image if "i2v" in task else None + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +def send_wanx_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + guidance_scale: float, + negative_prompt: str +) -> Tuple: + """Send the selected WanX video to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def wanx_generate_video_batch( + prompt, + negative_prompt, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + attn_mode, + block_swap, + fp8, + fp8_t5, + batch_size=1, + input_image=None, # Optional for i2v + lora_folder=None, + *args +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate videos with WanX with support for batches and LoRA""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract LoRA parameters from args + num_loras = 4 # Fixed number of LoRA inputs + lora_weights = args[:num_loras] + lora_multipliers = args[num_loras:num_loras*2] + exclude_single_blocks = args[num_loras*2] if len(args) > num_loras*2 else False + + # Process each item in the batch + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Prepare command + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + command = [ + sys.executable, + "wan_generate_video.py", + "--task", task, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--dit", dit_path, + "--vae", vae_path, + "--t5", t5_path, + "--sample_solver", sample_solver + ] + + # Add image path if provided (for i2v) + if input_image and "i2v" in task: + command.extend(["--image_path", input_image]) + command.extend(["--clip", clip_path]) # CLIP is needed for i2v + + # Add negative prompt if provided + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + + # Add block swap if provided + if block_swap > 0: + command.extend(["--blocks_to_swap", str(block_swap)]) + + # Add fp8 flags if enabled + if fp8: + command.append("--fp8") + + if fp8_t5: + command.append("--fp8_t5") + + # Add LoRA parameters + valid_loras = [] + for j, (weight, mult) in enumerate(zip(lora_weights, lora_multipliers)): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), float(mult))) + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + # Add LoRA options + if exclude_single_blocks: + command.append("--exclude_single_blocks") + + print(f"Running: {' '.join(command)}") + + # Execute command + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + # Process output + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield all_videos, "Generation stopped by user", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + # Clean CUDA cache + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_video_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_video_files if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + videos.append((str(video_path), f"Seed: {current_seed}")) + all_videos.extend(videos) + + yield all_videos, "Batch complete", "" + +def update_wanx_t2v_dimensions(size): + """Update width and height based on selected size""" + width, height = map(int, size.split('*')) + return gr.update(value=width), gr.update(value=height) + +def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when gallery item is clicked""" + return evt.index + +def send_wanx_t2v_to_v2v( + gallery, prompt, selected_index, width, height, video_length, + fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt +) -> Tuple: + """Send the selected WanX T2V video to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +# UI setup +with gr.Blocks( + theme=themes.Default( + primary_hue=colors.Color( + name="custom", + c50="#E6F0FF", + c100="#CCE0FF", + c200="#99C1FF", + c300="#66A3FF", + c400="#3384FF", + c500="#0060df", # This is your main color + c600="#0052C2", + c700="#003D91", + c800="#002961", + c900="#001430", + c950="#000A18" + ) + ), + css=""" + .gallery-item:first-child { border: 2px solid #4CAF50 !important; } + .gallery-item:first-child:hover { border-color: #45a049 !important; } + .green-btn { + background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; + color: white !important; + border: none !important; + } + .green-btn:hover { + background: linear-gradient(to bottom right, #27ae60, #219651) !important; + } + .refresh-btn { + max-width: 40px !important; + min-width: 40px !important; + height: 40px !important; + border-radius: 50% !important; + padding: 0 !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + } + """, + +) as demo: + # Add state for tracking selected video indices in both tabs + selected_index = gr.State(value=None) # For Text to Video + v2v_selected_index = gr.State(value=None) # For Video to Video + params_state = gr.State() #New addition + i2v_selected_index = gr.State(value=None) + skyreels_selected_index = gr.State(value=None) + demo.load(None, None, None, js=""" + () => { + document.title = 'H1111'; + + function updateTitle(text) { + if (text && text.trim()) { + const progressMatch = text.match(/(\d+)%.*\[.*<(\d+:\d+),/); + if (progressMatch) { + const percentage = progressMatch[1]; + const timeRemaining = progressMatch[2]; + document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; + } + } + } + + setTimeout(() => { + const progressElements = document.querySelectorAll('textarea.scroll-hide'); + progressElements.forEach(element => { + if (element) { + new MutationObserver(() => { + updateTitle(element.value); + }).observe(element, { + attributes: true, + childList: true, + characterData: true + }); + } + }); + }, 1000); + } + """) + + with gr.Tabs() as tabs: + # Text to Video Tab + with gr.Tab(id=1, label="Text to Video"): + with gr.Row(): + with gr.Column(scale=4): + prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + + with gr.Column(scale=1): + token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + + t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") + t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") + video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") + fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") + infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") + flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") + cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") + + with gr.Column(): + + with gr.Row(): + video_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + with gr.Row(): + refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + lora_weights = [] + lora_multipliers = [] + for i in range(4): + with gr.Column(): + lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + seed = gr.Number(label="Seed (use -1 for random)", value=-1) + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + + #Image to Video Tab + with gr.Tab(label="Image to Video") as i2v_tab: + with gr.Row(): + with gr.Column(scale=4): + i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + + with gr.Column(scale=1): + i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + i2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + i2v_input = gr.Image(label="Input Image", type="filepath") + i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + # Scale slider as percentage + scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + # Width and height inputs + with gr.Row(): + width = gr.Number(label="New Width", value=544, step=16) + calc_height_btn = gr.Button("→") + calc_width_btn = gr.Button("←") + height = gr.Number(label="New Height", value=544, step=16) + i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) + with gr.Column(): + i2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + i2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for Image2Video + i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + i2v_lora_weights = [] + i2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + i2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + i2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + i2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + + i2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + i2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + i2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + i2v_save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + i2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + + # Video to Video Tab + with gr.Tab(id=2, label="Video to Video") as v2v_tab: + with gr.Row(): + with gr.Column(scale=4): + v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + v2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt (for SkyReels models)", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + v2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + v2v_input = gr.Video(label="Input Video", format="mp4") + v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and Height Inputs + with gr.Row(): + v2v_width = gr.Number(label="New Width", value=544, step=16) + v2v_calc_height_btn = gr.Button("→") + v2v_calc_width_btn = gr.Button("←") + v2v_height = gr.Number(label="New Height", value=544, step=16) + v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) + with gr.Column(): + v2v_output = gr.Gallery( + label="Generated Videos", + columns=[1], + rows=[1], + object_fit="contain", + height="auto" + ) + v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button + v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + v2v_lora_weights = [] + v2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + v2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + v2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + v2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + v2v_save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) + + with gr.Tab(label="SkyReels-i2v") as skyreels_tab: + with gr.Row(): + with gr.Column(scale=4): + skyreels_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + skyreels_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") + skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + + # Scale slider as percentage + skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height inputs + with gr.Row(): + skyreels_width = gr.Number(label="New Width", value=544, step=16) + skyreels_calc_height_btn = gr.Button("→") + skyreels_calc_width_btn = gr.Button("←") + skyreels_height = gr.Number(label="New Height", value=544, step=16) + + skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) + skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) + + with gr.Column(): + skyreels_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for SKYREELS + skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + skyreels_lora_weights = [] + skyreels_lora_multipliers = [] + for i in range(4): + with gr.Column(): + skyreels_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + skyreels_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + skyreels_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("skyreels"), + value="skyreels_hunyuan_i2v_bf16.safetensors", + allow_custom_value=True, + interactive=True + ) + skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) + + # WanX Image to Video Tab + with gr.Tab(label="WanX-i2v") as wanx_i2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + wanx_input = gr.Image(label="Input Image", type="filepath") + wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height display + with gr.Row(): + wanx_width = gr.Number(label="Width", value=832, interactive=True) + wanx_calc_height_btn = gr.Button("→") + wanx_calc_width_btn = gr.Button("←") + wanx_height = gr.Number(label="Height", value=480, interactive=True) + wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, + info="Recommended: 3.0 for 480p, 5.0 for others") + wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + wanx_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + with gr.Row(): + wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_lora_weights = [] + wanx_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_task = gr.Dropdown( + label="Task", + choices=["i2v-14B"], + value="i2v-14B", + info="Currently only i2v-14B is supported" + ) + wanx_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_i2v_480p_14B_bf16.safetensors") + wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") + wanx_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") + wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) + wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + + #WanX-t2v Tab + + # WanX Text to Video Tab + with gr.Tab(label="WanX-t2v") as wanx_t2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_t2v_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_t2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") + wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") + wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, + info="Recommended: 3.0 for I2V with 480p, 5.0 for others") + wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_t2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + with gr.Row(): + wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_t2v_lora_weights = [] + wanx_t2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_t2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_t2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_t2v_task = gr.Dropdown( + label="Task", + choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], + value="t2v-14B", + info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" + ) + wanx_t2v_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_t2v_14B_bf16.safetensors") + wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") + wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") + wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, + info="Max 39 for 14B model, 29 for 1.3B model") + wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + + + #Video Info Tab + with gr.Tab("Video Info") as video_info_tab: + with gr.Row(): + video_input = gr.Video(label="Upload Video", interactive=True) + metadata_output = gr.JSON(label="Generation Parameters") + + with gr.Row(): + send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") + send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") + + with gr.Row(): + status = gr.Textbox(label="Status", interactive=False) + + #Merge Model's tab + with gr.Tab("Convert LoRA") as convert_lora_tab: + def suggest_output_name(file_obj) -> str: + """Generate suggested output name from input file""" + if not file_obj: + return "" + # Get input filename without extension and add MUSUBI + base_name = os.path.splitext(os.path.basename(file_obj.name))[0] + return f"{base_name}_MUSUBI" + + def convert_lora(input_file, output_name: str, target_format: str) -> str: + """Convert LoRA file to specified format""" + try: + if not input_file: + return "Error: No input file selected" + + # Ensure output directory exists + os.makedirs("lora", exist_ok=True) + + # Construct output path + output_path = os.path.join("lora", f"{output_name}.safetensors") + + # Build command + cmd = [ + sys.executable, + "convert_lora.py", + "--input", input_file.name, + "--output", output_path, + "--target", target_format + ] + + print(f"Converting {input_file.name} to {output_path}") + + # Execute conversion + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + if os.path.exists(output_path): + return f"Successfully converted LoRA to {output_path}" + else: + return "Error: Output file not created" + + except subprocess.CalledProcessError as e: + return f"Error during conversion: {e.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + with gr.Row(): + input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) + output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") + format_radio = gr.Radio( + choices=["default", "other"], + value="default", + label="Target Format", + info="Choose 'default' for H1111/MUSUBI format or 'other' for diffusion pipe format" + ) + + with gr.Row(): + convert_btn = gr.Button("Convert LoRA", variant="primary") + status_output = gr.Textbox(label="Status", interactive=False) + + # Automatically update output name when file is selected + input_file.change( + fn=suggest_output_name, + inputs=[input_file], + outputs=[output_name] + ) + + # Handle conversion + convert_btn.click( + fn=convert_lora, + inputs=[input_file, output_name, format_radio], + outputs=status_output + ) + with gr.Tab("Model Merging") as model_merge_tab: + with gr.Row(): + with gr.Column(): + # Model selection + dit_model = gr.Dropdown( + label="Base DiT Model", + choices=["mp_rank_00_model_states.pt"], + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + with gr.Row(): + with gr.Column(): + # Output model name + output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + merge_btn = gr.Button("Merge Models", variant="primary") + merge_status = gr.Textbox(label="Status", interactive=False) + with gr.Row(): + # LoRA selection section (similar to Text2Video) + merge_lora_weights = [] + merge_lora_multipliers = [] + for i in range(4): + with gr.Column(): + merge_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + merge_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + + #text to video + def change_to_tab_one(): + return gr.Tabs(selected=1) #This will navigate + #video to video + def change_to_tab_two(): + return gr.Tabs(selected=2) #This will navigate + def change_to_skyreels_tab(): + return gr.Tabs(selected=3) + + #SKYREELS TAB!!! + # Add state management for dimensions + def sync_skyreels_dimensions(width, height): + return gr.update(value=width), gr.update(value=height) + + # Add this function to update the LoRA dropdowns in the SKYREELS tab + def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + + # Add this function to update the models dropdown in the SKYREELS tab + def update_skyreels_model_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Add event handler for model dropdown refresh + skyreels_dit_folder.change( + fn=update_skyreels_model_dropdown, + inputs=[skyreels_dit_folder], + outputs=[skyreels_model] + ) + + # Add handlers for the refresh button + skyreels_refresh_btn.click( + fn=update_skyreels_lora_dropdowns, + inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] + ) + # Skyreels dimension handling + def calculate_skyreels_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 + return gr.update(value=new_width) + + def calculate_skyreels_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 + return gr.update(value=new_height) + + def update_skyreels_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_skyreels_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + def send_skyreels_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float, + negative_prompt: str = "" # Add this parameter + ) -> Tuple: + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + # Add event handlers for the SKYREELS tab + skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) + skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling + skyreels_input.change( + fn=update_skyreels_dimensions, + inputs=[skyreels_input], + outputs=[skyreels_original_dims, skyreels_width, skyreels_height] + ) + + skyreels_scale_slider.change( + fn=update_skyreels_from_scale, + inputs=[skyreels_scale_slider, skyreels_original_dims], + outputs=[skyreels_width, skyreels_height] + ) + + skyreels_calc_width_btn.click( + fn=calculate_skyreels_width, + inputs=[skyreels_height, skyreels_original_dims], + outputs=[skyreels_width] + ) + + skyreels_calc_height_btn.click( + fn=calculate_skyreels_height, + inputs=[skyreels_width, skyreels_original_dims], + outputs=[skyreels_height] + ) + + # SKYREELS tab generator button handler + skyreels_generate_btn.click( + fn=process_batch, + inputs=[ + skyreels_prompt, + skyreels_width, + skyreels_height, + skyreels_batch_size, + skyreels_video_length, + skyreels_fps, + skyreels_infer_steps, + skyreels_seed, + skyreels_dit_folder, + skyreels_model, + skyreels_vae, + skyreels_te1, + skyreels_te2, + skyreels_save_path, + skyreels_flow_shift, + skyreels_embedded_cfg_scale, + skyreels_output_type, + skyreels_attn_mode, + skyreels_block_swap, + skyreels_exclude_single_blocks, + skyreels_use_split_attn, + skyreels_lora_folder, + *skyreels_lora_weights, + *skyreels_lora_multipliers, + skyreels_input, + skyreels_strength, + skyreels_negative_prompt, + skyreels_guidance_scale, + skyreels_split_uncond, + skyreels_use_fp8 + ], + outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[skyreels_batch_size], + outputs=skyreels_selected_index + ) + + # Gallery selection handling + skyreels_output.select( + fn=handle_skyreels_gallery_select, + outputs=skyreels_selected_index + ) + + # Send to Video2Video handler + skyreels_send_to_v2v_btn.click( + fn=send_skyreels_to_v2v, + inputs=[ + skyreels_output, skyreels_prompt, skyreels_selected_index, + skyreels_width, skyreels_height, skyreels_video_length, + skyreels_fps, skyreels_infer_steps, skyreels_seed, + skyreels_flow_shift, skyreels_guidance_scale + ] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component + outputs=[ + v2v_input, v2v_prompt, v2v_width, v2v_height, + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + + # Refresh button handler + skyreels_refresh_outputs = [skyreels_model] + for i in range(4): + skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) + + skyreels_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=skyreels_refresh_outputs + ) + + # Add skyreels_selected_index to the initial states at the beginning of the script + skyreels_selected_index = gr.State(value=None) # Add this with other state declarations + + def calculate_v2v_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_width) + + def calculate_v2v_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_height) + + def update_v2v_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_v2v_dimensions(video): + if video is None: + return "", gr.update(value=544), gr.update(value=544) + cap = cv2.VideoCapture(video) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + # Make dimensions divisible by 16 + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + # Event Handlers for Video to Video Tab + v2v_input.change( + fn=update_v2v_dimensions, + inputs=[v2v_input], + outputs=[v2v_original_dims, v2v_width, v2v_height] + ) + + v2v_scale_slider.change( + fn=update_v2v_from_scale, + inputs=[v2v_scale_slider, v2v_original_dims], + outputs=[v2v_width, v2v_height] + ) + + v2v_calc_width_btn.click( + fn=calculate_v2v_width, + inputs=[v2v_height, v2v_original_dims], + outputs=[v2v_width] + ) + + v2v_calc_height_btn.click( + fn=calculate_v2v_height, + inputs=[v2v_width, v2v_original_dims], + outputs=[v2v_height] + ) + + ##Image 2 video dimension logic + def calculate_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_width) + + def calculate_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_height) + + def update_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + # Make dimensions divisible by 16 + w = (w // 16) * 16 # Changed from 8 to 16 + h = (h // 16) * 16 # Changed from 8 to 16 + return f"{w}x{h}", w, h + i2v_input.change( + fn=update_dimensions, + inputs=[i2v_input], + outputs=[original_dims, width, height] + ) + + scale_slider.change( + fn=update_from_scale, + inputs=[scale_slider, original_dims], + outputs=[width, height] + ) + + calc_width_btn.click( + fn=calculate_width, + inputs=[height, original_dims], + outputs=[width] + ) + + calc_height_btn.click( + fn=calculate_height, + inputs=[width, original_dims], + outputs=[height] + ) + + # Function to get available DiT models + def get_dit_models(dit_folder: str) -> List[str]: + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + + # Function to perform model merging + def merge_models( + dit_folder: str, + dit_model: str, + output_model: str, + exclude_single_blocks: bool, + merge_lora_folder: str, + *lora_params # Will contain both weights and multipliers + ) -> str: + try: + # Separate weights and multipliers + num_loras = len(lora_params) // 2 + weights = list(lora_params[:num_loras]) + multipliers = list(lora_params[num_loras:]) + + # Filter out "None" selections + valid_loras = [] + for weight, mult in zip(weights, multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) + + if not valid_loras: + return "No LoRA models selected for merging" + + # Create output path in the dit folder + os.makedirs(dit_folder, exist_ok=True) + output_path = os.path.join(dit_folder, output_model) + + # Prepare command + cmd = [ + sys.executable, + "merge_lora.py", + "--dit", os.path.join(dit_folder, dit_model), + "--save_merged_model", output_path + ] + + # Add LoRA weights and multipliers + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + # Execute merge operation + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + if os.path.exists(output_path): + return f"Successfully merged model and saved to {output_path}" + else: + return "Error: Output file not created" + + except subprocess.CalledProcessError as e: + return f"Error during merging: {e.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + # Update DiT model dropdown + def update_dit_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Connect events + merge_btn.click( + fn=merge_models, + inputs=[ + dit_folder, + dit_model, + output_model, + exclude_single_blocks, + merge_lora_folder, + *merge_lora_weights, + *merge_lora_multipliers + ], + outputs=merge_status + ) + + # Refresh buttons for both DiT and LoRA dropdowns + merge_refresh_btn.click( + fn=lambda f: update_dit_dropdown(f), + inputs=[dit_folder], + outputs=[dit_model] + ) + + # LoRA refresh handling + merge_refresh_outputs = [] + for i in range(4): + merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) + + merge_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, + outputs=merge_refresh_outputs + ) + # Event handlers + prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) + v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) + stop_btn.click(fn=lambda: stop_event.set(), queue=False) + v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + #Image_to_Video + def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters + img = Image.open(image_path) + + # Resize to the specified dimensions + img_resized = img.resize((width, height), Image.LANCZOS) + temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") + img_resized.save(temp_image_path) + + # Rest of function remains the same + frame_rate = 24 + duration = frames / frame_rate + command = [ + "ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", + "-t", str(duration), "-pix_fmt", "yuv420p", + "-vf", f"fps={frame_rate}", output_path + ] + + try: + subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + print(f"Video saved to {output_path}") + return True + except subprocess.CalledProcessError as e: + print(f"An error occurred while creating the video: {e}") + return False + finally: + # Clean up the temporary image file + if os.path.exists(temp_image_path): + os.remove(temp_image_path) + img.close() # Make sure to close the image file explicitly + + def generate_from_image( + image_path, + prompt, width, height, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, strength, batch_size, *lora_params + ): + """Generate video from input image with progressive updates""" + global stop_event + stop_event.clear() + + # Create temporary video path + temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") + + try: + # Convert image to video + if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): + yield [], "Failed to create temporary video", "Error in video creation" + return + + # Ensure video is fully written before proceeding + time.sleep(1) + if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: + yield [], "Failed to create temporary video", "Temporary video file is empty or missing" + return + + # Get video dimensions + try: + probe = ffmpeg.probe(temp_video_path) + video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) + if video_stream is None: + raise ValueError("No video stream found") + width = int(video_stream['width']) + height = int(video_stream['height']) + except Exception as e: + yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" + return + + # Generate the video using the temporary file + try: + generator = process_single_video( + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_params, video_path=temp_video_path, strength=strength + ) + + # Forward all generator updates + for videos, batch_text, progress_text in generator: + yield videos, batch_text, progress_text + + except Exception as e: + yield [], f"Error in video generation: {str(e)}", "Generation error" + return + + except Exception as e: + yield [], f"Unexpected error: {str(e)}", "Error occurred" + return + + finally: + # Clean up temporary file + try: + if os.path.exists(temp_video_path): + os.remove(temp_video_path) + except Exception: + pass # Ignore cleanup errors + + + # Add event handlers + i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) + i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + def handle_i2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when I2V gallery item is clicked""" + return evt.index + + def send_i2v_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: + """Send the selected video and parameters from Image2Video tab to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ + lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Use the original width and height without doubling + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) + + # Generate button handler + i2v_generate_btn.click( + fn=process_batch, + inputs=[ + i2v_prompt, width, height, + i2v_batch_size, i2v_video_length, + i2v_fps, i2v_infer_steps, i2v_seed, i2v_dit_folder, i2v_model, i2v_vae, i2v_te1, i2v_te2, + i2v_save_path, i2v_flow_shift, i2v_cfg_scale, i2v_output_type, i2v_attn_mode, + i2v_block_swap, i2v_exclude_single_blocks, i2v_use_split_attn, i2v_lora_folder, + *i2v_lora_weights, *i2v_lora_multipliers, i2v_input, i2v_strength, i2v_use_fp8 + ], + outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[i2v_batch_size], + outputs=i2v_selected_index + ) + # Send to Video2Video + i2v_output.select( + fn=handle_i2v_gallery_select, + outputs=i2v_selected_index + ) + + i2v_send_to_v2v_btn.click( + fn=send_i2v_to_v2v, + inputs=[ + i2v_output, i2v_prompt, i2v_selected_index, + width, height, + i2v_video_length, i2v_fps, i2v_infer_steps, + i2v_seed, i2v_flow_shift, i2v_cfg_scale + ] + i2v_lora_weights + i2v_lora_multipliers, + outputs=[ + v2v_input, v2v_prompt, + v2v_width, v2v_height, + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + #Video Info + def clean_video_path(video_path) -> str: + """Extract clean video path from Gradio's various return formats""" + print(f"Input video_path: {video_path}, type: {type(video_path)}") + if isinstance(video_path, dict): + path = video_path.get("name", "") + elif isinstance(video_path, (tuple, list)): + path = video_path[0] + elif isinstance(video_path, str): + path = video_path + else: + path = "" + print(f"Cleaned path: {path}") + return path + def handle_video_upload(video_path: str) -> Dict: + """Handle video upload and metadata extraction""" + if not video_path: + return {}, "No video uploaded" + + metadata = extract_video_metadata(video_path) + if not metadata: + return {}, "No metadata found in video" + + return metadata, "Metadata extracted successfully" + + def get_video_info(video_path: str) -> dict: + try: + probe = ffmpeg.probe(video_path) + video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') + + width = int(video_info['width']) + height = int(video_info['height']) + fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 + + # Calculate total frames + duration = float(probe['format']['duration']) + total_frames = int(duration * fps) + + # Ensure video length does not exceed 201 frames + if total_frames > 201: + total_frames = 201 + duration = total_frames / fps # Adjust duration accordingly + + return { + 'width': width, + 'height': height, + 'fps': fps, + 'total_frames': total_frames, + 'duration': duration # Might be useful in some contexts + } + except Exception as e: + print(f"Error extracting video info: {e}") + return {} + + def extract_video_details(video_path: str) -> Tuple[dict, str]: + metadata = extract_video_metadata(video_path) + video_details = get_video_info(video_path) + + # Combine metadata with video details + for key, value in video_details.items(): + if key not in metadata: + metadata[key] = value + + # Ensure video length does not exceed 201 frames + if 'video_length' in metadata: + metadata['video_length'] = min(metadata['video_length'], 201) + else: + metadata['video_length'] = min(video_details.get('total_frames', 0), 201) + + # Return both the updated metadata and a status message + return metadata, "Video details extracted successfully" + + def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: + """Create parameter mapping for target tab""" + if not metadata: + return "No parameters to send", {} + + tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" + try: + mapping = create_parameter_transfer_map(metadata, target_tab) + return f"Parameters ready for {tab_name}", mapping + except Exception as e: + return f"Error: {str(e)}", {} + + video_input.upload( + fn=extract_video_details, + inputs=video_input, + outputs=[metadata_output, status] + ) + + send_to_t2v_btn.click( + fn=lambda m: send_parameters_to_tab(m, "t2v"), + inputs=metadata_output, + outputs=[status, params_state] + ).then( + fn=change_to_tab_one, inputs=None, outputs=[tabs] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 544), + params.get("height", 544), + params.get("batch_size", 1), + params.get("video_length", 25), + params.get("fps", 24), + params.get("infer_steps", 30), + params.get("seed", -1), + params.get("model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("vae", "hunyuan/pytorch_model.pt"), + params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("te2", "hunyuan/clip_l.safetensors"), + params.get("save_path", "outputs"), + params.get("flow_shift", 11.0), + params.get("cfg_scale", 7.0), + params.get("output_type", "video"), + params.get("attn_mode", "sdpa"), + params.get("block_swap", "0"), + *[params.get(f"lora{i+1}", "") for i in range(4)], + *[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] + ] if params else [gr.update()]*26, + inputs=params_state, + outputs=[prompt, width, height, batch_size, video_length, fps, infer_steps, seed, + model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap] + lora_weights + lora_multipliers + ) + # Text to Video generation + generate_btn.click( + fn=process_batch, + inputs=[ + prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 + ], + outputs=[video_output, batch_progress, progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[batch_size], + outputs=selected_index + ) + + # Update gallery selection handling + def handle_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + # Track selected index when gallery item is clicked + video_output.select( + fn=handle_gallery_select, + outputs=selected_index + ) + + # Track selected index when Video2Video gallery item is clicked + def handle_v2v_gallery_select(evt: gr.SelectData) -> int: + """Handle gallery selection without automatically updating the input""" + return evt.index + + # Update the gallery selection event + v2v_output.select( + fn=handle_v2v_gallery_select, + outputs=v2v_selected_index + ) + + # Send button handler with gallery selection + def handle_send_button( + gallery: list, + prompt: str, + idx: int, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> tuple: + if not gallery or idx is None or idx >= len(gallery): + return (None, "", width, height, batch_size, video_length, fps, infer_steps, + seed, flow_shift, cfg_scale, + lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + "") # Add empty string for negative_prompt in the return values + + # Auto-select first item if only one exists and no selection made + if idx is None and len(gallery) == 1: + idx = 0 + + selected_item = gallery[idx] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return ( + str(video_path), + prompt, + width, + height, + batch_size, + video_length, + fps, + infer_steps, + seed, + flow_shift, + cfg_scale, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier, + "" # Add empty string for negative_prompt + ) + + send_t2v_to_v2v_btn.click( + fn=handle_send_button, + inputs=[ + video_output, prompt, selected_index, + t2v_width, t2v_height, batch_size, video_length, + fps, infer_steps, seed, flow_shift, cfg_scale + ] + lora_weights + lora_multipliers, # Remove the string here + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_batch_size, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: + """Handle both parameters and video transfer""" + status_msg, params = send_parameters_to_tab(metadata, "v2v") + return status_msg, params, video_path + + def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: + """Handle both parameters and video transfer from Video Info to V2V tab""" + if not video_path: + return "No video selected", {}, None + + status_msg, params = send_parameters_to_tab(metadata, "v2v") + # Just return the path directly + return status_msg, params, video_path + + # Send button click handler + send_to_v2v_btn.click( + fn=handle_info_to_v2v, + inputs=[metadata_output, video_input], + outputs=[status, params_state, v2v_input] + ).then( + lambda params: [ + params.get("v2v_prompt", ""), + params.get("v2v_width", 544), + params.get("v2v_height", 544), + params.get("v2v_batch_size", 1), + params.get("v2v_video_length", 25), + params.get("v2v_fps", 24), + params.get("v2v_infer_steps", 30), + params.get("v2v_seed", -1), + params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("v2v_vae", "hunyuan/pytorch_model.pt"), + params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("v2v_te2", "hunyuan/clip_l.safetensors"), + params.get("v2v_save_path", "outputs"), + params.get("v2v_flow_shift", 11.0), + params.get("v2v_cfg_scale", 7.0), + params.get("v2v_output_type", "video"), + params.get("v2v_attn_mode", "sdpa"), + params.get("v2v_block_swap", "0"), + *[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], + *[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] + ] if params else [gr.update()] * 26, + inputs=params_state, + outputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, + v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, + v2v_attn_mode, v2v_block_swap + ] + v2v_lora_weights + v2v_lora_multipliers + ).then( + lambda: print(f"Tabs object: {tabs}"), # Debug print + outputs=None + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + # Handler for sending selected video from Video2Video gallery to input + def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: + """Send the currently selected video in V2V gallery to V2V input""" + if not gallery or idx is None or idx >= len(gallery): + return None, "" + + selected_item = gallery[idx] + video_path = None + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] # Gallery returns (path, caption) + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, str): + video_path = selected_item + + if not video_path: + return None, "" + + # Check if the file exists and is accessible + if not os.path.exists(video_path): + print(f"Warning: Video file not found at {video_path}") + return None, "" + + return video_path, prompt + + v2v_send_to_input_btn.click( + fn=handle_v2v_send_button, + inputs=[v2v_output, v2v_prompt, v2v_selected_index], + outputs=[v2v_input, v2v_prompt] + ).then( + lambda: gr.update(visible=True), # Ensure the video input is visible + outputs=v2v_input + ) + + # Video to Video generation + v2v_generate_btn.click( + fn=process_batch, + inputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, + v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, + v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, + *v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, + v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 + ], + outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[v2v_batch_size], + outputs=v2v_selected_index + ) + refresh_outputs = [model] # Add model dropdown to outputs + for i in range(4): + refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) + + refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, + outputs=refresh_outputs + ) + # Image2Video refresh + i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs + for i in range(4): + i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) + + i2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, + outputs=i2v_refresh_outputs + ) + + # Video2Video refresh + v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs + for i in range(4): + v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) + + v2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, + outputs=v2v_refresh_outputs + ) + + # WanX-i2v tab connections + wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) + wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling for WanX-i2v + wanx_input.change( + fn=update_wanx_image_dimensions, + inputs=[wanx_input], + outputs=[wanx_original_dims, wanx_width, wanx_height] + ) + + # Scale slider handling for WanX-i2v + wanx_scale_slider.change( + fn=update_wanx_from_scale, + inputs=[wanx_scale_slider, wanx_original_dims], + outputs=[wanx_width, wanx_height] + ) + + # Width/height calculation buttons for WanX-i2v + wanx_calc_width_btn.click( + fn=calculate_wanx_width, + inputs=[wanx_height, wanx_original_dims], + outputs=[wanx_width] + ) + + wanx_calc_height_btn.click( + fn=calculate_wanx_height, + inputs=[wanx_width, wanx_original_dims], + outputs=[wanx_height] + ) + + # Flow shift recommendation buttons + wanx_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_width, wanx_height], + outputs=[wanx_flow_shift] + ) + + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Generate button handler + wanx_generate_btn.click( + fn=wanx_generate_video_batch, + inputs=[ + wanx_prompt, + wanx_negative_prompt, + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_flow_shift, + wanx_guidance_scale, + wanx_seed, + wanx_task, + wanx_dit_path, + wanx_vae_path, + wanx_t5_path, + wanx_clip_path, + wanx_save_path, + wanx_output_type, + wanx_sample_solver, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, + wanx_fp8_t5, + wanx_batch_size, + wanx_input, # Image input + wanx_lora_folder, + *wanx_lora_weights, + *wanx_lora_multipliers, + wanx_exclude_single_blocks + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_batch_size], + outputs=skyreels_selected_index + ) + + # Gallery selection handling + wanx_output.select( + fn=handle_wanx_gallery_select, + outputs=skyreels_selected_index # Reuse the skyreels_selected_index + ) + + # Send to Video2Video handler + wanx_send_to_v2v_btn.click( + fn=send_wanx_to_v2v, + inputs=[ + wanx_output, + wanx_prompt, + skyreels_selected_index, # Reuse the skyreels_selected_index + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_seed, + wanx_flow_shift, + wanx_guidance_scale, + wanx_negative_prompt + ], + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + + # Add state for T2V tab selected index + wanx_t2v_selected_index = gr.State(value=None) + + # Connect prompt token counter + wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) + + # Stop button handler + wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Flow shift recommendation button + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Task change handler to update CLIP visibility and path + def update_clip_visibility(task): + is_i2v = "i2v" in task + return gr.update(visible=is_i2v) + + wanx_t2v_task.change( + fn=update_clip_visibility, + inputs=[wanx_t2v_task], + outputs=[wanx_t2v_clip_path] + ) + + # Generate button handler for T2V + wanx_t2v_generate_btn.click( + fn=wanx_generate_video_batch, + inputs=[ + wanx_t2v_prompt, + wanx_t2v_negative_prompt, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_seed, + wanx_t2v_task, + wanx_t2v_dit_path, + wanx_t2v_vae_path, + wanx_t2v_t5_path, + wanx_t2v_clip_path, + wanx_t2v_save_path, + wanx_t2v_output_type, + wanx_t2v_sample_solver, + wanx_t2v_attn_mode, + wanx_t2v_block_swap, + wanx_t2v_fp8, + wanx_t2v_fp8_t5, + wanx_t2v_batch_size, + wanx_t2v_lora_folder, + *wanx_t2v_lora_weights, + *wanx_t2v_lora_multipliers, + wanx_t2v_exclude_single_blocks + ], + outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_t2v_batch_size], + outputs=wanx_t2v_selected_index + ) + + # Gallery selection handling + wanx_t2v_output.select( + fn=handle_wanx_t2v_gallery_select, + outputs=wanx_t2v_selected_index + ) + + # Send to Video2Video handler + wanx_t2v_send_to_v2v_btn.click( + fn=send_wanx_t2v_to_v2v, + inputs=[ + wanx_t2v_output, + wanx_t2v_prompt, + wanx_t2v_selected_index, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_seed, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_negative_prompt + ], + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + + # Refresh handlers for WanX-i2v + wanx_refresh_outputs = [] + for i in range(4): + wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) + + wanx_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, + outputs=wanx_refresh_outputs + ) + + # Refresh handlers for WanX-t2v + wanx_t2v_refresh_outputs = [] + for i in range(4): + wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) + + wanx_t2v_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, + outputs=wanx_t2v_refresh_outputs + ) +demo.queue().launch(server_name="0.0.0.0", share=False) \ No newline at end of file diff --git a/hunyuan/put_hunyuan_files_here.txt b/hunyuan/put_hunyuan_files_here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuan_model/__init__.py b/hunyuan_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuan_model/activation_layers.py b/hunyuan_model/activation_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f8774c26ceef6081482ca0dbbf930b207d4ac03b --- /dev/null +++ b/hunyuan_model/activation_layers.py @@ -0,0 +1,23 @@ +import torch.nn as nn + + +def get_activation_layer(act_type): + """get activation layer + + Args: + act_type (str): the activation type + + Returns: + torch.nn.functional: the activation layer + """ + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + # Approximate `tanh` requires torch >= 1.13 + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") diff --git a/hunyuan_model/attention.py b/hunyuan_model/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e94253df0aceb11e4f5812b728df75b9d38bf8c2 --- /dev/null +++ b/hunyuan_model/attention.py @@ -0,0 +1,295 @@ +import importlib.metadata +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + import flash_attn + from flash_attn.flash_attn_interface import _flash_attn_forward + from flash_attn.flash_attn_interface import flash_attn_varlen_func + from flash_attn.flash_attn_interface import flash_attn_func +except ImportError: + flash_attn = None + flash_attn_varlen_func = None + _flash_attn_forward = None + flash_attn_func = None + +try: + print(f"Trying to import sageattention") + from sageattention import sageattn_varlen, sageattn + + print("Successfully imported sageattention") +except ImportError: + print(f"Failed to import sageattention") + sageattn_varlen = None + sageattn = None + +try: + import xformers.ops as xops +except ImportError: + xops = None + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "flash_fixlen": ( + lambda x: x, + lambda x: x, + ), + "sageattn": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "sageattn_fixlen": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "xformers": ( + lambda x: x, + lambda x: x, + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def get_cu_seqlens(text_mask, img_len): + """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len + + Args: + text_mask (torch.Tensor): the mask of text + img_len (int): the length of image + + Returns: + torch.Tensor: the calculated cu_seqlens for flash attention + """ + batch_size = text_mask.shape[0] + text_len = text_mask.sum(dim=1) + max_len = text_mask.shape[1] + img_len + + cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") + + for i in range(batch_size): + s = text_len[i] + img_len + s1 = i * max_len + s + s2 = (i + 1) * max_len + cu_seqlens[2 * i + 1] = s1 + cu_seqlens[2 * i + 2] = s2 + + return cu_seqlens + + +def attention( + q_or_qkv_list, + k=None, + v=None, + mode="flash", + drop_rate=0, + attn_mask=None, + total_len=None, + causal=False, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v) + if type(q_or_qkv_list) == list: + q_or_qkv_list.clear() + split_attn = total_len is not None + if split_attn and mode == "sageattn": + mode = "sageattn_fixlen" + elif split_attn and mode == "flash": + mode = "flash_fixlen" + # print(f"Attention mode: {mode}, split_attn: {split_attn}") + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + # trim the sequence length to the actual length instead of attn_mask + if split_attn: + trimmed_len = q.shape[1] - total_len + q = [q[i : i + 1, : total_len[i]] for i in range(len(q))] + k = [k[i : i + 1, : total_len[i]] for i in range(len(k))] + v = [v[i : i + 1, : total_len[i]] for i in range(len(v))] + q = [pre_attn_layout(q_i) for q_i in q] + k = [pre_attn_layout(k_i) for k_i in k] + v = [pre_attn_layout(v_i) for v_i in v] + # print( + # f"Trimming the sequence length to {total_len},trimmed_len: {trimmed_len}, q.shape: {[q_i.shape for q_i in q]}, mode: {mode}" + # ) + else: + q = pre_attn_layout(q) + k = pre_attn_layout(k) + v = pre_attn_layout(v) + + if mode == "torch": + if split_attn: + x = [] + for i in range(len(q)): + x_i = F.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate, is_causal=causal) + q[i], k[i], v[i] = None, None, None + x.append(x_i) + del q, k, v + else: + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + del q, k, v + del attn_mask + + elif mode == "xformers": + # B, M, H, K: M is the sequence length, H is the number of heads, K is the dimension of the heads -> it is same as input dimension + # currently only support batch_size = 1 + assert split_attn, "Xformers only supports splitting" + x = [] + for i in range(len(q)): + x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) # , causal=causal) + q[i], k[i], v[i] = None, None, None + x.append(x_i) + del q, k, v + + elif mode == "flash": + x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + del q, k, v + # x with shape [(bxs), a, d] + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "flash_fixlen": + x = [] + for i in range(len(q)): + # q: (batch_size, seqlen, nheads, headdim), k: (batch_size, seqlen, nheads_k, headdim), v: (batch_size, seqlen, nheads_k, headdim) + x_i = flash_attn_func(q[i], k[i], v[i], dropout_p=drop_rate, causal=causal) + q[i], k[i], v[i] = None, None, None + x.append(x_i) + del q, k, v + elif mode == "sageattn": + x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + del q, k, v + # x with shape [(bxs), a, d] + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "sageattn_fixlen": + x = [] + for i in range(len(q)): + # HND seems to cause an error + x_i = sageattn(q[i], k[i], v[i]) # (batch_size, seq_len, head_num, head_dim) + q[i], k[i], v[i] = None, None, None + x.append(x_i) + del q, k, v + elif mode == "vanilla": + assert not split_attn, "Vanilla attention does not support trimming" + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + # TODO: Maybe force q and k to be float32 to avoid numerical overflow + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + if split_attn: + x = [post_attn_layout(x_i) for x_i in x] + for i in range(len(x)): + x[i] = F.pad(x[i], (0, 0, 0, 0, 0, trimmed_len[i])) + x = torch.cat(x, dim=0) + else: + x = post_attn_layout(x) + + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv): + attn1 = hybrid_seq_parallel_attn( + None, + q[:, :img_q_len, :, :], + k[:, :img_kv_len, :, :], + v[:, :img_kv_len, :, :], + dropout_p=0.0, + causal=False, + joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]], + joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]], + joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]], + joint_strategy="rear", + ) + if flash_attn.__version__ >= "2.7.0": + attn2, *_ = _flash_attn_forward( + q[:, cu_seqlens_q[1] :], + k[:, cu_seqlens_kv[1] :], + v[:, cu_seqlens_kv[1] :], + dropout_p=0.0, + softmax_scale=q.shape[-1] ** (-0.5), + causal=False, + window_size_left=-1, + window_size_right=-1, + softcap=0.0, + alibi_slopes=None, + return_softmax=False, + ) + else: + attn2, *_ = _flash_attn_forward( + q[:, cu_seqlens_q[1] :], + k[:, cu_seqlens_kv[1] :], + v[:, cu_seqlens_kv[1] :], + dropout_p=0.0, + softmax_scale=q.shape[-1] ** (-0.5), + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + return_softmax=False, + ) + attn = torch.cat([attn1, attn2], dim=1) + b, s, a, d = attn.shape + attn = attn.reshape(b, s, -1) + + return attn diff --git a/hunyuan_model/autoencoder_kl_causal_3d.py b/hunyuan_model/autoencoder_kl_causal_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e70737325a50e1ee1fbbee96b4a0aafbdcd241 --- /dev/null +++ b/hunyuan_model/autoencoder_kl_causal_3d.py @@ -0,0 +1,609 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== +from typing import Dict, Optional, Tuple, Union +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config + +# try: +# # This diffusers is modified and packed in the mirror. +# from diffusers.loaders import FromOriginalVAEMixin +# except ImportError: +# # Use this to be compatible with the original diffusers. +# from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D + + +@dataclass +class DecoderOutput2(BaseOutput): + sample: torch.FloatTensor + posterior: Optional[DiagonalGaussianDistribution] = None + + +class AutoencoderKLCausal3D(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",), + up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + sample_tsize: int = 64, + scaling_factor: float = 0.18215, + force_upcast: float = True, + spatial_compression_ratio: int = 8, + time_compression_ratio: int = 4, + mid_block_add_attention: bool = True, + ): + super().__init__() + + self.time_compression_ratio = time_compression_ratio + + self.encoder = EncoderCausal3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + time_compression_ratio=time_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.decoder = DecoderCausal3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + time_compression_ratio=time_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) + + self.use_slicing = False + self.use_spatial_tiling = False + self.use_temporal_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_tsize = sample_tsize + self.tile_latent_min_tsize = sample_tsize // time_compression_ratio + + self.tile_sample_min_size = self.config.sample_size + sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (EncoderCausal3D, DecoderCausal3D)): + module.gradient_checkpointing = value + + def enable_temporal_tiling(self, use_tiling: bool = True): + self.use_temporal_tiling = use_tiling + + def disable_temporal_tiling(self): + self.enable_temporal_tiling(False) + + def enable_spatial_tiling(self, use_tiling: bool = True): + self.use_spatial_tiling = use_tiling + + def disable_spatial_tiling(self): + self.enable_spatial_tiling(False) + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger videos. + """ + self.enable_spatial_tiling(use_tiling) + self.enable_temporal_tiling(use_tiling) + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.disable_spatial_tiling() + self.disable_temporal_tiling() + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def set_chunk_size_for_causal_conv_3d(self, chunk_size: int): + # set chunk_size to CausalConv3d recursively + def set_chunk_size(module): + if hasattr(module, "chunk_size"): + module.chunk_size = chunk_size + + self.apply(set_chunk_size) + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images/videos into latents. + + Args: + x (`torch.FloatTensor`): Input batch of images/videos. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images/videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + assert len(x.shape) == 5, "The input tensor should have 5 dimensions." + + if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize: + return self.temporal_tiled_encode(x, return_dict=return_dict) + + if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.spatial_tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + assert len(z.shape) == 5, "The input tensor should have 5 dimensions." + + if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: + return self.temporal_tiled_decode(z, return_dict=return_dict) + + if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.spatial_tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images/videos. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent) + return b + + def spatial_tiled_encode( + self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False + ) -> AutoencoderKLOutput: + r"""Encode a batch of images/videos using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images/videos. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split video into tiles and encode them separately. + rows = [] + for i in range(0, x.shape[-2], overlap_size): + row = [] + for j in range(0, x.shape[-1], overlap_size): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + moments = torch.cat(result_rows, dim=-2) + if return_moments: + return moments + + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Decode a batch of images/videos using a tiled decoder. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[-2], overlap_size): + row = [] + for j in range(0, z.shape[-1], overlap_size): + tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=-2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + + B, C, T, H, W = x.shape + overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) + t_limit = self.tile_latent_min_tsize - blend_extent + + # Split the video into tiles and encode them separately. + row = [] + for i in range(0, T, overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :] + if self.use_spatial_tiling and ( + tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size + ): + tile = self.spatial_tiled_encode(tile, return_moments=True) + else: + tile = self.encoder(tile) + tile = self.quant_conv(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + else: + result_row.append(tile[:, :, : t_limit + 1, :, :]) + + moments = torch.cat(result_row, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + # Split z into overlapping tiles and decode them separately. + + B, C, T, H, W = z.shape + overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) + t_limit = self.tile_sample_min_tsize - blend_extent + + row = [] + for i in range(0, T, overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :] + if self.use_spatial_tiling and ( + tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size + ): + decoded = self.spatial_tiled_decode(tile, return_dict=True).sample + else: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + else: + result_row.append(tile[:, :, : t_limit + 1, :, :]) + + dec = torch.cat(result_row, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + return_posterior: bool = False, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput2, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + if return_posterior: + return (dec, posterior) + else: + return (dec,) + if return_posterior: + return DecoderOutput2(sample=dec, posterior=posterior) + else: + return DecoderOutput2(sample=dec) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) diff --git a/hunyuan_model/embed_layers.py b/hunyuan_model/embed_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e31ba9cc58d1aa05e0f17b919762f69bd693b5c0 --- /dev/null +++ b/hunyuan_model/embed_layers.py @@ -0,0 +1,132 @@ +import collections +import math +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from .helpers import to_2tuple + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.flatten = flatten + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs) + nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) + if bias: + nn.init.zeros_(self.proj.bias) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class TextProjection(nn.Module): + """ + Projects text embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) + self.act_1 = act_layer() + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.normal_(self.mlp[2].weight, std=0.02) + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/hunyuan_model/fp8_optimization.py b/hunyuan_model/fp8_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..90b978baca8cd9a3401b8b66a6575c0c3c29c991 --- /dev/null +++ b/hunyuan_model/fp8_optimization.py @@ -0,0 +1,39 @@ +#based on ComfyUI's and MinusZoneAI's fp8_linear optimization +#further borrowed from HunyuanVideoWrapper for Musubi Tuner +import torch +import torch.nn as nn + +def fp8_linear_forward(cls, original_dtype, input): + weight_dtype = cls.weight.dtype + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if len(input.shape) == 3: + target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn + inn = input.reshape(-1, input.shape[2]).to(target_dtype) + w = cls.weight.t() + + scale = torch.ones((1), device=input.device, dtype=torch.float32) + bias = cls.bias.to(original_dtype) if cls.bias is not None else None + + if bias is not None: + o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale) + else: + o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale) + + if isinstance(o, tuple): + o = o[0] + + return o.reshape((-1, input.shape[1], cls.weight.shape[0])) + else: + return cls.original_forward(input.to(original_dtype)) + else: + return cls.original_forward(input) + +def convert_fp8_linear(module, original_dtype, params_to_keep={}): + setattr(module, "fp8_matmul_enabled", True) + + for name, module in module.named_modules(): + if not any(keyword in name for keyword in params_to_keep): + if isinstance(module, nn.Linear): + original_forward = module.forward + setattr(module, "original_forward", original_forward) + setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) diff --git a/hunyuan_model/helpers.py b/hunyuan_model/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..72ab8cb1feba4ce7782f1ea841fd42c71be7b0d1 --- /dev/null +++ b/hunyuan_model/helpers.py @@ -0,0 +1,40 @@ +import collections.abc + +from itertools import repeat + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + x = tuple(x) + if len(x) == 1: + x = tuple(repeat(x[0], n)) + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) + + +def as_tuple(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + if x is None or isinstance(x, (int, float, str)): + return (x,) + else: + raise ValueError(f"Unknown type {type(x)}") + + +def as_list_of_2tuple(x): + x = as_tuple(x) + if len(x) == 1: + x = (x[0], x[0]) + assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." + lst = [] + for i in range(0, len(x), 2): + lst.append((x[i], x[i + 1])) + return lst diff --git a/hunyuan_model/mlp_layers.py b/hunyuan_model/mlp_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc9547a6a0ba80ab19a472a9ea7aef525f46613 --- /dev/null +++ b/hunyuan_model/mlp_layers.py @@ -0,0 +1,118 @@ +# Modified from timm library: +# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 + +from functools import partial + +import torch +import torch.nn as nn + +from .modulate_layers import modulate +from .helpers import to_2tuple + + +class MLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer( + in_channels, hidden_channels, bias=bias[0], **factory_kwargs + ) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_channels, **factory_kwargs) + if norm_layer is not None + else nn.Identity() + ) + self.fc2 = linear_layer( + hidden_channels, out_features, bias=bias[1], **factory_kwargs + ) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +# +class MLPEmbedder(nn.Module): + """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" + def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class FinalLayer(nn.Module): + """The final layer of DiT.""" + + def __init__( + self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + # Just use LayerNorm for the final layer + self.norm_final = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + if isinstance(patch_size, int): + self.linear = nn.Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + **factory_kwargs + ) + else: + self.linear = nn.Linear( + hidden_size, + patch_size[0] * patch_size[1] * patch_size[2] * out_channels, + bias=True, + ) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + # Here we don't distinguish between the modulate types. Just use the simple one. + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + x = self.linear(x) + return x diff --git a/hunyuan_model/models.py b/hunyuan_model/models.py new file mode 100644 index 0000000000000000000000000000000000000000..68ece16722b711242f0255d56a7a3e892fc861c9 --- /dev/null +++ b/hunyuan_model/models.py @@ -0,0 +1,1162 @@ +import os +from typing import Any, List, Tuple, Optional, Union, Dict +import accelerate +from einops import rearrange + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from .activation_layers import get_activation_layer +from .norm_layers import get_norm_layer +from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection +from .attention import attention, parallel_attention, get_cu_seqlens +from .posemb_layers import apply_rotary_emb +from .mlp_layers import MLP, MLPEmbedder, FinalLayer +from .modulate_layers import ModulateDiT, modulate, apply_gate +from .token_refiner import SingleTokenRefiner +from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device +from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed + +from utils.safetensors_utils import MemoryEfficientSafeOpen + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class MMDoubleStreamBlock(nn.Module): + """ + A multimodal dit block with seperate modulation for + text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206 + (Flux.1): https://github.com/black-forest-labs/flux + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qkv_bias: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + attn_mode: str = "flash", + split_attn: bool = False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_mode = attn_mode + self.split_attn = split_attn + + self.deterministic = False + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.img_mod = ModulateDiT( + hidden_size, + factor=6, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.img_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.img_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_mlp = MLP( + hidden_size, + mlp_hidden_dim, + act_layer=get_activation_layer(mlp_act_type), + bias=True, + **factory_kwargs, + ) + + self.txt_mod = ModulateDiT( + hidden_size, + factor=6, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + self.txt_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.txt_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_mlp = MLP( + hidden_size, + mlp_hidden_dim, + act_layer=get_activation_layer(mlp_act_type), + bias=True, + **factory_kwargs, + ) + self.hybrid_seq_parallel_attn = None + + self.gradient_checkpointing = False + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + total_len: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + freqs_cis: tuple = None, + condition_type: str = None, + token_replace_vec: torch.Tensor = None, + frist_frame_token_num: int = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if condition_type == "token_replace": + img_mod1, token_replace_img_mod1 = self.img_mod(vec, condition_type=condition_type, token_replace_vec=token_replace_vec) + (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = img_mod1.chunk( + 6, dim=-1 + ) + (tr_img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate) = ( + token_replace_img_mod1.chunk(6, dim=-1) + ) + else: + (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod( + vec + ).chunk(6, dim=-1) + (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk( + 6, dim=-1 + ) + + # Prepare image for attention. + img_modulated = self.img_norm1(img) + + if condition_type == "token_replace": + img_modulated = modulate( + img_modulated, + shift=img_mod1_shift, + scale=img_mod1_scale, + condition_type=condition_type, + tr_shift=tr_img_mod1_shift, + tr_scale=tr_img_mod1_scale, + frist_frame_token_num=frist_frame_token_num, + ) + else: + img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) + + img_qkv = self.img_attn_qkv(img_modulated) + img_modulated = None + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + img_qkv = None + # Apply QK-Norm if needed + img_q = self.img_attn_q_norm(img_q).to(img_v) + img_k = self.img_attn_k_norm(img_k).to(img_v) + + # Apply RoPE if needed. + if freqs_cis is not None: + img_q_shape = img_q.shape + img_k_shape = img_k.shape + img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + assert ( + img_q.shape == img_q_shape and img_k.shape == img_k_shape + ), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}" + # img_q, img_k = img_qq, img_kk + + # Prepare txt for attention. + txt_modulated = self.txt_norm1(txt) + txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) + txt_qkv = self.txt_attn_qkv(txt_modulated) + txt_modulated = None + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + txt_qkv = None + # Apply QK-Norm if needed. + txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) + txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) + + # Run actual attention. + img_q_len = img_q.shape[1] + img_kv_len = img_k.shape[1] + batch_size = img_k.shape[0] + q = torch.cat((img_q, txt_q), dim=1) + img_q = txt_q = None + k = torch.cat((img_k, txt_k), dim=1) + img_k = txt_k = None + v = torch.cat((img_v, txt_v), dim=1) + img_v = txt_v = None + + assert ( + cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1 + ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}" + + # attention computation start + if not self.hybrid_seq_parallel_attn: + l = [q, k, v] + q = k = v = None + attn = attention( + l, + mode=self.attn_mode, + attn_mask=attn_mask, + total_len=total_len, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + batch_size=batch_size, + ) + else: + attn = parallel_attention( + self.hybrid_seq_parallel_attn, + q, + k, + v, + img_q_len=img_q_len, + img_kv_len=img_kv_len, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + ) + + # attention computation end + + img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] + attn = None + + # Calculate the img bloks. + if condition_type == "token_replace": + img = img + apply_gate( + self.img_attn_proj(img_attn), + gate=img_mod1_gate, + condition_type=condition_type, + tr_gate=tr_img_mod1_gate, + frist_frame_token_num=frist_frame_token_num, + ) + img_attn = None + img = img + apply_gate( + self.img_mlp( + modulate( + self.img_norm2(img), + shift=img_mod2_shift, + scale=img_mod2_scale, + condition_type=condition_type, + tr_shift=tr_img_mod2_shift, + tr_scale=tr_img_mod2_scale, + frist_frame_token_num=frist_frame_token_num, + ) + ), + gate=img_mod2_gate, + condition_type=condition_type, + tr_gate=tr_img_mod2_gate, + frist_frame_token_num=frist_frame_token_num, + ) + else: + img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + img_attn = None + img = img + apply_gate( + self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), + gate=img_mod2_gate, + ) + + # Calculate the txt bloks. + txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + txt_attn = None + txt = txt + apply_gate( + self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), + gate=txt_mod2_gate, + ) + + return img, txt + + # def forward( + # self, + # img: torch.Tensor, + # txt: torch.Tensor, + # vec: torch.Tensor, + # attn_mask: Optional[torch.Tensor] = None, + # cu_seqlens_q: Optional[torch.Tensor] = None, + # cu_seqlens_kv: Optional[torch.Tensor] = None, + # max_seqlen_q: Optional[int] = None, + # max_seqlen_kv: Optional[int] = None, + # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + # ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class MMSingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + Also refer to (SD3): https://arxiv.org/abs/2403.03206 + (Flux.1): https://github.com/black-forest-labs/flux + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + attn_mode: str = "flash", + split_attn: bool = False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_mode = attn_mode + self.split_attn = split_attn + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + self.scale = qk_scale or head_dim**-0.5 + + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.mlp_act = get_activation_layer(mlp_act_type)() + self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs) + self.hybrid_seq_parallel_attn = None + + self.gradient_checkpointing = False + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward( + self, + x: torch.Tensor, + vec: torch.Tensor, + txt_len: int, + attn_mask: Optional[torch.Tensor] = None, + total_len: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + condition_type: str = None, + token_replace_vec: torch.Tensor = None, + frist_frame_token_num: int = None, + ) -> torch.Tensor: + if condition_type == "token_replace": + mod, tr_mod = self.modulation(vec, condition_type=condition_type, token_replace_vec=token_replace_vec) + (mod_shift, mod_scale, mod_gate) = mod.chunk(3, dim=-1) + (tr_mod_shift, tr_mod_scale, tr_mod_gate) = tr_mod.chunk(3, dim=-1) + else: + mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) + + if condition_type == "token_replace": + x_mod = modulate( + self.pre_norm(x), + shift=mod_shift, + scale=mod_scale, + condition_type=condition_type, + tr_shift=tr_mod_shift, + tr_scale=tr_mod_scale, + frist_frame_token_num=frist_frame_token_num, + ) + else: + x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) + + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + x_mod = None + # mlp = mlp.to("cpu", non_blocking=True) + # clean_memory_on_device(x.device) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + qkv = None + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + # Apply RoPE if needed. + if freqs_cis is not None: + img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] + q = k = None + img_q_shape = img_q.shape + img_k_shape = img_k.shape + img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + assert ( + img_q.shape == img_q_shape and img_k_shape == img_k.shape + ), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}" + # img_q, img_k = img_qq, img_kk + # del img_qq, img_kk + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + del img_q, txt_q, img_k, txt_k + + # Compute attention. + assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}" + + # attention computation start + if not self.hybrid_seq_parallel_attn: + l = [q, k, v] + q = k = v = None + attn = attention( + l, + mode=self.attn_mode, + attn_mask=attn_mask, + total_len=total_len, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + batch_size=x.shape[0], + ) + else: + attn = parallel_attention( + self.hybrid_seq_parallel_attn, + q, + k, + v, + img_q_len=img_q.shape[1], + img_kv_len=img_k.shape[1], + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + ) + # attention computation end + + # Compute activation in mlp stream, cat again and run second linear layer. + # mlp = mlp.to(x.device) + mlp = self.mlp_act(mlp) + attn_mlp = torch.cat((attn, mlp), 2) + attn = None + mlp = None + output = self.linear2(attn_mlp) + attn_mlp = None + + if condition_type == "token_replace": + output = x + apply_gate( + output, + gate=mod_gate, + condition_type=condition_type, + tr_gate=tr_mod_gate, + frist_frame_token_num=frist_frame_token_num, + ) + return output + else: + return x + apply_gate(output, gate=mod_gate) + + # def forward( + # self, + # x: torch.Tensor, + # vec: torch.Tensor, + # txt_len: int, + # attn_mask: Optional[torch.Tensor] = None, + # cu_seqlens_q: Optional[torch.Tensor] = None, + # cu_seqlens_kv: Optional[torch.Tensor] = None, + # max_seqlen_q: Optional[int] = None, + # max_seqlen_kv: Optional[int] = None, + # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + # ) -> torch.Tensor: + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin): + """ + HunyuanVideo Transformer backbone + + Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline. + + Reference: + [1] Flux.1: https://github.com/black-forest-labs/flux + [2] MMDiT: http://arxiv.org/abs/2403.03206 + + Parameters + ---------- + args: argparse.Namespace + The arguments parsed by argparse. + patch_size: list + The size of the patch. + in_channels: int + The number of input channels. + out_channels: int + The number of output channels. + hidden_size: int + The hidden size of the transformer backbone. + heads_num: int + The number of attention heads. + mlp_width_ratio: float + The ratio of the hidden size of the MLP in the transformer block. + mlp_act_type: str + The activation function of the MLP in the transformer block. + depth_double_blocks: int + The number of transformer blocks in the double blocks. + depth_single_blocks: int + The number of transformer blocks in the single blocks. + rope_dim_list: list + The dimension of the rotary embedding for t, h, w. + qkv_bias: bool + Whether to use bias in the qkv linear layer. + qk_norm: bool + Whether to use qk norm. + qk_norm_type: str + The type of qk norm. + guidance_embed: bool + Whether to use guidance embedding for distillation. + text_projection: str + The type of the text projection, default is single_refiner. + use_attention_mask: bool + Whether to use attention mask for text encoder. + dtype: torch.dtype + The dtype of the model. + device: torch.device + The device of the model. + attn_mode: str + The mode of the attention, default is flash. + split_attn: bool + Whether to use split attention (make attention as batch size 1). + """ + + # @register_to_config + def __init__( + self, + text_states_dim: int, + text_states_dim_2: int, + patch_size: list = [1, 2, 2], + in_channels: int = 4, # Should be VAE.config.latent_channels. + out_channels: int = None, + hidden_size: int = 3072, + heads_num: int = 24, + mlp_width_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + mm_double_blocks_depth: int = 20, + mm_single_blocks_depth: int = 40, + rope_dim_list: List[int] = [16, 56, 56], + qkv_bias: bool = True, + qk_norm: bool = True, + qk_norm_type: str = "rms", + guidance_embed: bool = False, # For modulation. + text_projection: str = "single_refiner", + use_attention_mask: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + attn_mode: str = "flash", + split_attn: bool = False, + i2v_mode: bool = False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.unpatchify_channels = self.out_channels + self.guidance_embed = guidance_embed + self.rope_dim_list = rope_dim_list + + # Text projection. Default to linear projection. + # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831 + self.use_attention_mask = use_attention_mask + self.text_projection = text_projection + + self.text_states_dim = text_states_dim + self.text_states_dim_2 = text_states_dim_2 + + if hidden_size % heads_num != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}") + pe_dim = hidden_size // heads_num + if sum(rope_dim_list) != pe_dim: + raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}") + self.hidden_size = hidden_size + self.heads_num = heads_num + + self.attn_mode = attn_mode + self.split_attn = split_attn + print(f"Using {self.attn_mode} attention mode, split_attn: {self.split_attn}") + + self.i2v_condition_type = "token_replace" if i2v_mode else None # only support token_replace for i2v mode + + # image projection + self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs) + + # text projection + if self.text_projection == "linear": + self.txt_in = TextProjection( + self.text_states_dim, + self.hidden_size, + get_activation_layer("silu"), + **factory_kwargs, + ) + elif self.text_projection == "single_refiner": + self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs) + else: + raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}") + + # time modulation + self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) + + # text modulation + self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs) + + # guidance modulation + self.guidance_in = ( + TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None + ) + + # double blocks + self.double_blocks = nn.ModuleList( + [ + MMDoubleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + attn_mode=attn_mode, + split_attn=split_attn, + **factory_kwargs, + ) + for _ in range(mm_double_blocks_depth) + ] + ) + + # single blocks + self.single_blocks = nn.ModuleList( + [ + MMSingleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + attn_mode=attn_mode, + split_attn=split_attn, + **factory_kwargs, + ) + for _ in range(mm_single_blocks_depth) + ] + ) + + self.final_layer = FinalLayer( + self.hidden_size, + self.patch_size, + self.out_channels, + get_activation_layer("silu"), + **factory_kwargs, + ) + + self.gradient_checkpointing = False + self.blocks_to_swap = None + self.offloader_double = None + self.offloader_single = None + self._enable_img_in_txt_in_offloading = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + self.txt_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing() + + print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + self.txt_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print(f"HYVideoDiffusionTransformer: Gradient checkpointing disabled.") + + def enable_img_in_txt_in_offloading(self): + self._enable_img_in_txt_in_offloading = True + + def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool): + self.blocks_to_swap = num_blocks + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1 + + assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, ( + f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = ModelOffloader( + "double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True + ) + self.offloader_single = ModelOffloader( + "single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True + ) + print( + f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def switch_block_swap_for_inference(self): + if self.blocks_to_swap: + self.offloader_double.set_forward_only(True) + self.offloader_single.set_forward_only(True) + self.prepare_block_swap_before_forward() + print(f"HYVideoDiffusionTransformer: Block swap set to forward only.") + + def switch_block_swap_for_training(self): + if self.blocks_to_swap: + self.offloader_double.set_forward_only(False) + self.offloader_single.set_forward_only(False) + self.prepare_block_swap_before_forward() + print(f"HYVideoDiffusionTransformer: Block swap set to forward and backward.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def enable_deterministic(self): + for block in self.double_blocks: + block.enable_deterministic() + for block in self.single_blocks: + block.enable_deterministic() + + def disable_deterministic(self): + for block in self.double_blocks: + block.disable_deterministic() + for block in self.single_blocks: + block.disable_deterministic() + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, # Should be in range(0, 1000). + text_states: torch.Tensor = None, + text_mask: torch.Tensor = None, # Now we don't use it. + text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. + freqs_cos: Optional[torch.Tensor] = None, + freqs_sin: Optional[torch.Tensor] = None, + guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + out = {} + img = x + txt = text_states + _, _, ot, oh, ow = x.shape + tt, th, tw = ( + ot // self.patch_size[0], + oh // self.patch_size[1], + ow // self.patch_size[2], + ) + + # Prepare modulation vectors. + vec = self.time_in(t) + + if self.i2v_condition_type == "token_replace": + token_replace_t = torch.zeros_like(t) + token_replace_vec = self.time_in(token_replace_t) + frist_frame_token_num = th * tw + else: + token_replace_vec = None + frist_frame_token_num = None + # token_replace_mask_img = None + # token_replace_mask_txt = None + + # text modulation + vec_2 = self.vector_in(text_states_2) + vec = vec + vec_2 + if self.i2v_condition_type == "token_replace": + token_replace_vec = token_replace_vec + vec_2 + vec_2 = None + + # guidance modulation + if self.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + + # our timestep_embedding is merged into guidance_in(TimestepEmbedder) + vec = vec + self.guidance_in(guidance) + + # Embed image and text. + if self._enable_img_in_txt_in_offloading: + self.img_in.to(x.device, non_blocking=True) + self.txt_in.to(x.device, non_blocking=True) + synchronize_device(x.device) + + img = self.img_in(img) + if self.text_projection == "linear": + txt = self.txt_in(txt) + elif self.text_projection == "single_refiner": + txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) + else: + raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}") + + if self._enable_img_in_txt_in_offloading: + self.img_in.to(torch.device("cpu"), non_blocking=True) + self.txt_in.to(torch.device("cpu"), non_blocking=True) + synchronize_device(x.device) + clean_memory_on_device(x.device) + + txt_seq_len = txt.shape[1] + img_seq_len = img.shape[1] + + # Compute cu_squlens and max_seqlen for flash attention + cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len) + cu_seqlens_kv = cu_seqlens_q + max_seqlen_q = img_seq_len + txt_seq_len + max_seqlen_kv = max_seqlen_q + + attn_mask = total_len = None + if self.split_attn or self.attn_mode == "torch": + # calculate text length and total length + text_len = text_mask.sum(dim=1) # (bs, ) + total_len = img_seq_len + text_len # (bs, ) + if self.attn_mode == "torch" and not self.split_attn: + # initialize attention mask: bool tensor for sdpa, (b, 1, n, n) + bs = img.shape[0] + attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device) + + # set attention mask with total_len + for i in range(bs): + attn_mask[i, :, : total_len[i], : total_len[i]] = True + total_len = None # means we don't use split_attn + + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + # --------------------- Pass through DiT blocks ------------------------ + for block_idx, block in enumerate(self.double_blocks): + double_block_args = [ + img, + txt, + vec, + attn_mask, + total_len, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + freqs_cis, + self.i2v_condition_type, + token_replace_vec, + frist_frame_token_num, + ] + + if self.blocks_to_swap: + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(*double_block_args) + + if self.blocks_to_swap: + self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx) + + # Merge txt and img to pass through single stream blocks. + x = torch.cat((img, txt), 1) + if self.blocks_to_swap: + # delete img, txt to reduce memory usage + del img, txt + clean_memory_on_device(x.device) + + if len(self.single_blocks) > 0: + for block_idx, block in enumerate(self.single_blocks): + single_block_args = [ + x, + vec, + txt_seq_len, + attn_mask, + total_len, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + freqs_cis, + self.i2v_condition_type, + token_replace_vec, + frist_frame_token_num, + ] + if self.blocks_to_swap: + self.offloader_single.wait_for_block(block_idx) + + x = block(*single_block_args) + + if self.blocks_to_swap: + self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx) + + img = x[:, :img_seq_len, ...] + x = None + + # ---------------------------- Final layer ------------------------------ + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + + img = self.unpatchify(img, tt, th, tw) + if return_dict: + out["x"] = img + return out + return img + + def unpatchify(self, x, t, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.unpatchify_channels + pt, ph, pw = self.patch_size + assert t * h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) + x = torch.einsum("nthwcopq->nctohpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + + return imgs + + def params_count(self): + counts = { + "double": sum( + [ + sum(p.numel() for p in block.img_attn_qkv.parameters()) + + sum(p.numel() for p in block.img_attn_proj.parameters()) + + sum(p.numel() for p in block.img_mlp.parameters()) + + sum(p.numel() for p in block.txt_attn_qkv.parameters()) + + sum(p.numel() for p in block.txt_attn_proj.parameters()) + + sum(p.numel() for p in block.txt_mlp.parameters()) + for block in self.double_blocks + ] + ), + "single": sum( + [ + sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters()) + for block in self.single_blocks + ] + ), + "total": sum(p.numel() for p in self.parameters()), + } + counts["attn+mlp"] = counts["double"] + counts["single"] + return counts + + +################################################################################# +# HunyuanVideo Configs # +################################################################################# + +HUNYUAN_VIDEO_CONFIG = { + "HYVideo-T/2": { + "mm_double_blocks_depth": 20, + "mm_single_blocks_depth": 40, + "rope_dim_list": [16, 56, 56], + "hidden_size": 3072, + "heads_num": 24, + "mlp_width_ratio": 4, + }, + "HYVideo-T/2-cfgdistill": { + "mm_double_blocks_depth": 20, + "mm_single_blocks_depth": 40, + "rope_dim_list": [16, 56, 56], + "hidden_size": 3072, + "heads_num": 24, + "mlp_width_ratio": 4, + "guidance_embed": True, + }, +} + + +def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, i2v_mode, factor_kwargs): + """load hunyuan video model + + NOTE: Only support HYVideo-T/2-cfgdistill now. + The config of I2V model is "HYVideo-T/2", but if embedded_cfg_scale is not 1.0, it has guidance embed. So it is same as "HYVideo-T/2-cfgdistill". + + Args: + text_state_dim (int): text state dimension + text_state_dim_2 (int): text state dimension 2 + in_channels (int): input channels number + out_channels (int): output channels number + i2v_mode (bool): whether to use i2v model + factor_kwargs (dict): factor kwargs + + Returns: + model (nn.Module): The hunyuan video model + """ + # if args.model in HUNYUAN_VIDEO_CONFIG.keys(): + model = HYVideoDiffusionTransformer( + text_states_dim=text_states_dim, + text_states_dim_2=text_states_dim_2, + in_channels=in_channels, + out_channels=out_channels, + i2v_mode=i2v_mode, + **HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"], + **factor_kwargs, + ) + return model + # else: + # raise NotImplementedError() + + +def load_state_dict(model, model_path): + state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True) + + load_key = "module" + if load_key in state_dict: + state_dict = state_dict[load_key] + else: + raise KeyError( + f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint " + f"are: {list(state_dict.keys())}." + ) + info = model.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Load state dict from {model_path} with info: {info}") + return model + + +def load_transformer(dit_path, attn_mode, split_attn, device, dtype, in_channels=16, i2v_mode=False) -> HYVideoDiffusionTransformer: + # =========================== Build main model =========================== + factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode, "split_attn": split_attn} + latent_channels = 16 + out_channels = latent_channels + + with accelerate.init_empty_weights(): + transformer = load_dit_model( + text_states_dim=4096, + text_states_dim_2=768, + in_channels=in_channels, + out_channels=out_channels, + i2v_mode=i2v_mode, + factor_kwargs=factor_kwargs, + ) + + if os.path.splitext(dit_path)[-1] == ".safetensors": + # loading safetensors: may be already fp8 + with MemoryEfficientSafeOpen(dit_path) as f: + state_dict = {} + for k in f.keys(): + tensor = f.get_tensor(k) + tensor = tensor.to(device=device, dtype=dtype) + # TODO support comfy model + # if k.startswith("model.model."): + # k = convert_comfy_model_key(k) + state_dict[k] = tensor + info = transformer.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Load state dict from {dit_path} with info: {info}") + else: + transformer = load_state_dict(transformer, dit_path) + + return transformer + + +def get_rotary_pos_embed_by_shape(model, latents_size): + target_ndim = 3 + ndim = 5 - 2 + + if isinstance(model.patch_size, int): + assert all(s % model.patch_size == 0 for s in latents_size), ( + f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), " + f"but got {latents_size}." + ) + rope_sizes = [s // model.patch_size for s in latents_size] + elif isinstance(model.patch_size, list): + assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), ( + f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), " + f"but got {latents_size}." + ) + rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)] + + if len(rope_sizes) != target_ndim: + rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis + head_dim = model.hidden_size // model.heads_num + rope_dim_list = model.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + + rope_theta = 256 + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1 + ) + return freqs_cos, freqs_sin + + +def get_rotary_pos_embed(vae_name, model, video_length, height, width): + # 884 + if "884" in vae_name: + latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] + elif "888" in vae_name: + latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8] + else: + latents_size = [video_length, height // 8, width // 8] + + return get_rotary_pos_embed_by_shape(model, latents_size) diff --git a/hunyuan_model/modulate_layers.py b/hunyuan_model/modulate_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f2a8be00a7b3eb5cac79cea5efd624a16ef13bd0 --- /dev/null +++ b/hunyuan_model/modulate_layers.py @@ -0,0 +1,101 @@ +from typing import Callable + +import torch +import torch.nn as nn + + +class ModulateDiT(nn.Module): + """Modulation layer for DiT.""" + + def __init__( + self, + hidden_size: int, + factor: int, + act_layer: Callable, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.act = act_layer() + self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) + # Zero-initialize the modulation + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor, condition_type=None, token_replace_vec=None) -> torch.Tensor: + x_out = self.linear(self.act(x)) + + if condition_type == "token_replace": + x_token_replace_out = self.linear(self.act(token_replace_vec)) + return x_out, x_token_replace_out + else: + return x_out + + +def modulate(x, shift=None, scale=None, condition_type=None, tr_shift=None, tr_scale=None, frist_frame_token_num=None): + """modulate by shift and scale + + Args: + x (torch.Tensor): input tensor. + shift (torch.Tensor, optional): shift tensor. Defaults to None. + scale (torch.Tensor, optional): scale tensor. Defaults to None. + + Returns: + torch.Tensor: the output tensor after modulate. + """ + if condition_type == "token_replace": + x_zero = x[:, :frist_frame_token_num] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1) + x_orig = x[:, frist_frame_token_num:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x = torch.concat((x_zero, x_orig), dim=1) + return x + else: + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale.unsqueeze(1)) + elif scale is None: + return x + shift.unsqueeze(1) + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def apply_gate(x, gate=None, tanh=False, condition_type=None, tr_gate=None, frist_frame_token_num=None): + """AI is creating summary for apply_gate + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if condition_type == "token_replace": + if gate is None: + return x + if tanh: + x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1).tanh() + x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1).tanh() + x = torch.concat((x_zero, x_orig), dim=1) + return x + else: + x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1) + x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1) + x = torch.concat((x_zero, x_orig), dim=1) + return x + else: + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +def ckpt_wrapper(module): + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward diff --git a/hunyuan_model/norm_layers.py b/hunyuan_model/norm_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..a53d167436b6971d3aabf5cfe51c0b9d6dfc022f --- /dev/null +++ b/hunyuan_model/norm_layers.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + # output = output * self.weight + # support fp8 + output = output * self.weight.to(output.dtype) + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") diff --git a/hunyuan_model/pipeline_hunyuan_video.py b/hunyuan_model/pipeline_hunyuan_video.py new file mode 100644 index 0000000000000000000000000000000000000000..c1293161e13a47ae7dcedfef2c55e3baefc655f4 --- /dev/null +++ b/hunyuan_model/pipeline_hunyuan_video.py @@ -0,0 +1,1100 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== +import inspect +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import torch +import torch.distributed as dist +import numpy as np +from dataclasses import dataclass +from packaging import version + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput + +from ...constants import PRECISION_TO_TYPE +from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D +from ...text_encoder import TextEncoder +from ...modules import HYVideoDiffusionTransformer + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class HunyuanVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`TextEncoder`]): + Frozen text-encoder. + text_encoder_2 ([`TextEncoder`]): + Frozen text-encoder_2. + transformer ([`HYVideoDiffusionTransformer`]): + A `HYVideoDiffusionTransformer` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = ["text_encoder_2"] + _exclude_from_cpu_offload = ["transformer"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: TextEncoder, + transformer: HYVideoDiffusionTransformer, + scheduler: KarrasDiffusionSchedulers, + text_encoder_2: Optional[TextEncoder] = None, + progress_bar_config: Dict[str, Any] = None, + args=None, + ): + super().__init__() + + # ========================================================================================== + if progress_bar_config is None: + progress_bar_config = {} + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + self._progress_bar_config.update(progress_bar_config) + + self.args = args + # ========================================================================================== + + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def encode_prompt( + self, + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + text_encoder: Optional[TextEncoder] = None, + data_type: Optional[str] = "image", + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + attention_mask (`torch.Tensor`, *optional*): + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_attention_mask (`torch.Tensor`, *optional*): + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + text_encoder (TextEncoder, *optional*): + data_type (`str`, *optional*): + """ + if text_encoder is None: + text_encoder = self.text_encoder + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(text_encoder.model, lora_scale) + else: + scale_lora_layers(text_encoder.model, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer) + + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) + + if clip_skip is None: + prompt_outputs = text_encoder.encode( + text_inputs, data_type=data_type, device=device + ) + prompt_embeds = prompt_outputs.hidden_state + else: + prompt_outputs = text_encoder.encode( + text_inputs, + output_hidden_states=True, + data_type=data_type, + device=device, + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = text_encoder.model.text_model.final_layer_norm( + prompt_embeds + ) + + attention_mask = prompt_outputs.attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(device) + bs_embed, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, num_videos_per_prompt) + attention_mask = attention_mask.view( + bs_embed * num_videos_per_prompt, seq_len + ) + + if text_encoder is not None: + prompt_embeds_dtype = text_encoder.dtype + elif self.transformer is not None: + prompt_embeds_dtype = self.transformer.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if prompt_embeds.ndim == 2: + bs_embed, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) + else: + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_videos_per_prompt, seq_len, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt( + uncond_tokens, text_encoder.tokenizer + ) + + # max_length = prompt_embeds.shape[1] + uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type) + + negative_prompt_outputs = text_encoder.encode( + uncond_input, data_type=data_type, device=device + ) + negative_prompt_embeds = negative_prompt_outputs.hidden_state + + negative_attention_mask = negative_prompt_outputs.attention_mask + if negative_attention_mask is not None: + negative_attention_mask = negative_attention_mask.to(device) + _, seq_len = negative_attention_mask.shape + negative_attention_mask = negative_attention_mask.repeat( + 1, num_videos_per_prompt + ) + negative_attention_mask = negative_attention_mask.view( + batch_size * num_videos_per_prompt, seq_len + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) + + if negative_prompt_embeds.ndim == 2: + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_videos_per_prompt + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_videos_per_prompt, -1 + ) + else: + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_videos_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_videos_per_prompt, seq_len, -1 + ) + + if text_encoder is not None: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(text_encoder.model, lora_scale) + + return ( + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + def decode_latents(self, latents, enable_tiling=True): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False)[0] + else: + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + if image.ndim == 4: + image = image.cpu().permute(0, 2, 3, 1).float() + else: + image = image.cpu().float() + return image + + def prepare_extra_func_kwargs(self, func, kwargs): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + extra_step_kwargs = {} + + for k, v in kwargs.items(): + accepts = k in set(inspect.signature(func).parameters.keys()) + if accepts: + extra_step_kwargs[k] = v + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + video_length, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + vae_ver="88-4c-sd", + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if video_length is not None: + if "884" in vae_ver: + if video_length != 1 and (video_length - 1) % 4 != 0: + raise ValueError( + f"`video_length` has to be 1 or a multiple of 4 but is {video_length}." + ) + elif "888" in vae_ver: + if video_length != 1 and (video_length - 1) % 8 != 0: + raise ValueError( + f"`video_length` has to be 1 or a multiple of 8 but is {video_length}." + ) + + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + video_length, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + video_length, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler + if hasattr(self.scheduler, "init_noise_sigma"): + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, + w: torch.Tensor, + embedding_dim: int = 512, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + video_length: int, + data_type: str = "video", + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[ + Callable[[int, int, Dict], None], + PipelineCallback, + MultiPipelineCallbacks, + ] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + vae_ver: str = "88-4c-sd", + enable_tiling: bool = False, + n_tokens: Optional[int] = None, + embedded_guidance_scale: Optional[float] = None, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + video_length (`int`): + The number of frames in the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + # height = height or self.transformer.config.sample_size * self.vae_scale_factor + # width = width or self.transformer.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + video_length, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + vae_ver=vae_ver, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_mask, + negative_prompt_mask, + ) = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + attention_mask=attention_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_attention_mask=negative_attention_mask, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + data_type=data_type, + ) + if self.text_encoder_2 is not None: + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_mask_2, + negative_prompt_mask_2, + ) = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=None, + attention_mask=None, + negative_prompt_embeds=None, + negative_attention_mask=None, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + text_encoder=self.text_encoder_2, + data_type=data_type, + ) + else: + prompt_embeds_2 = None + negative_prompt_embeds_2 = None + prompt_mask_2 = None + negative_prompt_mask_2 = None + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if prompt_mask is not None: + prompt_mask = torch.cat([negative_prompt_mask, prompt_mask]) + if prompt_embeds_2 is not None: + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + if prompt_mask_2 is not None: + prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2]) + + + # 4. Prepare timesteps + extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.set_timesteps, {"n_tokens": n_tokens} + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + **extra_set_timesteps_kwargs, + ) + + if "884" in vae_ver: + video_length = (video_length - 1) // 4 + 1 + elif "888" in vae_ver: + video_length = (video_length - 1) // 8 + 1 + else: + video_length = video_length + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + video_length, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, + {"generator": generator, "eta": eta}, + ) + + target_dtype = PRECISION_TO_TYPE[self.args.precision] + autocast_enabled = ( + target_dtype != torch.float32 + ) and not self.args.disable_autocast + vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32 + ) and not self.args.disable_autocast + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # if is_progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + t_expand = t.repeat(latent_model_input.shape[0]) + guidance_expand = ( + torch.tensor( + [embedded_guidance_scale] * latent_model_input.shape[0], + dtype=torch.float32, + device=device, + ).to(target_dtype) + * 1000.0 + if embedded_guidance_scale is not None + else None + ) + + # predict the noise residual + with torch.autocast( + device_type="cuda", dtype=target_dtype, enabled=autocast_enabled + ): + noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256) + latent_model_input, # [2, 16, 33, 24, 42] + t_expand, # [2] + text_states=prompt_embeds, # [2, 256, 4096] + text_mask=prompt_mask, # [2, 256] + text_states_2=prompt_embeds_2, # [2, 768] + freqs_cos=freqs_cis[0], # [seqlen, head_dim] + freqs_sin=freqs_cis[1], # [seqlen, head_dim] + guidance=guidance_expand, + return_dict=True, + )[ + "x" + ] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + if progress_bar is not None: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + expand_temporal_dim = False + if len(latents.shape) == 4: + if isinstance(self.vae, AutoencoderKLCausal3D): + latents = latents.unsqueeze(2) + expand_temporal_dim = True + elif len(latents.shape) == 5: + pass + else: + raise ValueError( + f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}." + ) + + if ( + hasattr(self.vae.config, "shift_factor") + and self.vae.config.shift_factor + ): + latents = ( + latents / self.vae.config.scaling_factor + + self.vae.config.shift_factor + ) + else: + latents = latents / self.vae.config.scaling_factor + + with torch.autocast( + device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled + ): + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode( + latents, return_dict=False, generator=generator + )[0] + else: + image = self.vae.decode( + latents, return_dict=False, generator=generator + )[0] + + if expand_temporal_dim or image.shape[2] == 1: + image = image.squeeze(2) + + else: + image = latents + + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().float() + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return HunyuanVideoPipelineOutput(videos=image) diff --git a/hunyuan_model/posemb_layers.py b/hunyuan_model/posemb_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dfce82c690540d17a55a51b7997ee7ceb0bdbf44 --- /dev/null +++ b/hunyuan_model/posemb_layers.py @@ -0,0 +1,310 @@ +import torch +from typing import Union, Tuple, List + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 + + +def reshape_for_broadcast( + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + x: torch.Tensor, + head_first=False, +): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Notes: + When using FlashMHAModified, head_first should be False. + When using Attention, head_first should be True. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [ + d if i == ndim - 2 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + else: + assert freqs_cis[0].shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [ + d if i == ndim - 2 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + else: + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = ( + x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + ) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + else: + # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) + xq_ = torch.view_as_complex( + xq.float().reshape(*xq.shape[:-1], -1, 2) + ) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( + xq.device + ) # [S, D//2] --> [1, S, 1, D//2] + # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) + # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_ = torch.view_as_complex( + xk.float().reshape(*xk.shape[:-1], -1, 2) + ) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + theta_rescale_factor: Union[float, List[float]] = 1.0, + interpolation_factor: Union[float, List[float]] = 1.0, +): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd( + start, *args, dim=len(rope_dim_list) + ) # [3, W, H, D] / [2, W, H] + + if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): + theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) + assert len(theta_rescale_factor) == len( + rope_dim_list + ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): + interpolation_factor = [interpolation_factor] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) + assert len(interpolation_factor) == len( + rope_dim_list + ), "len(interpolation_factor) should equal to len(rope_dim_list)" + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) # [D/2] + # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # complex64 # [S, D/2] + return freqs_cis diff --git a/hunyuan_model/text_encoder.py b/hunyuan_model/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8d3ab19e5c4866ad15769d06ebe614665d1d0305 --- /dev/null +++ b/hunyuan_model/text_encoder.py @@ -0,0 +1,1297 @@ +from dataclasses import dataclass +import json +import os +from typing import Optional, Tuple, Union +from copy import deepcopy + +import torch +import torch.nn as nn +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + AutoTokenizer, + AutoModel, + CLIPConfig, + LlamaForCausalLM, + LlamaConfig, + LlavaConfig, + LlavaProcessor, + CLIPImageProcessor, +) +from transformers.utils import ModelOutput +from transformers.models.llama import LlamaModel +from transformers.models.llava import LlavaForConditionalGeneration +from safetensors.torch import load_file +from accelerate import init_empty_weights + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +CLIP_L_HUGGINGFACE_MODEL_ID = "openai/clip-vit-large-patch14" +LLAVA_HUGGINGFACE_MODEL_ID = "xtuner/llava-llama-3-8b-v1_1-transformers" + +CLIP_CONFIG = { + "_name_or_path": "clip-vit-large-patch14/", + "architectures": ["CLIPModel"], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + # "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0.0, + "bad_words_ids": None, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0.0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 1, + "prefix": None, + "problem_type": None, + "projection_dim": 768, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "task_specific_params": None, + "temperature": 1.0, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "transformers_version": "4.16.0.dev0", + "use_bfloat16": False, + "vocab_size": 49408, + # }, + # "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "projection_dim": 768, + # }, + # "torch_dtype": "float32", + # "transformers_version": null +} + +LLAMA_CONFIG = { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "mlp_bias": False, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.46.3", + "use_cache": True, + "vocab_size": 128320, +} +LLAVA_CONFIG_JSON = json.loads( + """ +{ + "architectures": [ + "LlavaForConditionalGeneration" + ], + "ignore_index": -100, + "image_token_index": 128257, + "model_type": "llava", + "pad_token_id": 128258, + "projector_hidden_act": "gelu", + "text_config": { + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 128000, + "eos_token_id": 128001, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 500000.0, + "torch_dtype": "float16", + "vocab_size": 128320 + }, + "torch_dtype": "float16", + "transformers_version": "4.40.1", + "vision_config": { + "architectures": [ + "CLIPVisionModel" + ], + "dropout": 0.0, + "hidden_size": 1024, + "image_size": 336, + "intermediate_size": 4096, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "torch_dtype": "float32" + }, + "vision_feature_layer": -2, + "vision_feature_select_strategy": "default" +}""" +) + +LLAVA_PROCESSOR_CONFIG = json.loads( + """{ + "image_token": "", + "num_additional_image_tokens": 1, + "patch_size": 14, + "processor_class": "LlavaNextProcessor", + "vision_feature_select_strategy": "default" +}""" +) + +# When using decoder-only models, we must provide a prompt template to instruct the text encoder +# on how to generate the text. +# -------------------------------------------------------------------- +PROMPT_TEMPLATE_ENCODE = ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" +) +PROMPT_TEMPLATE_ENCODE_VIDEO = ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" +) + +PROMPT_TEMPLATE_ENCODE_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" +NEGATIVE_PROMPT_I2V = "deformation, a poor composition and deformed video, bad teeth, bad eyes, bad limbs" + + +PROMPT_TEMPLATE = { + "dit-llm-encode": { + "template": PROMPT_TEMPLATE_ENCODE, + "crop_start": 36, + }, + "dit-llm-encode-video": { + "template": PROMPT_TEMPLATE_ENCODE_VIDEO, + "crop_start": 95, + }, + "dit-llm-encode-i2v": { + "template": PROMPT_TEMPLATE_ENCODE_I2V, + "crop_start": 36, + "image_emb_start": 5, + "image_emb_end": 581, + "image_emb_len": 576, + "double_return_token_id": 271, + }, + "dit-llm-encode-video-i2v": { + "template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, + "crop_start": 103, + "image_emb_start": 5, + "image_emb_end": 581, + "image_emb_len": 576, + "double_return_token_id": 271, + }, +} + + +def use_default(value, default): + return value if value is not None else default + + +def load_clip_l(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None): + if os.path.isdir(text_encoder_path): + # load from directory, configs are in the directory + text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype) + else: + # load from file, we create the model with the appropriate config + config = CLIPConfig(**CLIP_CONFIG) + with init_empty_weights(): + text_encoder = CLIPTextModel._from_config(config, torch_dtype=dtype) + + state_dict = load_file(text_encoder_path) + + text_encoder.load_state_dict(state_dict, strict=True, assign=True) + # if dtype is not None: + # text_encoder.to(dtype=dtype) + + return text_encoder + + +def load_clip_l_tokenizer(tokenizer_path: str): + if os.path.isdir(tokenizer_path): + tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77) + else: + # load from Hugging Face + logger.info(f"Loading tokenizer from Hugging Face: {CLIP_L_HUGGINGFACE_MODEL_ID}") + tokenizer = CLIPTokenizer.from_pretrained(CLIP_L_HUGGINGFACE_MODEL_ID, max_length=77) + + return tokenizer + + +def load_llm(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None): + if os.path.isdir(text_encoder_path): + # load from directory, configs are in the directory + text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype) + else: + # load from file, we create the model with the appropriate config + config = LlamaConfig(**LLAMA_CONFIG) + with init_empty_weights(): + text_encoder = LlamaForCausalLM._from_config(config, torch_dtype=dtype) + + state_dict = load_file(text_encoder_path) + + # support weights from ComfyUI + if "tokenizer" in state_dict: + state_dict.pop("tokenizer") + + text_encoder.load_state_dict(state_dict, strict=True, assign=True) + + return text_encoder + + +def load_llm_i2v(text_encoder_path: str, clip_vision_path: str, dtype: Optional[Union[str, torch.dtype]] = None): + if os.path.isdir(text_encoder_path): + # load from directory, configs are in the directory + text_encoder = LlavaForConditionalGeneration.from_pretrained(text_encoder_path, low_cpu_mem_usage=True) + else: + # load from file, we create the model with the appropriate config + config = LlavaConfig(**LLAVA_CONFIG_JSON) + with init_empty_weights(): + text_encoder = LlavaForConditionalGeneration._from_config(config, torch_dtype=dtype) + + state_dict = load_file(text_encoder_path) + + # support weights from ComfyUI + if "tokenizer" in state_dict: + state_dict.pop("tokenizer") + + state_dict = {"language_model." + k: v for k, v in state_dict.items()} + + state_dict_vision = load_file(clip_vision_path) + state_dict_vision = { + ("vision_tower." if "multi_modal_projector." not in k else "") + k: v for k, v in state_dict_vision.items() + } + state_dict.update(state_dict_vision) + + text_encoder.load_state_dict(state_dict, strict=True, assign=True) + + return text_encoder + + +def load_llm_tokenizer(tokenizer_path: str, padding_side="right"): + if os.path.isdir(tokenizer_path): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + else: + # load from Hugging Face + logger.info(f"Loading tokenizer from Hugging Face: {LLAVA_HUGGINGFACE_MODEL_ID}") + tokenizer = AutoTokenizer.from_pretrained(LLAVA_HUGGINGFACE_MODEL_ID, padding_side=padding_side) + + return tokenizer + + +def load_text_encoder( + text_encoder_type: str, + text_encoder_path: str, + text_encoder_dtype: Optional[Union[str, torch.dtype]] = None, + clip_vision_path: Optional[str] = None, +): + logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}") + + # reduce peak memory usage by specifying the dtype of the model + dtype = text_encoder_dtype + processor = None + if text_encoder_type == "clipL": + text_encoder = load_clip_l(text_encoder_path, dtype=dtype) + text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm + elif text_encoder_type == "llm": + text_encoder = load_llm(text_encoder_path, dtype=dtype) + if hasattr(text_encoder, "norm"): + text_encoder.final_layer_norm = text_encoder.norm # by from_pretrained + else: + text_encoder.final_layer_norm = text_encoder.model.norm # by _from_config + elif text_encoder_type == "llm-i2v": + text_encoder = load_llm_i2v(text_encoder_path, clip_vision_path, dtype=dtype) + text_encoder.final_layer_norm = text_encoder.language_model.model.norm + else: + raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") + # from_pretrained will ensure that the model is in eval mode. + + if dtype is not None: + text_encoder = text_encoder.to(dtype=dtype) + + text_encoder.requires_grad_(False) + + logger.info(f"Text encoder to dtype: {text_encoder.dtype}") + return text_encoder, processor, text_encoder_path + + +def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"): + logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}") + + if tokenizer_type == "clipL": + tokenizer = load_clip_l_tokenizer(tokenizer_path) + elif tokenizer_type == "llm" or tokenizer_type == "llm-i2v": + tokenizer = load_llm_tokenizer(tokenizer_path, padding_side=padding_side) + else: + raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") + + return tokenizer, tokenizer_path + + +@dataclass +class TextEncoderModelOutput(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + text_outputs (`list`, *optional*, returned when `return_texts=True` is passed): + List of decoded texts. + """ + + hidden_state: torch.FloatTensor = None + attention_mask: Optional[torch.LongTensor] = None + hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None + text_outputs: Optional[list] = None + + +class TextEncoder(nn.Module): + def __init__( + self, + text_encoder_type: str, + max_length: int, + text_encoder_dtype: Optional[Union[str, torch.dtype]] = None, + text_encoder_path: Optional[str] = None, + clip_vision_path: Optional[str] = None, + tokenizer_type: Optional[str] = None, + tokenizer_path: Optional[str] = None, + i2v_mode: bool = False, + output_key: Optional[str] = None, + use_attention_mask: bool = True, + input_max_length: Optional[int] = None, + prompt_template: Optional[dict] = None, + prompt_template_video: Optional[dict] = None, + hidden_state_skip_layer: Optional[int] = None, + apply_final_norm: bool = False, + reproduce: bool = False, + image_embed_interleave: int = None, + ): + super().__init__() + self.text_encoder_type = text_encoder_type + self.max_length = max_length + # self.precision = text_encoder_precision + self.model_path = text_encoder_path + self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type + self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path + self.i2v_mode = i2v_mode + self.use_attention_mask = use_attention_mask + if prompt_template_video is not None: + assert use_attention_mask is True, "Attention mask is True required when training videos." + self.input_max_length = input_max_length if input_max_length is not None else max_length + self.prompt_template = prompt_template + self.prompt_template_video = prompt_template_video + self.hidden_state_skip_layer = hidden_state_skip_layer + self.apply_final_norm = apply_final_norm + self.reproduce = reproduce + self.image_embed_interleave = image_embed_interleave + + self.use_template = self.prompt_template is not None + if self.use_template: + assert ( + isinstance(self.prompt_template, dict) and "template" in self.prompt_template + ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}" + assert "{}" in str(self.prompt_template["template"]), ( + "`prompt_template['template']` must contain a placeholder `{}` for the input text, " + f"got {self.prompt_template['template']}" + ) + + self.use_video_template = self.prompt_template_video is not None + if self.use_video_template: + if self.prompt_template_video is not None: + assert ( + isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video + ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}" + assert "{}" in str(self.prompt_template_video["template"]), ( + "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, " + f"got {self.prompt_template_video['template']}" + ) + + if "t5" in text_encoder_type: + self.output_key = output_key or "last_hidden_state" + elif "clip" in text_encoder_type: + self.output_key = output_key or "pooler_output" + elif "llm" in text_encoder_type or "glm" in text_encoder_type: + self.output_key = output_key or "last_hidden_state" + else: + raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") + + self.model, self.processor, self.model_path = load_text_encoder( + text_encoder_type=self.text_encoder_type, + text_encoder_path=self.model_path, + text_encoder_dtype=text_encoder_dtype, + clip_vision_path=clip_vision_path, + ) + self.dtype = self.model.dtype + + self.tokenizer, self.tokenizer_path = load_tokenizer( + tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right" + ) + + if text_encoder_type == "llm-i2v": + clip_processor = CLIPImageProcessor.from_pretrained(LLAVA_HUGGINGFACE_MODEL_ID) + self.processor = LlavaProcessor.from_args_and_dict( + args=[clip_processor, self.tokenizer], processor_dict=LLAVA_PROCESSOR_CONFIG + ) + # print(f"patch size: {self.processor.patch_size}, vision strategy: {self.processor.vision_feature_select_strategy}") + else: + self.processor = None + + def __repr__(self): + return f"{self.text_encoder_type} ({self.precision} - {self.model_path})" + + @property + def device(self): + return self.model.device + + @staticmethod + def apply_text_to_template(text, template, prevent_empty_text=True): + """ + Apply text to template. + + Args: + text (str): Input text. + template (str or list): Template string or list of chat conversation. + prevent_empty_text (bool): If Ture, we will prevent the user text from being empty + by adding a space. Defaults to True. + """ + if isinstance(template, str): + # Will send string to tokenizer. Used for llm + return template.format(text) + else: + raise TypeError(f"Unsupported template type: {type(template)}") + + def text2tokens(self, text, data_type="image", semantic_images=None): + """ + Tokenize the input text. + + Args: + text (str or list): Input text. + """ + tokenize_input_type = "str" + if self.use_template: + if data_type == "image": + prompt_template = self.prompt_template["template"] + elif data_type == "video": + prompt_template = self.prompt_template_video["template"] + else: + raise ValueError(f"Unsupported data type: {data_type}") + if isinstance(text, (list, tuple)): + text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text] + if isinstance(text[0], list): + tokenize_input_type = "list" + elif isinstance(text, str): + text = self.apply_text_to_template(text, prompt_template) + if isinstance(text, list): + tokenize_input_type = "list" + else: + raise TypeError(f"Unsupported text type: {type(text)}") + else: + if isinstance(text, (list, tuple)): + tokenize_input_type = "list" + elif isinstance(text, str): + tokenize_input_type = "str" + else: + raise TypeError(f"Unsupported text type: {type(text)}") + + kwargs = dict( + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + if tokenize_input_type == "str": + if self.text_encoder_type != "llm-i2v": + return self.tokenizer( + text, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + **kwargs, + ) + else: + # support transformers >= 4.47 + assert semantic_images is not None, "semantic_images is required for i2v mode tokenization." + kwargs["max_length"] += 575 # image feature length-1 + return self.processor( + semantic_images, + text, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + **kwargs, + ) + + elif tokenize_input_type == "list": + if self.use_template: + # this block is not tested yet + return self.tokenizer( + text, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + **kwargs, + ) + else: + return self.tokenizer.apply_chat_template( + text, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **kwargs, + ) + else: + raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}") + + def encode( + self, + batch_encoding, + use_attention_mask=None, + output_hidden_states=False, + do_sample=None, + hidden_state_skip_layer=None, + return_texts=False, + data_type="image", + semantic_images=None, + device=None, + ): + """ + Args: + batch_encoding (dict): Batch encoding from tokenizer. + use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask. + Defaults to None. + output_hidden_states (bool): Whether to output hidden states. If False, return the value of + self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer, + output_hidden_states will be set True. Defaults to False. + do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None. + When self.produce is False, do_sample is set to True by default. + hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer. + If None, self.output_key will be used. Defaults to None. + return_texts (bool): Whether to return the decoded texts. Defaults to False. + """ + device = self.model.device if device is None else device + use_attention_mask = use_default(use_attention_mask, self.use_attention_mask) + hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer) + do_sample = use_default(do_sample, not self.reproduce) + + if not self.i2v_mode: + attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None + outputs = self.model( + input_ids=batch_encoding["input_ids"].to(device), + attention_mask=attention_mask, + output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None, + ) + if hidden_state_skip_layer is not None: + last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] + # Real last hidden state already has layer norm applied. So here we only apply it + # for intermediate layers. + if hidden_state_skip_layer > 0 and self.apply_final_norm: + last_hidden_state = self.model.final_layer_norm(last_hidden_state) + else: + last_hidden_state = outputs[self.output_key] + + # Remove hidden states of instruction tokens, only keep prompt tokens. + if self.use_template: + if data_type == "image": + crop_start = self.prompt_template.get("crop_start", -1) + elif data_type == "video": + crop_start = self.prompt_template_video.get("crop_start", -1) + else: + raise ValueError(f"Unsupported data type: {data_type}") + if crop_start > 0: + last_hidden_state = last_hidden_state[:, crop_start:] + attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None + + if output_hidden_states: + return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states) + return TextEncoderModelOutput(last_hidden_state, attention_mask) + else: + # I2V mode + """ + # original code from HunyuanVideo + image_outputs = self.processor(semantic_images, return_tensors="pt")["pixel_values"].to(device) + attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None + outputs = self.model( + input_ids=batch_encoding["input_ids"].to(device), + attention_mask=attention_mask, + output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None, + pixel_values=image_outputs, + ) + + if hidden_state_skip_layer is not None: + last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] + # Real last hidden state already has layer norm applied. So here we only apply it + # for intermediate layers. + if hidden_state_skip_layer > 0 and self.apply_final_norm: + last_hidden_state = self.model.final_layer_norm(last_hidden_state) + else: + last_hidden_state = outputs[self.output_key] + + if self.use_template: + if data_type == "video": + crop_start = self.prompt_template_video.get("crop_start", -1) + text_crop_start = crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576) + image_crop_start = self.prompt_template_video.get("image_emb_start", 5) + image_crop_end = self.prompt_template_video.get("image_emb_end", 581) + batch_indices, last_double_return_token_indices = torch.where( + batch_encoding["input_ids"] == self.prompt_template_video.get("double_return_token_id", 271) + ) + + if last_double_return_token_indices.shape[0] == 3: + # in case the prompt is too long + last_double_return_token_indices = torch.cat( + (last_double_return_token_indices, torch.tensor([batch_encoding["input_ids"].shape[-1]])) + ) + batch_indices = torch.cat((batch_indices, torch.tensor([0]))) + + last_double_return_token_indices = last_double_return_token_indices.reshape( + batch_encoding["input_ids"].shape[0], -1 + )[:, -1] + batch_indices = batch_indices.reshape(batch_encoding["input_ids"].shape[0], -1)[:, -1] + assistant_crop_start = ( + last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4 + ) + assistant_crop_end = last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) + attention_mask_assistant_crop_start = last_double_return_token_indices - 4 + attention_mask_assistant_crop_end = last_double_return_token_indices + else: + raise ValueError(f"Unsupported data type: {data_type}") + """ + # modified code for i2v mode, support transformers >= 4.47 + assert use_attention_mask is True, "Attention mask is True required for backward compatibility." + batch_encoding = batch_encoding.to(device) + attention_mask = batch_encoding["attention_mask"] + outputs = self.model(**batch_encoding, output_hidden_states=True) + + if hidden_state_skip_layer is not None: + last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] + # Real last hidden state already has layer norm applied. So here we only apply it + # for intermediate layers. + if hidden_state_skip_layer > 0 and self.apply_final_norm: + last_hidden_state = self.model.final_layer_norm(last_hidden_state) + else: + last_hidden_state = outputs[self.output_key] + + if self.use_template: + if data_type == "video": + crop_start = self.prompt_template_video.get("crop_start", -1) + text_crop_start = crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576) + image_crop_start = self.prompt_template_video.get("image_emb_start", 5) + image_crop_end = self.prompt_template_video.get("image_emb_end", 581) + batch_indices, last_double_return_token_indices = torch.where( + batch_encoding["input_ids"] == self.prompt_template_video.get("double_return_token_id", 271) + ) + + if last_double_return_token_indices.shape[0] == 3: + # in case the prompt is too long + last_double_return_token_indices = torch.cat( + (last_double_return_token_indices, torch.tensor([batch_encoding["input_ids"].shape[-1]])) + ) + batch_indices = torch.cat((batch_indices, torch.tensor([0]))) + + last_double_return_token_indices = last_double_return_token_indices.reshape( + batch_encoding["input_ids"].shape[0], -1 + )[:, -1] + batch_indices = batch_indices.reshape(batch_encoding["input_ids"].shape[0], -1)[:, -1] + + # with transformers >= 4.47, token in input_ids is already expanded to image embed size. + # so we don't need to add image_emb_len to the last_double_return_token_indices. + assistant_crop_start = last_double_return_token_indices - 4 + assistant_crop_end = last_double_return_token_indices + # attention mask is also expanded to image embed size, so the same as hidden state. + attention_mask_assistant_crop_start = last_double_return_token_indices - 4 + attention_mask_assistant_crop_end = last_double_return_token_indices + else: + raise ValueError(f"Unsupported data type: {data_type}") + + text_last_hidden_state = [] + text_attention_mask = [] + image_last_hidden_state = [] + image_attention_mask = [] + for i in range(batch_encoding["input_ids"].shape[0]): + text_last_hidden_state.append( + torch.cat( + [ + last_hidden_state[i, text_crop_start : assistant_crop_start[i].item()], + last_hidden_state[i, assistant_crop_end[i].item() :], + ] + ) + ) + text_attention_mask.append( + torch.cat( + [ + attention_mask[ + i, + text_crop_start : attention_mask_assistant_crop_start[i].item(), # this line is modified + ], + attention_mask[i, attention_mask_assistant_crop_end[i].item() :], + ] + ) + if use_attention_mask + else None + ) + image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end]) + image_attention_mask.append( + torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).to(attention_mask.dtype) + if use_attention_mask + else None + ) + + text_last_hidden_state = torch.stack(text_last_hidden_state) + text_attention_mask = torch.stack(text_attention_mask) + image_last_hidden_state = torch.stack(image_last_hidden_state) + image_attention_mask = torch.stack(image_attention_mask) + + if semantic_images is not None and 0 < self.image_embed_interleave < 6: + image_last_hidden_state = image_last_hidden_state[:, :: self.image_embed_interleave, :] + image_attention_mask = image_attention_mask[:, :: self.image_embed_interleave] + + assert ( + text_last_hidden_state.shape[0] == text_attention_mask.shape[0] + and image_last_hidden_state.shape[0] == image_attention_mask.shape[0] + ) + + last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1) + attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1) + + if output_hidden_states: + return TextEncoderModelOutput( + last_hidden_state, + attention_mask, + hidden_states_list=outputs.hidden_states, + ) + return TextEncoderModelOutput(last_hidden_state, attention_mask) + + def forward( + self, + text, + use_attention_mask=None, + output_hidden_states=False, + do_sample=False, + hidden_state_skip_layer=None, + return_texts=False, + ): + batch_encoding = self.text2tokens(text) + return self.encode( + batch_encoding, + use_attention_mask=use_attention_mask, + output_hidden_states=output_hidden_states, + do_sample=do_sample, + hidden_state_skip_layer=hidden_state_skip_layer, + return_texts=return_texts, + ) + + +# region HunyanVideo architecture + + +def load_text_encoder_1( + text_encoder_dir: str, + device: torch.device, + fp8_llm: bool, + dtype: Optional[Union[str, torch.dtype]] = None, + i2v_mode: bool = False, + image_embed_interleave: int = None, + clip_vision_path: Optional[str] = None, +) -> TextEncoder: + """ + clip_vision_path is required for i2v mode with .safetensors file. + """ + text_encoder_dtype = dtype or torch.float16 + text_encoder_type = "llm" if not i2v_mode else "llm-i2v" + text_len = 256 + hidden_state_skip_layer = 2 + apply_final_norm = False + reproduce = False + + prompt_template = "dit-llm-encode" if not i2v_mode else "dit-llm-encode-i2v" + prompt_template = PROMPT_TEMPLATE[prompt_template] + prompt_template_video = "dit-llm-encode-video" if not i2v_mode else "dit-llm-encode-video-i2v" + prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] + + crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0) + max_length = text_len + crop_start + + text_encoder_1 = TextEncoder( + text_encoder_type=text_encoder_type, + max_length=max_length, + text_encoder_dtype=text_encoder_dtype, + text_encoder_path=text_encoder_dir, + clip_vision_path=clip_vision_path, + tokenizer_type=text_encoder_type, + i2v_mode=i2v_mode, + prompt_template=prompt_template, + prompt_template_video=prompt_template_video, + hidden_state_skip_layer=hidden_state_skip_layer, + apply_final_norm=apply_final_norm, + reproduce=reproduce, + image_embed_interleave=image_embed_interleave, + ) + text_encoder_1.eval() + + if fp8_llm: + org_dtype = text_encoder_1.dtype + logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn") + text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn) + + # prepare LLM for fp8 + def prepare_fp8(llama_model: LlamaModel, target_dtype): + def forward_hook(module): + def forward(hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) + return module.weight.to(input_dtype) * hidden_states.to(input_dtype) + + return forward + + for module in llama_model.modules(): + if module.__class__.__name__ in ["Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["LlamaRMSNorm"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + prepare_fp8(text_encoder_1.model, org_dtype) + else: + text_encoder_1.to(device=device) + + return text_encoder_1 + + +def load_text_encoder_2( + text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None +) -> TextEncoder: + text_encoder_dtype = dtype or torch.float16 + reproduce = False + + text_encoder_2_type = "clipL" + text_len_2 = 77 + + text_encoder_2 = TextEncoder( + text_encoder_type=text_encoder_2_type, + max_length=text_len_2, + text_encoder_dtype=text_encoder_dtype, + text_encoder_path=text_encoder_dir, + tokenizer_type=text_encoder_2_type, + reproduce=reproduce, + ) + text_encoder_2.eval() + + text_encoder_2.to(device=device) + + return text_encoder_2 + + +# endregion + + +if __name__ == "__main__": + # Test the text encoder + import argparse + from utils.model_utils import str_to_dtype + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if False: + # This is a test script to check if the text encoder is loaded correctly and the outputs are the same. + # Compare two directories or files of text encoders: Offcial ckpt and single file ckpt. + parser = argparse.ArgumentParser() + parser.add_argument("type", type=str, help="Text Encoder type") + parser.add_argument("path1", type=str, help="Text Encoder directory or file 1") + parser.add_argument("path2", type=str, help="Text Encoder directory or file 2") + parser.add_argument("--clip_vision_path1", type=str, default=None, help="Vision Encoder directory or file 1") + parser.add_argument("--clip_vision_path2", type=str, default=None, help="Vision Encoder directory or file 2") + parser.add_argument("--image_path", type=str, default=None, help="Image path, if set, use i2v mode") + parser.add_argument("--image_embed_interleave", type=int, default=None, help="Image embed interleave") + parser.add_argument("--dtype", type=str, default=None, help="Data type for Text Encoder") + args = parser.parse_args() + + dtype = str_to_dtype(args.dtype) if args.dtype is not None else torch.float16 + + i2v_mode = args.image_path is not None + if i2v_mode: + from PIL import Image + + image = Image.open(args.image_path).convert("RGB") + semantic_images = [image] + else: + semantic_images = None + + if args.type == "clipL": + text_encoder_1 = load_text_encoder_2(args.path1, device, dtype) + text_encoder_2nd = load_text_encoder_2(args.path2, "cpu", dtype) + elif args.type == "llm" or args.type == "llm-i2v": + print("loading text encoder 1st") + text_encoder_1 = load_text_encoder_1( + args.path1, device, False, dtype, i2v_mode, args.image_embed_interleave, args.clip_vision_path1 + ) + print("loading text encoder 2nd") + text_encoder_2nd = load_text_encoder_1( + args.path2, "cpu", False, dtype, i2v_mode, args.image_embed_interleave, args.clip_vision_path2 + ) + print(f"1st Text Encoder dtype: {text_encoder_1.dtype}") + print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}") + + prompt = "A cat sitting on a table" + data_type = "video" # video only, image is not supported + text_inputs_new = text_encoder_1.text2tokens(prompt, data_type=data_type) + text_inputs_2nd = text_encoder_2nd.text2tokens(prompt, data_type=data_type) + print(text_inputs_new) + assert torch.allclose(text_inputs_new["input_ids"], text_inputs_2nd["input_ids"]) + + with torch.no_grad(): + print("Encoding with 1st text encoder") + prompt_outputs_new = text_encoder_1.encode(text_inputs_new, data_type=data_type, semantic_images=semantic_images) + del text_encoder_1 + text_encoder_2nd.to(device=device) + with torch.no_grad(): + prompt_outputs_2nd = text_encoder_2nd.encode(text_inputs_new, data_type=data_type, semantic_images=semantic_images) + + # prompt_outputs.hidden_state, prompt_outputs.attention_mask + assert torch.allclose(prompt_outputs_new.hidden_state, prompt_outputs_2nd.hidden_state) + print("Hidden states are the same.") + assert torch.allclose(prompt_outputs_new.attention_mask, prompt_outputs_2nd.attention_mask) + print("Attention masks are the same.") + print("All outputs are the same.") + + if True: + # Test Llava with image in old transformers and new transformers + # works only transformers < 4.47 (supports new behavior and legacy behavior) + parser = argparse.ArgumentParser() + parser.add_argument("path1", type=str, help="Text Encoder directory or file 1") + parser.add_argument("--clip_vision_path1", type=str, default=None, help="Vision Encoder directory or file 1") + parser.add_argument("--image_path", type=str, default=None, help="Image path, if set, use i2v mode") + parser.add_argument("--dtype", type=str, default=None, help="Data type for Text Encoder") + args = parser.parse_args() + + dtype = str_to_dtype(args.dtype) if args.dtype is not None else torch.float16 + + from PIL import Image + + image = Image.open(args.image_path).convert("RGB") + semantic_images = [image] + + text_encoder_1 = load_text_encoder_1(args.path1, device, False, dtype, True, 4, args.clip_vision_path1) + + prompt = "A short animated video of a girl standing in a classroom. The girl is wearing a sailor uniform. The upper body of the girl is shown, and the girl is talking to the camera with a rich expression and using gestures. The girl has a short black bob hairstyle, and the inner color of her hair is blue. She has red eyes and wears silver-framed glasses. High quality animated video, studio quality." + # prompt = ( + # "A short animated video of a girl standing in a classroom. The girl is wearing a sailor uniform. The upper body of the girl is shown, and the girl is talking to the camera with a rich expression and using gestures. The girl has a short black bob hairstyle, and the inner color of her hair is blue. She has red eyes and wears silver-framed glasses. High quality animated video, studio quality. " + # "A short animated video of a girl standing in a classroom. The girl is wearing a sailor uniform. The upper body of the girl is shown, and the girl is talking to the camera with a rich expression and using gestures. The girl has a short black bob hairstyle, and the inner color of her hair is blue. She has red eyes and wears silver-framed glasses. High quality animated video, studio quality. " + # "A short animated video of a girl standing in a classroom. The girl is wearing a sailor uniform. The upper body of the girl is shown, and the girl is talking to the camera with a rich expression and using gestures. The girl has a short black bob hairstyle, and the inner color of her hair is blue. She has red eyes and wears silver-framed glasses. High quality animated video, studio quality. " + # "A short animated video of a girl standing in a classroom. The girl is wearing a sailor uniform. The upper body of the girl is shown, and the girl is talking to the camera with a rich expression and using gestures. The girl has a short black bob hairstyle, and the inner color of her hair is blue. She has red eyes and wears silver-framed glasses. High quality animated video, studio quality. " + # ) + data_type = "video" # video only, image is not supported + + ### Test the new behavior of text encoder + + print("Encoding with text encoder, new behavior") + text_inputs_new = text_encoder_1.text2tokens(prompt, data_type=data_type, semantic_images=semantic_images).to(device) + print(f"text_inputs_new keys: {text_inputs_new.keys()}") + print(f"input_ids shape: {text_inputs_new['input_ids'].shape}") + print(f"attention_mask shape: {text_inputs_new['attention_mask'].shape}") + + with torch.no_grad(): + prompt_outputs_new = text_encoder_1.model(**text_inputs_new, output_hidden_states=True) + + ### Test the old behavior of text encoder + + print("Encoding with text encoder, old behavior") + text_encoder_1.text_encoder_type = "llm" # force old behavior, call tokenizer instead of processor + text_inputs_old = text_encoder_1.text2tokens(prompt, data_type=data_type).to(device) + print(f"text_inputs_old keys: {text_inputs_old.keys()}") + print(f"input_ids shape: {text_inputs_old['input_ids'].shape}") + print(f"attention_mask shape: {text_inputs_old['attention_mask'].shape}") + + with torch.no_grad(): + # original code from HunyuanVideo + clip_processor = CLIPImageProcessor.from_pretrained(LLAVA_HUGGINGFACE_MODEL_ID) + + image_outputs = clip_processor(semantic_images, return_tensors="pt")["pixel_values"].to(device) + attention_mask = text_inputs_old["attention_mask"].to(device) # if use_attention_mask else None + prompt_outputs_old = text_encoder_1.model( + input_ids=text_inputs_old["input_ids"].to(device), + attention_mask=attention_mask, + output_hidden_states=True, + pixel_values=image_outputs, + ) + + ### calc crop position + + crop_start = text_encoder_1.prompt_template_video.get("crop_start", -1) + text_crop_start = crop_start - 1 + text_encoder_1.prompt_template_video.get("image_emb_len", 576) + image_crop_start = text_encoder_1.prompt_template_video.get("image_emb_start", 5) + image_crop_end = text_encoder_1.prompt_template_video.get("image_emb_end", 581) + print(f"crop_start: {crop_start}") + print(f"text_crop_start: {text_crop_start}, image_crop_start: {image_crop_start}, image_crop_end: {image_crop_end}") + + # we test with a single prompt, so the batch_indices will be 0 + def get_batch_and_last_double_return_token_indices(batch_encoding): + batch_indices, last_double_return_token_indices = torch.where( + batch_encoding["input_ids"] == text_encoder_1.prompt_template_video.get("double_return_token_id", 271) + ) + if last_double_return_token_indices.shape[0] == 3: + # in case the prompt is too long + last_double_return_token_indices = torch.cat( + (last_double_return_token_indices, torch.tensor([batch_encoding["input_ids"].shape[-1]], device=device)) + ) + batch_indices = torch.cat((batch_indices, torch.tensor([0], device=device))) + return batch_indices, last_double_return_token_indices + + batch_indices_new, last_double_return_token_indices_new = get_batch_and_last_double_return_token_indices(text_inputs_new) + batch_indices_old, last_double_return_token_indices_old = get_batch_and_last_double_return_token_indices(text_inputs_old) + print( + f"batch_indices_new: {batch_indices_new}, last_double_return_token_indices_new: {last_double_return_token_indices_new}" + ) + print( + f"batch_indices_old: {batch_indices_old}, last_double_return_token_indices_old: {last_double_return_token_indices_old}" + ) + + def calc_attn_crop_new(batch_encoding, batch_indices, last_double_return_token_indices): + last_double_return_token_indices = last_double_return_token_indices.reshape(batch_encoding["input_ids"].shape[0], -1)[ + :, -1 + ] + print(f"new last_double_return_token_indices: {last_double_return_token_indices}") + batch_indices = batch_indices.reshape(batch_encoding["input_ids"].shape[0], -1)[:, -1] + assistant_crop_start = last_double_return_token_indices - 4 + assistant_crop_end = last_double_return_token_indices + attention_mask_assistant_crop_start = last_double_return_token_indices - 4 + attention_mask_assistant_crop_end = last_double_return_token_indices + return assistant_crop_start, assistant_crop_end, attention_mask_assistant_crop_start, attention_mask_assistant_crop_end + + def calc_attn_crop(batch_encoding, batch_indices, last_double_return_token_indices): + last_double_return_token_indices = last_double_return_token_indices.reshape(batch_encoding["input_ids"].shape[0], -1)[ + :, -1 + ] + print(f"old last_double_return_token_indices: {last_double_return_token_indices}") + batch_indices = batch_indices.reshape(batch_encoding["input_ids"].shape[0], -1)[:, -1] + assistant_crop_start = ( + last_double_return_token_indices - 1 + text_encoder_1.prompt_template_video.get("image_emb_len", 576) - 4 + ) + assistant_crop_end = ( + last_double_return_token_indices - 1 + text_encoder_1.prompt_template_video.get("image_emb_len", 576) + ) + attention_mask_assistant_crop_start = last_double_return_token_indices - 4 + attention_mask_assistant_crop_end = last_double_return_token_indices + return assistant_crop_start, assistant_crop_end, attention_mask_assistant_crop_start, attention_mask_assistant_crop_end + + ( + assistant_crop_start_new, + assistant_crop_end_new, + attention_mask_assistant_crop_start_new, + attention_mask_assistant_crop_end_new, + ) = calc_attn_crop_new(text_inputs_new, batch_indices_new, last_double_return_token_indices_new) + ( + assistant_crop_start_old, + assistant_crop_end_old, + attention_mask_assistant_crop_start_old, + attention_mask_assistant_crop_end_old, + ) = calc_attn_crop(text_inputs_old, batch_indices_old, last_double_return_token_indices_old) + + print("Assistant crop start and end:") + print( + "new", + assistant_crop_start_new, + assistant_crop_end_new, + attention_mask_assistant_crop_start_new, + attention_mask_assistant_crop_end_new, + ) + print( + "old", + assistant_crop_start_old, + assistant_crop_end_old, + attention_mask_assistant_crop_start_old, + attention_mask_assistant_crop_end_old, + ) + + ### Compare the outputs of the two models + + hidden_state_new = prompt_outputs_new.hidden_states[-(2 + 1)] + hidden_state_old = prompt_outputs_old.hidden_states[-(2 + 1)] + + def crop_hidden_state_and_attn_mask( + hidden_state, + attention_mask, + text_crop_start, + crop_start, + assistant_crop_start, + assistant_crop_end, + attention_mask_assistant_crop_start, + attention_mask_assistant_crop_end, + ): + hidden_state = torch.cat([hidden_state[0, text_crop_start:assistant_crop_start], hidden_state[0, assistant_crop_end:]]) + print(f"cropping attention mask: {attention_mask.shape}, {crop_start}, {assistant_crop_start}, {assistant_crop_end}") + attention_mask = torch.cat( + [ + attention_mask[0, crop_start:attention_mask_assistant_crop_start], + attention_mask[0, attention_mask_assistant_crop_end:], + ] + ) + return hidden_state, attention_mask + + with torch.no_grad(): + hidden_state_new = text_encoder_1.model.final_layer_norm(hidden_state_new) + hidden_state_old = text_encoder_1.model.final_layer_norm(hidden_state_old) + + hidden_state_new, attention_mask_new = crop_hidden_state_and_attn_mask( + hidden_state_new, + text_inputs_new["attention_mask"], + text_crop_start, + text_crop_start, + assistant_crop_start_new, + assistant_crop_end_new, + attention_mask_assistant_crop_start_new, + attention_mask_assistant_crop_end_new, + ) + hidden_state_old, attention_mask_old = crop_hidden_state_and_attn_mask( + hidden_state_old, + text_inputs_old["attention_mask"], + text_crop_start, + crop_start, + assistant_crop_start_old, + assistant_crop_end_old, + attention_mask_assistant_crop_start_old, + attention_mask_assistant_crop_end_old, + ) + + assert ( + hidden_state_new.shape == hidden_state_old.shape + ), f"hidden state shape is not the same: {hidden_state_new.shape} vs {hidden_state_old.shape}" + assert ( + hidden_state_new.dtype == hidden_state_old.dtype + ), f"hidden state dtype is not the same: {hidden_state_new.dtype} vs {hidden_state_old.dtype}" + print(f"hidden state shape: {hidden_state_new.shape}") + + diff = (hidden_state_new - hidden_state_old).abs() + print(f"hidden state diff: {diff.max()}, {diff.mean()}, {diff.std()}") + print(hidden_state_new[-20:, 0]) + print(hidden_state_old[-20:, 0]) + print(diff[-20:, 0]) + + assert ( + attention_mask_new.shape == attention_mask_old.shape + ), f"attention mask shape is not the same: {attention_mask_new.shape} vs {attention_mask_old.shape}" + assert ( + attention_mask_new.dtype == attention_mask_old.dtype + ), f"attention mask dtype is not the same: {attention_mask_new.dtype} vs {attention_mask_old.dtype}" + print(f"attention mask shape: {attention_mask_new.shape}") + assert torch.allclose( + attention_mask_new, attention_mask_old + ), f"attention mask is not the same. diff: {(attention_mask_new - attention_mask_old).abs().max()}" + + print(f"final attention mask: {attention_mask_new}") + + # assert torch.allclose(hidden_state_new, hidden_state_old), f"hidden state is not the same. diff: {diff}" + + import numpy as np + + diff = diff.float().cpu().numpy() # (934, 4096) + diff = diff.mean(axis=1) + + attn_mask_np = attention_mask_new.float().cpu().numpy() + diff = diff * attn_mask_np + + assert diff.max() < 1e-3, f"hidden state diff is too large: {diff.max()}" + + # # show as bar plot + # import matplotlib.pyplot as plt + + # plt.bar(range(diff.shape[0]), diff) + # plt.title("Hidden state diff") + # plt.xlabel("Hidden state index") + # plt.ylabel("Diff") + # plt.show() diff --git a/hunyuan_model/token_refiner.py b/hunyuan_model/token_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..378bbab7d5b5483f552bc37699650506dc6f790c --- /dev/null +++ b/hunyuan_model/token_refiner.py @@ -0,0 +1,245 @@ +from typing import Optional + +from einops import rearrange +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from .activation_layers import get_activation_layer +from .attention import attention +from .norm_layers import get_norm_layer +from .embed_layers import TimestepEmbedder, TextProjection +from .mlp_layers import MLP +from .modulate_layers import modulate, apply_gate + + +class IndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) + act_layer = get_activation_layer(act_type) + self.mlp = MLP( + in_channels=hidden_size, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop_rate, + **factory_kwargs, + ) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + # FFN Layer + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + + return x + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class IndividualTokenRefiner(nn.Module): + def __init__( + self, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) + for _ in range(depth) + ] + ) + + def enable_gradient_checkpointing(self): + for block in self.blocks: + block.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + for block in self.blocks: + block.disable_gradient_checkpointing() + + def forward( + self, + x: torch.Tensor, + c: torch.LongTensor, + mask: Optional[torch.Tensor] = None, + ): + self_attn_mask = None + if mask is not None: + batch_size = mask.shape[0] + seq_len = mask.shape[1] + mask = mask.to(x.device) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + # avoids self-attention weight being NaN for padding tokens + self_attn_mask[:, :, :, 0] = True + + for block in self.blocks: + x = block(x, c, self_attn_mask) + return x + + +class SingleTokenRefiner(nn.Module): + """ + A single token refiner block for llm text embedding refine. + """ + + def __init__( + self, + in_channels, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + attn_mode: str = "torch", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_mode = attn_mode + assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." + + self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs) + + act_layer = get_activation_layer(act_type) + # Build timestep embedding layer + self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) + # Build context embedding layer + self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs) + + self.individual_token_refiner = IndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) + + def enable_gradient_checkpointing(self): + self.individual_token_refiner.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.individual_token_refiner.disable_gradient_checkpointing() + + def forward( + self, + x: torch.Tensor, + t: torch.LongTensor, + mask: Optional[torch.LongTensor] = None, + ): + timestep_aware_representations = self.t_embedder(t) + + if mask is None: + context_aware_representations = x.mean(dim=1) + else: + mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] + context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + + x = self.individual_token_refiner(x, c, mask) + + return x diff --git a/hunyuan_model/vae.py b/hunyuan_model/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae718a5634e98e53a0c0dec85254228229a01c3 --- /dev/null +++ b/hunyuan_model/vae.py @@ -0,0 +1,446 @@ +from dataclasses import dataclass +import json +from typing import Optional, Tuple, Union +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn + +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils.torch_utils import randn_tensor +from diffusers.models.attention_processor import SpatialNorm +from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +SCALING_FACTOR = 0.476986 +VAE_VER = "884-16c-hy" # We don't support other versions currently + + +def load_vae( + vae_type: str = "884-16c-hy", + vae_dtype: Optional[Union[str, torch.dtype]] = None, + sample_size: tuple = None, + vae_path: str = None, + device=None, +): + """the fucntion to load the 3D VAE model + + Args: + vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy". + vae_precision (str, optional): the precision to load vae. Defaults to None. + sample_size (tuple, optional): the tiling size. Defaults to None. + vae_path (str, optional): the path to vae. Defaults to None. + logger (_type_, optional): logger. Defaults to None. + device (_type_, optional): device to load vae. Defaults to None. + """ + if vae_path is None: + vae_path = VAE_PATH[vae_type] + + logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}") + + # use fixed config for Hunyuan's VAE + CONFIG_JSON = """{ + "_class_name": "AutoencoderKLCausal3D", + "_diffusers_version": "0.4.2", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlockCausal3D", + "DownEncoderBlockCausal3D", + "DownEncoderBlockCausal3D", + "DownEncoderBlockCausal3D" + ], + "in_channels": 3, + "latent_channels": 16, + "layers_per_block": 2, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "sample_tsize": 64, + "up_block_types": [ + "UpDecoderBlockCausal3D", + "UpDecoderBlockCausal3D", + "UpDecoderBlockCausal3D", + "UpDecoderBlockCausal3D" + ], + "scaling_factor": 0.476986, + "time_compression_ratio": 4, + "mid_block_add_attention": true + }""" + + # config = AutoencoderKLCausal3D.load_config(vae_path) + config = json.loads(CONFIG_JSON) + + # import here to avoid circular import + from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D + + if sample_size: + vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) + else: + vae = AutoencoderKLCausal3D.from_config(config) + + # vae_ckpt = Path(vae_path) / "pytorch_model.pt" + # assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}" + + if vae_path.endswith(".safetensors"): + from safetensors.torch import load_file + ckpt = load_file(vae_path) + else: + ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True) + if "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + if any(k.startswith("vae.") for k in ckpt.keys()): + ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")} + vae.load_state_dict(ckpt) + + spatial_compression_ratio = vae.config.spatial_compression_ratio + time_compression_ratio = vae.config.time_compression_ratio + + if vae_dtype is not None: + vae = vae.to(vae_dtype) + + vae.requires_grad_(False) + + logger.info(f"VAE to dtype: {vae.dtype}") + + if device is not None: + vae = vae.to(device) + + vae.eval() + + return vae, vae_path, spatial_compression_ratio, time_compression_ratio + + +@dataclass +class DecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class EncoderCausal3D(nn.Module): + r""" + The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_downsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block) + else: + raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") + + downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) + downsample_stride_T = (2,) if add_time_downsample else (1,) + downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) + down_block = get_down_block3d( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=bool(add_spatial_downsample or add_time_downsample), + downsample_stride=downsample_stride, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockCausal3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `EncoderCausal3D` class.""" + assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" + + sample = self.conv_in(sample) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class DecoderCausal3D(nn.Module): + r""" + The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlockCausal3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_upsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_upsample = bool(i < num_spatial_upsample_layers) + add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block) + else: + raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") + + upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) + upsample_scale_factor_T = (2,) if add_time_upsample else (1,) + upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) + up_block = get_up_block3d( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=bool(add_spatial_upsample or add_time_upsample), + upsample_scale_factor=upsample_scale_factor, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `DecoderCausal3D` class.""" + assert len(sample.shape) == 5, "The input tensor should have 5 dimensions." + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + latent_embeds, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + if parameters.ndim == 3: + dim = 2 # (B, L, C) + elif parameters.ndim == 5 or parameters.ndim == 4: + dim = 1 # (B, C, T, H ,W) / (B, C, H, W) + else: + raise NotImplementedError + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + reduce_dim = list(range(1, self.mean.ndim)) + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=reduce_dim, + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=reduce_dim, + ) + + def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean diff --git a/hv_generate_video.py b/hv_generate_video.py new file mode 100644 index 0000000000000000000000000000000000000000..cb39def339b5a1c790dcba8b3e501f540c78e4a1 --- /dev/null +++ b/hv_generate_video.py @@ -0,0 +1,956 @@ +import argparse +from datetime import datetime +from pathlib import Path +import random +import sys +import os +import time +from typing import Optional, Union + +import numpy as np +import torch +import torchvision +import accelerate +from diffusers.utils.torch_utils import randn_tensor +from transformers.models.llama import LlamaModel +from tqdm import tqdm +import av +from einops import rearrange +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from PIL import Image + +from hunyuan_model import vae +from hunyuan_model.text_encoder import TextEncoder +from hunyuan_model.text_encoder import PROMPT_TEMPLATE +from hunyuan_model.vae import load_vae +from hunyuan_model.models import load_transformer, get_rotary_pos_embed +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +from networks import lora + +try: + from lycoris.kohya import create_network_from_weights +except: + pass + +from utils.model_utils import str_to_dtype +from utils.safetensors_utils import mem_eff_save_file +from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket +import math +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def clean_memory_on_device(device): + if device.type == "cuda": + torch.cuda.empty_cache() + elif device.type == "cpu": + pass + elif device.type == "mps": # not tested + torch.mps.empty_cache() + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def extend_video_frames(video: torch.Tensor, target_frames: int) -> torch.Tensor: + current_frames = video.shape[2] + if current_frames >= target_frames: + return video + + base_repeats = target_frames // current_frames + extra = target_frames % current_frames # Remaining repeats to distribute + + # Create repeat tensor with partial repetition for even distribution + repeats = torch.full((current_frames,), base_repeats, + dtype=torch.int64, device=video.device) + repeats[:extra] += 1 # Distribute extra repeats to early frames + + # Create interleaved index pattern (e.g., 001122.. instead of 012012) + indices = torch.arange(current_frames, device=video.device) + indices = indices.repeat_interleave(repeats) + + extended_video = torch.index_select(video, 2, indices) + return extended_video + +def load_and_extend_video(args, video_length: int): + """ + Load video and extend it if needed to match target length. + """ + if os.path.isfile(args.video_path): + video = load_video(args.video_path, 0, video_length, bucket_reso=(args.video_size[1], args.video_size[0])) + else: + video = load_images(args.video_path, video_length, bucket_reso=(args.video_size[1], args.video_size[0])) + + if len(video) < video_length: + logger.info(f"Video length ({len(video)}) is less than target length ({video_length}). Extending video...") + # Convert list of frames to tensor + video_tensor = torch.from_numpy(np.stack(video, axis=0)) # [F, H, W, C] + video_tensor = video_tensor.permute(3, 0, 1, 2).unsqueeze(0) # [1, C, F, H, W] + + # Extend the video + extended_tensor = extend_video_frames(video_tensor, video_length) + + # Convert back to list of frames + extended_tensor = extended_tensor.squeeze(0).permute(1, 2, 3, 0) # [F, H, W, C] + video = [frame.numpy() for frame in extended_tensor] + + logger.info(f"Video extended to {len(video)} frames") + + video = np.stack(video, axis=0) # F, H, W, C + video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W + video = video / 255.0 + + return video + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24): + """save videos by video tensor + copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61 + + Args: + videos (torch.Tensor): video tensor predicted by the model + path (str): path to save video + rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False. + n_rows (int, optional): Defaults to 1. + fps (int, optional): video save fps. Defaults to 8. + """ + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x, 0, 1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + # # save video with av + # container = av.open(path, "w") + # stream = container.add_stream("libx264", rate=fps) + # for x in outputs: + # frame = av.VideoFrame.from_ndarray(x, format="rgb24") + # packet = stream.encode(frame) + # container.mux(packet) + # packet = stream.encode(None) + # container.mux(packet) + # container.close() + + height, width, _ = outputs[0].shape + + # create output container + container = av.open(path, mode="w") + + # create video stream + codec = "libx264" + pixel_format = "yuv420p" + stream = container.add_stream(codec, rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = pixel_format + stream.bit_rate = 4000000 # 4Mbit/s + + for frame_array in outputs: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + packets = stream.encode(frame) + for packet in packets: + container.mux(packet) + + for packet in stream.encode(): + container.mux(packet) + + container.close() + + +def save_images_grid( + videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True +): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x, 0, 1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + if create_subdir: + output_dir = os.path.join(parent_dir, image_name) + else: + output_dir = parent_dir + + os.makedirs(output_dir, exist_ok=True) + for i, x in enumerate(outputs): + image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png") + image = Image.fromarray(x) + image.save(image_path) + + +# region Encoding prompt + + +def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + text_encoder (TextEncoder): + text encoder to be used for encoding the prompt + """ + # LoRA and Textual Inversion are not supported in this script + # negative prompt and prompt embedding are not supported in this script + # clip_skip is not supported in this script because it is not used in the original script + data_type = "video" # video only, image is not supported + + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) + + with torch.no_grad(): + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device) + prompt_embeds = prompt_outputs.hidden_state + + attention_mask = prompt_outputs.attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(device) + bs_embed, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, num_videos_per_prompt) + attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len) + + prompt_embeds_dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if prompt_embeds.ndim == 2: + bs_embed, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) + else: + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, attention_mask + + +def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None): + # constants + prompt_template_video = "dit-llm-encode-video" + prompt_template = "dit-llm-encode" + text_encoder_dtype = torch.float16 + text_encoder_type = "llm" + text_len = 256 + hidden_state_skip_layer = 2 + apply_final_norm = False + reproduce = False + + text_encoder_2_type = "clipL" + text_len_2 = 77 + + num_videos = 1 + + # if args.prompt_template_video is not None: + # crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0) + # elif args.prompt_template is not None: + # crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0) + # else: + # crop_start = 0 + crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0) + max_length = text_len + crop_start + + # prompt_template + prompt_template = PROMPT_TEMPLATE[prompt_template] + + # prompt_template_video + prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None + + # load text encoders + logger.info(f"loading text encoder: {args.text_encoder1}") + text_encoder = TextEncoder( + text_encoder_type=text_encoder_type, + max_length=max_length, + text_encoder_dtype=text_encoder_dtype, + text_encoder_path=args.text_encoder1, + tokenizer_type=text_encoder_type, + prompt_template=prompt_template, + prompt_template_video=prompt_template_video, + hidden_state_skip_layer=hidden_state_skip_layer, + apply_final_norm=apply_final_norm, + reproduce=reproduce, + ) + text_encoder.eval() + if fp8_llm: + org_dtype = text_encoder.dtype + logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn") + text_encoder.to(device=device, dtype=torch.float8_e4m3fn) + + # prepare LLM for fp8 + def prepare_fp8(llama_model: LlamaModel, target_dtype): + def forward_hook(module): + def forward(hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) + return module.weight.to(input_dtype) * hidden_states.to(input_dtype) + + return forward + + for module in llama_model.modules(): + if module.__class__.__name__ in ["Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["LlamaRMSNorm"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + prepare_fp8(text_encoder.model, org_dtype) + + logger.info(f"loading text encoder 2: {args.text_encoder2}") + text_encoder_2 = TextEncoder( + text_encoder_type=text_encoder_2_type, + max_length=text_len_2, + text_encoder_dtype=text_encoder_dtype, + text_encoder_path=args.text_encoder2, + tokenizer_type=text_encoder_2_type, + reproduce=reproduce, + ) + text_encoder_2.eval() + + # encode prompt + logger.info(f"Encoding prompt with text encoder 1") + text_encoder.to(device=device) + if fp8_llm: + with accelerator.autocast(): + prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder) + else: + prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder) + text_encoder = None + clean_memory_on_device(device) + + logger.info(f"Encoding prompt with text encoder 2") + text_encoder_2.to(device=device) + prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2) + + prompt_embeds = prompt_embeds.to("cpu") + prompt_mask = prompt_mask.to("cpu") + prompt_embeds_2 = prompt_embeds_2.to("cpu") + prompt_mask_2 = prompt_mask_2.to("cpu") + + text_encoder_2 = None + clean_memory_on_device(device) + + return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 + + +# endregion + + +def load_images(image_dir, video_length, bucket_reso): + image_files = glob_images(image_dir) + if len(image_files) == 0: + raise ValueError(f"No image files found in {image_dir}") + if len(image_files) < video_length: + raise ValueError(f"Number of images in {image_dir} is less than {video_length}") + + image_files.sort() + images = [] + for image_file in image_files[:video_length]: + image = Image.open(image_file) + image = resize_image_to_bucket(image, bucket_reso) # returns a numpy array + images.append(image) + + return images + + +def prepare_vae(args, device): + vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype) + vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae) + vae.eval() + # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} + + # set chunk_size to CausalConv3d recursively + chunk_size = args.vae_chunk_size + if chunk_size is not None: + vae.set_chunk_size_for_causal_conv_3d(chunk_size) + logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d") + + if args.vae_spatial_tile_sample_min_size is not None: + vae.enable_spatial_tiling(True) + vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size + vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 + # elif args.vae_tiling: + else: + vae.enable_spatial_tiling(True) + + return vae, vae_dtype + + +def encode_to_latents(args, video, device): + vae, vae_dtype = prepare_vae(args, device) + + video = video.to(device=device, dtype=vae_dtype) + video = video * 2 - 1 # 0, 1 -> -1, 1 + with torch.no_grad(): + latents = vae.encode(video).latent_dist.sample() + + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor + else: + latents = latents * vae.config.scaling_factor + + return latents + + +def decode_latents(args, latents, device): + vae, vae_dtype = prepare_vae(args, device) + + expand_temporal_dim = False + if len(latents.shape) == 4: + latents = latents.unsqueeze(2) + expand_temporal_dim = True + elif len(latents.shape) == 5: + pass + else: + raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.") + + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = latents / vae.config.scaling_factor + vae.config.shift_factor + else: + latents = latents / vae.config.scaling_factor + + latents = latents.to(device=device, dtype=vae_dtype) + with torch.no_grad(): + image = vae.decode(latents, return_dict=False)[0] + + if expand_temporal_dim: + image = image.squeeze(2) + + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().float() + + return image + + +def parse_args(): + parser = argparse.ArgumentParser(description="HunyuanVideo inference script") + + parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory") + parser.add_argument( + "--dit_in_channels", + type=int, + default=None, + help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others", + ) + parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") + parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory") + parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory") + + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights") + + # inference + parser.add_argument("--prompt", type=str, required=True, help="prompt for generation") + parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation") + parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size") + parser.add_argument("--video_length", type=int, default=129, help="video length") + parser.add_argument("--fps", type=int, default=24, help="video fps") + parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + parser.add_argument( + "--guidance_scale", + type=float, + default=1.0, + help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)", + ) + parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.") + parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") + parser.add_argument( + "--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model" + ) + parser.add_argument( + "--split_uncond", + action="store_true", + help="split unconditional call for classifier free guidance, slower but less memory usage", + ) + parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference") + + # Flow Matching + parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.") + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode" + ) + parser.add_argument( + "--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True" + ) + parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") + parser.add_argument( + "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" + ) + parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model") + parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu") + parser.add_argument( + "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") + + args = parser.parse_args() + + assert (args.latent_path is None or len(args.latent_path) == 0) or ( + args.output_type == "images" or args.output_type == "video" + ), "latent_path is only supported for images or video output" + + # update dit_weight based on model_base if not exists + + return args + + +def check_inputs(args): + height = args.video_size[0] + width = args.video_size[1] + video_length = args.video_length + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + return height, width, video_length + + +def main(): + args = parse_args() + + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + dit_dtype = torch.bfloat16 + dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype + logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}") + + original_base_names = None + if args.latent_path is not None and len(args.latent_path) > 0: + original_base_names = [] + latents_list = [] + seeds = [] + for latent_path in args.latent_path: + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + seed = 0 + + if os.path.splitext(latent_path)[1] != ".safetensors": + latents = torch.load(latent_path, map_location="cpu") + else: + latents = load_file(latent_path)["latent"] + with safe_open(latent_path, framework="pt") as f: + metadata = f.metadata() + logger.info(f"Loaded metadata: {metadata}") + + if "seeds" in metadata: + seed = int(metadata["seeds"]) + + seeds.append(seed) + latents_list.append(latents) + + logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") + latents = torch.stack(latents_list, dim=0) + else: + # prepare accelerator + mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" + accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) + + # load prompt + prompt = args.prompt # TODO load prompts from file + assert prompt is not None, "prompt is required" + + # check inputs: may be height, width, video_length etc will be changed for each generation in future + height, width, video_length = check_inputs(args) + + # encode prompt with LLM and Text Encoder + logger.info(f"Encoding prompt: {prompt}") + + do_classifier_free_guidance = args.guidance_scale != 1.0 + if do_classifier_free_guidance: + negative_prompt = args.negative_prompt + if negative_prompt is None: + logger.info("Negative prompt is not provided, using empty prompt") + negative_prompt = "" + logger.info(f"Encoding negative prompt: {negative_prompt}") + prompt = [negative_prompt, prompt] + else: + if args.negative_prompt is not None: + logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.") + + prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt( + prompt, args, device, args.fp8_llm, accelerator + ) + + video_latents = None + if args.video_path is not None: + # v2v inference + logger.info(f"Video2Video inference: {args.video_path}") + video = load_and_extend_video(args, video_length) + + if not isinstance(video, torch.Tensor): + video = np.stack(video, axis=0) # F, H, W, C + video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W + video = video / 255.0 + + # Verify the video length after extension + if video.shape[2] < video_length: # Check tensor dimension F + raise ValueError(f"Video length {video.shape[2]} is less than target length {video_length} after extension") + + logger.info(f"Encoding video to latents") + video_latents = encode_to_latents(args, video, device) + video_latents = video_latents.to(device=device, dtype=dit_dtype) + + clean_memory_on_device(device) + + # encode latents for image2video inference + image_latents = None + if args.image_path is not None: + # i2v inference + logger.info(f"Image2Video inference: {args.image_path}") + + image = Image.open(args.image_path) + image = resize_image_to_bucket(image, (width, height)) # returns a numpy array + image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W + image = image / 255.0 + + logger.info(f"Encoding image to latents") + image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W + image_latents = image_latents.to(device=device, dtype=dit_dtype) + + clean_memory_on_device(device) + + # load DiT model + blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0 + loading_device = "cpu" # if blocks_to_swap > 0 else device + + logger.info(f"Loading DiT model from {args.dit}") + if args.attn_mode == "sdpa": + args.attn_mode = "torch" + + # if image_latents is given, the model should be I2V model, so the in_channels should be 32 + dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16) + + # if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16 + # the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway + # on the fly merging will be a solution for this issue for .safetenors files (not implemented yet) + transformer = load_transformer( + args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels + ) + transformer.eval() + + # load LoRA weights + if args.lora_weight is not None and len(args.lora_weight) > 0: + for i, lora_weight in enumerate(args.lora_weight): + if args.lora_multiplier is not None and len(args.lora_multiplier) > i: + lora_multiplier = args.lora_multiplier[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + # Filter to exclude keys that are part of single_blocks + if args.exclude_single_blocks: + filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k} + weights_sd = filtered_weights + + if args.lycoris: + lycoris_net, _ = create_network_from_weights( + multiplier=lora_multiplier, + file=None, + weights_sd=weights_sd, + unet=transformer, + text_encoder=None, + vae=None, + for_inference=True, + ) + else: + network = lora.create_arch_network_from_weights( + lora_multiplier, weights_sd, unet=transformer, for_inference=True + ) + logger.info("Merging LoRA weights to DiT model") + + # try: + # network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + # info = network.load_state_dict(weights_sd, strict=True) + # logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + # network.eval() + # network.to(device) + # except Exception as e: + if args.lycoris: + lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device) + else: + network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True) + + synchronize_device(device) + + logger.info("LoRA weights loaded") + + # save model here before casting to dit_weight_dtype + if args.save_merged_model: + logger.info(f"Saving merged model to {args.save_merged_model}") + mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory + logger.info("Merged model saved") + return + + if blocks_to_swap > 0: + logger.info(f"Casting model to {dit_weight_dtype}") + transformer.to(dtype=dit_weight_dtype) + logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}") + transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False) + transformer.move_to_device_except_swap_blocks(device) + transformer.prepare_block_swap_before_forward() + else: + logger.info(f"Moving and casting model to {device} and {dit_weight_dtype}") + transformer.to(device=device, dtype=dit_weight_dtype) + if args.img_in_txt_in_offloading: + logger.info("Enable offloading img_in and txt_in to CPU") + transformer.enable_img_in_txt_in_offloading() + + # load scheduler + logger.info(f"Loading scheduler") + scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler") + + # Prepare timesteps + num_inference_steps = args.infer_steps + scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler + timesteps = scheduler.timesteps + + # Prepare generator + num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size + seed = args.seed + if seed is None: + seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)] + elif isinstance(seed, int): + seeds = [seed + i for i in range(num_videos_per_prompt)] + else: + raise ValueError(f"Seed must be an integer or None, got {seed}.") + generator = [torch.Generator(device).manual_seed(seed) for seed in seeds] + + # Prepare noisy latents + num_channels_latents = 16 # transformer.config.in_channels + vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4 + + vae_ver = vae.VAE_VER + if "884" in vae_ver: + latent_video_length = (video_length - 1) // 4 + 1 + elif "888" in vae_ver: + latent_video_length = (video_length - 1) // 8 + 1 + else: + latent_video_length = video_length + + # shape = ( + # num_videos_per_prompt, + # num_channels_latents, + # latent_video_length, + # height // vae_scale_factor, + # width // vae_scale_factor, + # ) + # latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype) + + # make first N frames to be the same if the given seed is same + shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor) + latents = [] + for i in range(latent_video_length): + latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype)) + latents = torch.cat(latents, dim=2) + + # pad image_latents to match the length of video_latents + if image_latents is not None: + zero_latents = torch.zeros_like(latents) + zero_latents[:, :, :1, :, :] = image_latents + image_latents = zero_latents + + if args.video_path is not None: + # v2v inference + noise = latents + assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}" + + num_inference_steps = int(num_inference_steps * args.strength) + timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time + t = timestep_start / 1000.0 + latents = noise * t + video_latents * (1 - t) + + timesteps = timesteps[-num_inference_steps:] + + logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}") + + # FlowMatchDiscreteScheduler does not have init_noise_sigma + + # Denoising loop + embedded_guidance_scale = args.embedded_cfg_scale + if embedded_guidance_scale is not None: + guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu") + guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype) + if do_classifier_free_guidance: + guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0) + else: + guidance_expand = None + freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width) + # n_tokens = freqs_cos.shape[0] + + # move and cast all inputs to the correct device and dtype + prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype) + prompt_mask = prompt_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype) + prompt_mask_2 = prompt_mask_2.to(device=device) + + freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype) + freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype) + + num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference + + # assert split_uncond and split_attn + if args.split_attn and do_classifier_free_guidance and not args.split_uncond: + logger.warning("split_attn is enabled, split_uncond will be enabled as well.") + args.split_uncond = True + + # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p: + with tqdm(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents = scheduler.scale_model_input(latents, t) + + # predict the noise residual + with torch.no_grad(), accelerator.autocast(): + latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0) + if image_latents is not None: + latents_image_input = ( + image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0) + ) + latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W + + batch_size = 1 if args.split_uncond else latents_input.shape[0] + + noise_pred_list = [] + for j in range(0, latents_input.shape[0], batch_size): + noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256) + latents_input[j : j + batch_size], # [1, 16, 33, 24, 42] + t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1] + text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096] + text_mask=prompt_mask[j : j + batch_size], # [1, 256] + text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768] + freqs_cos=freqs_cos, # [seqlen, head_dim] + freqs_sin=freqs_sin, # [seqlen, head_dim] + guidance=guidance_expand[j : j + batch_size], # [1] + return_dict=True, + )["x"] + noise_pred_list.append(noise_pred) + noise_pred = torch.cat(noise_pred_list, dim=0) + + # perform classifier free guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # # SkyReels' rescale noise config is omitted for now + # if guidance_rescale > 0.0: + # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # noise_pred = rescale_noise_cfg( + # noise_pred, + # noise_pred_cond, + # guidance_rescale=self.guidance_rescale, + # ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): + if progress_bar is not None: + progress_bar.update() + + # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1)) + # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + latents = latents.detach().cpu() + transformer = None + clean_memory_on_device(device) + + # Save samples + output_type = args.output_type + save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}" + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + if output_type == "latent" or output_type == "both": + # save latent + for i, latent in enumerate(latents): + latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors" + + if args.no_metadata: + metadata = None + else: + metadata = { + "seeds": f"{seeds[i]}", + "prompt": f"{args.prompt}", + "height": f"{height}", + "width": f"{width}", + "video_length": f"{video_length}", + "infer_steps": f"{num_inference_steps}", + "guidance_scale": f"{args.guidance_scale}", + "embedded_cfg_scale": f"{args.embedded_cfg_scale}", + } + if args.negative_prompt is not None: + metadata["negative_prompt"] = f"{args.negative_prompt}" + sd = {"latent": latent} + save_file(sd, latent_path, metadata=metadata) + + logger.info(f"Latent save to: {latent_path}") + if output_type == "video" or output_type == "both": + # save video + videos = decode_latents(args, latents, device) + for i, sample in enumerate(videos): + original_name = "" if original_base_names is None else f"_{original_base_names[i]}" + sample = sample.unsqueeze(0) + video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4" + save_videos_grid(sample, video_path, fps=args.fps) + logger.info(f"Sample save to: {video_path}") + elif output_type == "images": + # save images + videos = decode_latents(args, latents, device) + for i, sample in enumerate(videos): + original_name = "" if original_base_names is None else f"_{original_base_names[i]}" + sample = sample.unsqueeze(0) + image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}" + save_images_grid(sample, save_path, image_name) + logger.info(f"Sample images save to: {save_path}/{image_name}") + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/hv_generate_video_with_hunyuani2v.py b/hv_generate_video_with_hunyuani2v.py new file mode 100644 index 0000000000000000000000000000000000000000..4dfb8de9dda2d98f073fe7aeba42115917806e5a --- /dev/null +++ b/hv_generate_video_with_hunyuani2v.py @@ -0,0 +1,973 @@ +import argparse +from datetime import datetime +from pathlib import Path +import random +import sys +import os +import time +from typing import Optional, Union + +import numpy as np +import torch +import torchvision +import accelerate +from diffusers.utils.torch_utils import randn_tensor +from transformers.models.llama import LlamaModel +from tqdm import tqdm +import av +from einops import rearrange +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from PIL import Image + +from hunyuan_model import vae +from hunyuan_model import text_encoder +from hunyuan_model.text_encoder import TextEncoder +from hunyuan_model.text_encoder import PROMPT_TEMPLATE +from hunyuan_model.vae import load_vae +from hunyuan_model.models import load_transformer, get_rotary_pos_embed +from hunyuan_model.fp8_optimization import convert_fp8_linear +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +from networks import lora + +try: + from lycoris.kohya import create_network_from_weights +except: + pass + +from utils.model_utils import str_to_dtype +from utils.safetensors_utils import mem_eff_save_file +from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def clean_memory_on_device(device): + if device.type == "cuda": + torch.cuda.empty_cache() + elif device.type == "cpu": + pass + elif device.type == "mps": # not tested + torch.mps.empty_cache() + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24): + """save videos by video tensor + copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61 + + Args: + videos (torch.Tensor): video tensor predicted by the model + path (str): path to save video + rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False. + n_rows (int, optional): Defaults to 1. + fps (int, optional): video save fps. Defaults to 8. + """ + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x, 0, 1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + # # save video with av + # container = av.open(path, "w") + # stream = container.add_stream("libx264", rate=fps) + # for x in outputs: + # frame = av.VideoFrame.from_ndarray(x, format="rgb24") + # packet = stream.encode(frame) + # container.mux(packet) + # packet = stream.encode(None) + # container.mux(packet) + # container.close() + + height, width, _ = outputs[0].shape + + # create output container + container = av.open(path, mode="w") + + # create video stream + codec = "libx264" + pixel_format = "yuv420p" + stream = container.add_stream(codec, rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = pixel_format + stream.bit_rate = 4000000 # 4Mbit/s + + for frame_array in outputs: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + packets = stream.encode(frame) + for packet in packets: + container.mux(packet) + + for packet in stream.encode(): + container.mux(packet) + + container.close() + + +def save_images_grid( + videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True +): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x, 0, 1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + if create_subdir: + output_dir = os.path.join(parent_dir, image_name) + else: + output_dir = parent_dir + + os.makedirs(output_dir, exist_ok=True) + for i, x in enumerate(outputs): + image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png") + image = Image.fromarray(x) + image.save(image_path) + + +# region Encoding prompt + + +def encode_prompt( + prompt: Union[str, list[str]], + semantic_images: Optional[Union[Image.Image, list[Image.Image]]], + device: torch.device, + num_videos_per_prompt: int, + text_encoder: TextEncoder, +): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + semantic_images (`Image.Image` or `List[Image.Image]`): + semantic images to be used for I2V model + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + text_encoder (TextEncoder): + text encoder to be used for encoding the prompt + """ + # LoRA and Textual Inversion are not supported in this script + # negative prompt and prompt embedding are not supported in this script + # clip_skip is not supported in this script because it is not used in the original script + data_type = "video" # video only, image is not supported + + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, semantic_images=semantic_images) + + with torch.no_grad(): + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device, semantic_images=semantic_images) + prompt_embeds = prompt_outputs.hidden_state + + attention_mask = prompt_outputs.attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(device) + bs_embed, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, num_videos_per_prompt) + attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len) + + prompt_embeds_dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if prompt_embeds.ndim == 2: + bs_embed, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) + else: + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, attention_mask + + +def encode_input_prompt( + hunyuan_video_i2v: bool, + prompt: Union[str, list[str]], + semantic_images: Optional[Union[Image.Image, list[Image.Image]]], + image_embed_interleave: Optional[int], + args, + device, + fp8_llm=False, + accelerator=None, +): + # constants + prompt_template_video = "dit-llm-encode-video" if not hunyuan_video_i2v else "dit-llm-encode-video-i2v" + prompt_template = "dit-llm-encode" if not hunyuan_video_i2v else "dit-llm-encode-i2v" + # text_encoder_dtype = torch.float16 + # text_encoder_type = "llm" if not hunyuan_video_i2v else "llm-i2v" + # text_len = 256 + # hidden_state_skip_layer = 2 + # apply_final_norm = False + # reproduce = False + + # text_encoder_2_type = "clipL" + # text_len_2 = 77 + + num_videos = 1 + + if type(prompt) == list: + # use default negative prompt for None + default_negative_prompt = text_encoder.NEGATIVE_PROMPT if not hunyuan_video_i2v else text_encoder.NEGATIVE_PROMPT_I2V + replaced_prompt = [p if p is not None else default_negative_prompt for p in prompt] + prompt = replaced_prompt + if semantic_images is not None and type(semantic_images) != list: + semantic_images = [semantic_images] + + # crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0) + # max_length = text_len + crop_start + + # prompt_template + prompt_template = PROMPT_TEMPLATE[prompt_template] + + # prompt_template_video + prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None + + # load text encoders + logger.info(f"loading text encoder: {args.text_encoder1}") + text_encoder_1 = text_encoder.load_text_encoder_1( + text_encoder_dir=args.text_encoder1, + device="cpu", + fp8_llm=fp8_llm, + i2v_mode=hunyuan_video_i2v, + image_embed_interleave=image_embed_interleave, + clip_vision_path=args.clip_vision_path, + ) + text_encoder_1.eval() + + logger.info(f"loading text encoder 2: {args.text_encoder2}") + text_encoder_2 = text_encoder.load_text_encoder_2(text_encoder_dir=args.text_encoder2, device="cpu") + text_encoder_2.eval() + + # encode prompt + logger.info(f"Encoding prompt with text encoder 1") + text_encoder_1.to(device=device) + if fp8_llm: + with accelerator.autocast(): + prompt_embeds, prompt_mask = encode_prompt(prompt, semantic_images, device, num_videos, text_encoder_1) + else: + prompt_embeds, prompt_mask = encode_prompt(prompt, semantic_images, device, num_videos, text_encoder_1) + text_encoder_1 = None + clean_memory_on_device(device) + + logger.info(f"Encoding prompt with text encoder 2") + text_encoder_2.to(device=device) + prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, semantic_images, device, num_videos, text_encoder_2) + + prompt_embeds = prompt_embeds.to("cpu") + prompt_mask = prompt_mask.to("cpu") + prompt_embeds_2 = prompt_embeds_2.to("cpu") + prompt_mask_2 = prompt_mask_2.to("cpu") + + text_encoder_2 = None + clean_memory_on_device(device) + + return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 + + +# endregion + + +def load_images(image_dir, video_length, bucket_reso): + image_files = glob_images(image_dir) + if len(image_files) == 0: + raise ValueError(f"No image files found in {image_dir}") + if len(image_files) < video_length: + raise ValueError(f"Number of images in {image_dir} is less than {video_length}") + + image_files.sort() + images = [] + for image_file in image_files[:video_length]: + image = Image.open(image_file) + image = resize_image_to_bucket(image, bucket_reso) # returns a numpy array + images.append(image) + + return images + + +def prepare_vae(args, device): + vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype) + vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae) + vae.eval() + # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} + + # set chunk_size to CausalConv3d recursively + chunk_size = args.vae_chunk_size + if chunk_size is not None: + vae.set_chunk_size_for_causal_conv_3d(chunk_size) + logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d") + + if args.vae_spatial_tile_sample_min_size is not None: + vae.enable_spatial_tiling(True) + vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size + vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 + # elif args.vae_tiling: + else: + vae.enable_spatial_tiling(True) + + return vae, vae_dtype + + +def encode_to_latents(args, video, device): + vae, vae_dtype = prepare_vae(args, device) + + video = video.to(device=device, dtype=vae_dtype) + video = video * 2 - 1 # 0, 1 -> -1, 1 + with torch.no_grad(): + latents = vae.encode(video).latent_dist.sample() + + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor + else: + latents = latents * vae.config.scaling_factor + + return latents + + +def decode_latents(args, latents, device): + vae, vae_dtype = prepare_vae(args, device) + + expand_temporal_dim = False + if len(latents.shape) == 4: + latents = latents.unsqueeze(2) + expand_temporal_dim = True + elif len(latents.shape) == 5: + pass + else: + raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.") + + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = latents / vae.config.scaling_factor + vae.config.shift_factor + else: + latents = latents / vae.config.scaling_factor + + latents = latents.to(device=device, dtype=vae_dtype) + with torch.no_grad(): + image = vae.decode(latents, return_dict=False)[0] + + if expand_temporal_dim: + image = image.squeeze(2) + + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().float() + + return image + + +def parse_args(): + parser = argparse.ArgumentParser(description="HunyuanVideo inference script") + + parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory") + parser.add_argument( + "--dit_in_channels", + type=int, + default=None, + help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others", + ) + parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") + parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 path or directory") + parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 path or directory") + parser.add_argument("--clip_vision_path", type=str, default=None, help="CLIP vision model path for HunyuanVideo-I2V") + + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights") + + # inference + parser.add_argument("--prompt", type=str, required=True, help="prompt for generation") + parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation") + parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size") + parser.add_argument("--video_length", type=int, default=129, help="video length") + parser.add_argument("--fps", type=int, default=24, help="video fps") + parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + parser.add_argument( + "--guidance_scale", + type=float, + default=1.0, + help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)", + ) + parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.") + parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") + parser.add_argument( + "--image_path", type=str, default=None, help="path to image for image2video inference, only works for I2V model" + ) + parser.add_argument("--i2v_stability", action="store_true", help="use stability for HunyuanVideo-I2V model") + parser.add_argument( + "--split_uncond", + action="store_true", + help="split unconditional call for classifier free guidance, slower but less memory usage", + ) + parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference") + + # Flow Matching + parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.") + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode" + ) + parser.add_argument( + "--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True" + ) + parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") + parser.add_argument( + "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" + ) + parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model") + parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu") + parser.add_argument( + "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") + parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)") + parser.add_argument("--compile", action="store_true", help="Enable torch.compile") + parser.add_argument( + "--compile_args", nargs=4, metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), + default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], + help="Torch.compile settings" + ) + + args = parser.parse_args() + + assert (args.latent_path is None or len(args.latent_path) == 0) or ( + args.output_type == "images" or args.output_type == "video" + ), "latent_path is only supported for images or video output" + + # update dit_weight based on model_base if not exists + + if args.fp8_fast and not args.fp8: + raise ValueError("--fp8_fast requires --fp8") + + return args + + +def check_inputs(args): + height = args.video_size[0] + width = args.video_size[1] + video_length = args.video_length + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + return height, width, video_length + + +def main(): + args = parse_args() + + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + dit_dtype = torch.bfloat16 + dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype + logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}") + + original_base_names = None + if args.latent_path is not None and len(args.latent_path) > 0: + original_base_names = [] + latents_list = [] + seeds = [] + for latent_path in args.latent_path: + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + seed = 0 + + if os.path.splitext(latent_path)[1] != ".safetensors": + latents = torch.load(latent_path, map_location="cpu") + else: + latents = load_file(latent_path)["latent"] + with safe_open(latent_path, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + logger.info(f"Loaded metadata: {metadata}") + + if "seeds" in metadata: + seed = int(metadata["seeds"]) + + seeds.append(seed) + latents_list.append(latents) + + logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") + latents = torch.stack(latents_list, dim=0) + else: + # prepare accelerator + mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" + accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) + + # SkyReels-I2V or HunyuanVideo-I2V + i2v = args.image_path is not None + dit_in_channels = 16 if args.dit_in_channels is None else args.dit_in_channels + semantic_image: Image.Image = None + if i2v: + semantic_image = Image.open(args.image_path).convert("RGB") + + sky_reels_i2v = dit_in_channels == 32 + hunyuan_video_i2v = not sky_reels_i2v # only supports "token_replace" mode for HunyuanVideo-I2V + if hunyuan_video_i2v: + image_embed_interleave = 4 + else: + image_embed_interleave = None + else: + sky_reels_i2v = False + hunyuan_video_i2v = False + image_embed_interleave = None + + # load prompt + prompt = args.prompt # TODO load prompts from file + assert prompt is not None, "prompt is required" + + # check inputs: may be height, width, video_length etc will be changed for each generation in future + height, width, video_length = check_inputs(args) + + # encode prompt with LLM and Text Encoder + logger.info(f"Encoding prompt: {prompt}") + + do_classifier_free_guidance = args.guidance_scale != 1.0 + semantic_images = [semantic_image] # for prompt + if do_classifier_free_guidance: + negative_prompt = args.negative_prompt + if negative_prompt is None: + logger.info("Negative prompt is not provided, using default prompt") + logger.info(f"Encoding negative prompt: {negative_prompt}") + prompt = [negative_prompt, prompt] + # use black image for negative prompt + semantic_images = [Image.new("RGB", semantic_image.size, (0, 0, 0)), semantic_image] + else: + if args.negative_prompt is not None: + logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.") + + prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt( + hunyuan_video_i2v, prompt, semantic_images, image_embed_interleave, args, device, args.fp8_llm, accelerator + ) + + # encode latents for video2video inference + video_latents = None + if args.video_path is not None: + # v2v inference + logger.info(f"Video2Video inference: {args.video_path}") + + if os.path.isfile(args.video_path): + video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames + else: + video = load_images(args.video_path, video_length, bucket_reso=(width, height)) # list of frames + + if len(video) < video_length: + raise ValueError(f"Video length is less than {video_length}") + video = np.stack(video, axis=0) # F, H, W, C + video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W + video = video / 255.0 + + logger.info(f"Encoding video to latents") + video_latents = encode_to_latents(args, video, device) + video_latents = video_latents.to(device=device, dtype=dit_dtype) + + clean_memory_on_device(device) + + # encode latents for image2video inference + image_latents = None + if i2v: + # i2v inference + logger.info(f"Image2Video inference: {args.image_path}") + + image = resize_image_to_bucket(semantic_image, (width, height)) # returns a numpy array + image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W + image = image / 255.0 + + logger.info(f"Encoding image to latents") + image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W + image_latents = image_latents.to(device=device, dtype=dit_dtype) + + clean_memory_on_device(device) + + # load DiT model + blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0 + loading_device = "cpu" # if blocks_to_swap > 0 else device + + logger.info(f"Loading DiT model from {args.dit}") + if args.attn_mode == "sdpa": + args.attn_mode = "torch" + + # if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16 + # the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway + # on the fly merging will be a solution for this issue for .safetenors files (not implemented yet) + transformer = load_transformer( + args.dit, + args.attn_mode, + args.split_attn, + loading_device, + dit_dtype, + in_channels=dit_in_channels, + i2v_mode=hunyuan_video_i2v, + ) + transformer.eval() + + # load LoRA weights + if args.lora_weight is not None and len(args.lora_weight) > 0: + for i, lora_weight in enumerate(args.lora_weight): + if args.lora_multiplier is not None and len(args.lora_multiplier) > i: + lora_multiplier = args.lora_multiplier[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + + # Filter to exclude keys that are part of single_blocks + if args.exclude_single_blocks: + filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k} + weights_sd = filtered_weights + + if args.lycoris: + lycoris_net, _ = create_network_from_weights( + multiplier=lora_multiplier, + file=None, + weights_sd=weights_sd, + unet=transformer, + text_encoder=None, + vae=None, + for_inference=True, + ) + else: + network = lora.create_arch_network_from_weights( + lora_multiplier, weights_sd, unet=transformer, for_inference=True + ) + logger.info("Merging LoRA weights to DiT model") + + # try: + # network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + # info = network.load_state_dict(weights_sd, strict=True) + # logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + # network.eval() + # network.to(device) + # except Exception as e: + if args.lycoris: + lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device) + else: + network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True) + + synchronize_device(device) + + logger.info("LoRA weights loaded") + + # save model here before casting to dit_weight_dtype + if args.save_merged_model: + logger.info(f"Saving merged model to {args.save_merged_model}") + mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory + logger.info("Merged model saved") + return + + logger.info(f"Casting model to {dit_weight_dtype}") + transformer.to(dtype=dit_weight_dtype) + + if args.fp8_fast: + logger.info("Enabling FP8 acceleration") + params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} + for name, param in transformer.named_parameters(): + dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype + param.to(dtype=dtype_to_use) + convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep) + + if args.compile: + compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args + logger.info( + f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" + ) + torch._dynamo.config.cache_size_limit = 32 + for i, block in enumerate(transformer.single_blocks): + compiled_block = torch.compile( + block, backend=compile_backend, mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true" + ) + transformer.single_blocks[i] = compiled_block + for i, block in enumerate(transformer.double_blocks): + compiled_block = torch.compile( + block, backend=compile_backend, mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true" + ) + transformer.double_blocks[i] = compiled_block + + if blocks_to_swap > 0: + logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}") + transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False) + transformer.move_to_device_except_swap_blocks(device) + transformer.prepare_block_swap_before_forward() + else: + logger.info(f"Moving model to {device}") + transformer.to(device=device) + if args.img_in_txt_in_offloading: + logger.info("Enable offloading img_in and txt_in to CPU") + transformer.enable_img_in_txt_in_offloading() + + # load scheduler + logger.info(f"Loading scheduler") + scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler") + + # Prepare timesteps + num_inference_steps = args.infer_steps + scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler + timesteps = scheduler.timesteps + + # Prepare generator + num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size + seed = args.seed + if seed is None: + seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)] + elif isinstance(seed, int): + seeds = [seed + i for i in range(num_videos_per_prompt)] + else: + raise ValueError(f"Seed must be an integer or None, got {seed}.") + generator = [torch.Generator(device).manual_seed(seed) for seed in seeds] + + # Prepare noisy latents + num_channels_latents = 16 # transformer.config.in_channels + vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4 + + vae_ver = vae.VAE_VER + if "884" in vae_ver: + latent_video_length = (video_length - 1) // 4 + 1 + elif "888" in vae_ver: + latent_video_length = (video_length - 1) // 8 + 1 + else: + latent_video_length = video_length + + # shape = ( + # num_videos_per_prompt, + # num_channels_latents, + # latent_video_length, + # height // vae_scale_factor, + # width // vae_scale_factor, + # ) + # latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype) + + # make first N frames to be the same if the given seed is same + shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor) + latents = [] + for i in range(latent_video_length): + latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype)) + latents = torch.cat(latents, dim=2) + + if sky_reels_i2v: + # pad image_latents to match the length of video_latents + zero_latents = torch.zeros_like(latents) + zero_latents[:, :, :1, :, :] = image_latents + image_latents = zero_latents + elif hunyuan_video_i2v: + if args.i2v_stability: + t = torch.tensor([0.999]).to(device=device) + latents = latents * t + image_latents.repeat(1, 1, latent_video_length, 1, 1) * (1 - t) + + if args.video_path is not None: + # v2v inference + noise = latents + assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}" + + num_inference_steps = int(num_inference_steps * args.strength) + timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time + t = timestep_start / 1000.0 + latents = noise * t + video_latents * (1 - t) + + timesteps = timesteps[-num_inference_steps:] + + logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}") + + # FlowMatchDiscreteScheduler does not have init_noise_sigma + + # Denoising loop + embedded_guidance_scale = args.embedded_cfg_scale + if embedded_guidance_scale is not None: + guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu") + guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype) + if do_classifier_free_guidance: + guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0) + else: + guidance_expand = None + freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width) + # n_tokens = freqs_cos.shape[0] + + # move and cast all inputs to the correct device and dtype + prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype) + prompt_mask = prompt_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype) + prompt_mask_2 = prompt_mask_2.to(device=device) + + freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype) + freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype) + + num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference + + # assert split_uncond and split_attn + if args.split_attn and do_classifier_free_guidance and not args.split_uncond: + logger.warning("split_attn is enabled, split_uncond will be enabled as well.") + args.split_uncond = True + + # we do not support "latent_concat" mode for HunyuanVideo-I2V model, only "token_replace" mode is supported + + # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p: + with tqdm(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents = scheduler.scale_model_input(latents, t) + + # predict the noise residual + with torch.no_grad(), accelerator.autocast(): + if hunyuan_video_i2v: # and i2v_condition_type == "token_replace": + latents = torch.cat([image_latents, latents[:, :, 1:, :, :]], dim=2) + + latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0) + if sky_reels_i2v: + latents_image_input = ( + image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0) + ) + latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W + + batch_size = 1 if args.split_uncond else latents_input.shape[0] + + noise_pred_list = [] + for j in range(0, latents_input.shape[0], batch_size): + noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256) + latents_input[j : j + batch_size], # [1, 16, 33, 24, 42] + t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1] + text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096] + text_mask=prompt_mask[j : j + batch_size], # [1, 256] + text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768] + freqs_cos=freqs_cos, # [seqlen, head_dim] + freqs_sin=freqs_sin, # [seqlen, head_dim] + guidance=guidance_expand[j : j + batch_size], # [1] + return_dict=True, + )["x"] + noise_pred_list.append(noise_pred) + noise_pred = torch.cat(noise_pred_list, dim=0) + + # perform classifier free guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # # SkyReels' rescale noise config is omitted for now + # if guidance_rescale > 0.0: + # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # noise_pred = rescale_noise_cfg( + # noise_pred, + # noise_pred_cond, + # guidance_rescale=self.guidance_rescale, + # ) + + # compute the previous noisy sample x_t -> x_t-1 + # generator and eta arguments are not used in FlowMatchDiscreteScheduler + if not hunyuan_video_i2v: + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + else: + latents = scheduler.step(noise_pred[:, :, 1:, :, :], t, latents[:, :, 1:, :, :], return_dict=False)[0] + latents = torch.concat([image_latents, latents], dim=2) + + # update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): + if progress_bar is not None: + progress_bar.update() + + # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1)) + # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + latents = latents.detach().cpu() + transformer = None + clean_memory_on_device(device) + + # Save samples + output_type = args.output_type + save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}" + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + if output_type == "latent" or output_type == "both": + # save latent + for i, latent in enumerate(latents): + latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors" + + if args.no_metadata: + metadata = None + else: + metadata = { + "seeds": f"{seeds[i]}", + "prompt": f"{args.prompt}", + "height": f"{height}", + "width": f"{width}", + "video_length": f"{video_length}", + "infer_steps": f"{num_inference_steps}", + "guidance_scale": f"{args.guidance_scale}", + "embedded_cfg_scale": f"{args.embedded_cfg_scale}", + } + if args.negative_prompt is not None: + metadata["negative_prompt"] = f"{args.negative_prompt}" + sd = {"latent": latent} + save_file(sd, latent_path, metadata=metadata) + + logger.info(f"Latent save to: {latent_path}") + if output_type == "video" or output_type == "both": + # save video + videos = decode_latents(args, latents, device) + for i, sample in enumerate(videos): + original_name = "" if original_base_names is None else f"_{original_base_names[i]}" + sample = sample.unsqueeze(0) + video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4" + save_videos_grid(sample, video_path, fps=args.fps) + logger.info(f"Sample save to: {video_path}") + elif output_type == "images": + # save images + videos = decode_latents(args, latents, device) + for i, sample in enumerate(videos): + original_name = "" if original_base_names is None else f"_{original_base_names[i]}" + sample = sample.unsqueeze(0) + image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}" + save_images_grid(sample, save_path, image_name) + logger.info(f"Sample images save to: {save_path}/{image_name}") + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/hv_i2v_generate_video.py b/hv_i2v_generate_video.py new file mode 100644 index 0000000000000000000000000000000000000000..4dfb8de9dda2d98f073fe7aeba42115917806e5a --- /dev/null +++ b/hv_i2v_generate_video.py @@ -0,0 +1,973 @@ +import argparse +from datetime import datetime +from pathlib import Path +import random +import sys +import os +import time +from typing import Optional, Union + +import numpy as np +import torch +import torchvision +import accelerate +from diffusers.utils.torch_utils import randn_tensor +from transformers.models.llama import LlamaModel +from tqdm import tqdm +import av +from einops import rearrange +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from PIL import Image + +from hunyuan_model import vae +from hunyuan_model import text_encoder +from hunyuan_model.text_encoder import TextEncoder +from hunyuan_model.text_encoder import PROMPT_TEMPLATE +from hunyuan_model.vae import load_vae +from hunyuan_model.models import load_transformer, get_rotary_pos_embed +from hunyuan_model.fp8_optimization import convert_fp8_linear +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +from networks import lora + +try: + from lycoris.kohya import create_network_from_weights +except: + pass + +from utils.model_utils import str_to_dtype +from utils.safetensors_utils import mem_eff_save_file +from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def clean_memory_on_device(device): + if device.type == "cuda": + torch.cuda.empty_cache() + elif device.type == "cpu": + pass + elif device.type == "mps": # not tested + torch.mps.empty_cache() + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24): + """save videos by video tensor + copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61 + + Args: + videos (torch.Tensor): video tensor predicted by the model + path (str): path to save video + rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False. + n_rows (int, optional): Defaults to 1. + fps (int, optional): video save fps. Defaults to 8. + """ + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x, 0, 1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + # # save video with av + # container = av.open(path, "w") + # stream = container.add_stream("libx264", rate=fps) + # for x in outputs: + # frame = av.VideoFrame.from_ndarray(x, format="rgb24") + # packet = stream.encode(frame) + # container.mux(packet) + # packet = stream.encode(None) + # container.mux(packet) + # container.close() + + height, width, _ = outputs[0].shape + + # create output container + container = av.open(path, mode="w") + + # create video stream + codec = "libx264" + pixel_format = "yuv420p" + stream = container.add_stream(codec, rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = pixel_format + stream.bit_rate = 4000000 # 4Mbit/s + + for frame_array in outputs: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + packets = stream.encode(frame) + for packet in packets: + container.mux(packet) + + for packet in stream.encode(): + container.mux(packet) + + container.close() + + +def save_images_grid( + videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True +): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x, 0, 1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + if create_subdir: + output_dir = os.path.join(parent_dir, image_name) + else: + output_dir = parent_dir + + os.makedirs(output_dir, exist_ok=True) + for i, x in enumerate(outputs): + image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png") + image = Image.fromarray(x) + image.save(image_path) + + +# region Encoding prompt + + +def encode_prompt( + prompt: Union[str, list[str]], + semantic_images: Optional[Union[Image.Image, list[Image.Image]]], + device: torch.device, + num_videos_per_prompt: int, + text_encoder: TextEncoder, +): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + semantic_images (`Image.Image` or `List[Image.Image]`): + semantic images to be used for I2V model + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + text_encoder (TextEncoder): + text encoder to be used for encoding the prompt + """ + # LoRA and Textual Inversion are not supported in this script + # negative prompt and prompt embedding are not supported in this script + # clip_skip is not supported in this script because it is not used in the original script + data_type = "video" # video only, image is not supported + + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, semantic_images=semantic_images) + + with torch.no_grad(): + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device, semantic_images=semantic_images) + prompt_embeds = prompt_outputs.hidden_state + + attention_mask = prompt_outputs.attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(device) + bs_embed, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, num_videos_per_prompt) + attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len) + + prompt_embeds_dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if prompt_embeds.ndim == 2: + bs_embed, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) + else: + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, attention_mask + + +def encode_input_prompt( + hunyuan_video_i2v: bool, + prompt: Union[str, list[str]], + semantic_images: Optional[Union[Image.Image, list[Image.Image]]], + image_embed_interleave: Optional[int], + args, + device, + fp8_llm=False, + accelerator=None, +): + # constants + prompt_template_video = "dit-llm-encode-video" if not hunyuan_video_i2v else "dit-llm-encode-video-i2v" + prompt_template = "dit-llm-encode" if not hunyuan_video_i2v else "dit-llm-encode-i2v" + # text_encoder_dtype = torch.float16 + # text_encoder_type = "llm" if not hunyuan_video_i2v else "llm-i2v" + # text_len = 256 + # hidden_state_skip_layer = 2 + # apply_final_norm = False + # reproduce = False + + # text_encoder_2_type = "clipL" + # text_len_2 = 77 + + num_videos = 1 + + if type(prompt) == list: + # use default negative prompt for None + default_negative_prompt = text_encoder.NEGATIVE_PROMPT if not hunyuan_video_i2v else text_encoder.NEGATIVE_PROMPT_I2V + replaced_prompt = [p if p is not None else default_negative_prompt for p in prompt] + prompt = replaced_prompt + if semantic_images is not None and type(semantic_images) != list: + semantic_images = [semantic_images] + + # crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0) + # max_length = text_len + crop_start + + # prompt_template + prompt_template = PROMPT_TEMPLATE[prompt_template] + + # prompt_template_video + prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None + + # load text encoders + logger.info(f"loading text encoder: {args.text_encoder1}") + text_encoder_1 = text_encoder.load_text_encoder_1( + text_encoder_dir=args.text_encoder1, + device="cpu", + fp8_llm=fp8_llm, + i2v_mode=hunyuan_video_i2v, + image_embed_interleave=image_embed_interleave, + clip_vision_path=args.clip_vision_path, + ) + text_encoder_1.eval() + + logger.info(f"loading text encoder 2: {args.text_encoder2}") + text_encoder_2 = text_encoder.load_text_encoder_2(text_encoder_dir=args.text_encoder2, device="cpu") + text_encoder_2.eval() + + # encode prompt + logger.info(f"Encoding prompt with text encoder 1") + text_encoder_1.to(device=device) + if fp8_llm: + with accelerator.autocast(): + prompt_embeds, prompt_mask = encode_prompt(prompt, semantic_images, device, num_videos, text_encoder_1) + else: + prompt_embeds, prompt_mask = encode_prompt(prompt, semantic_images, device, num_videos, text_encoder_1) + text_encoder_1 = None + clean_memory_on_device(device) + + logger.info(f"Encoding prompt with text encoder 2") + text_encoder_2.to(device=device) + prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, semantic_images, device, num_videos, text_encoder_2) + + prompt_embeds = prompt_embeds.to("cpu") + prompt_mask = prompt_mask.to("cpu") + prompt_embeds_2 = prompt_embeds_2.to("cpu") + prompt_mask_2 = prompt_mask_2.to("cpu") + + text_encoder_2 = None + clean_memory_on_device(device) + + return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 + + +# endregion + + +def load_images(image_dir, video_length, bucket_reso): + image_files = glob_images(image_dir) + if len(image_files) == 0: + raise ValueError(f"No image files found in {image_dir}") + if len(image_files) < video_length: + raise ValueError(f"Number of images in {image_dir} is less than {video_length}") + + image_files.sort() + images = [] + for image_file in image_files[:video_length]: + image = Image.open(image_file) + image = resize_image_to_bucket(image, bucket_reso) # returns a numpy array + images.append(image) + + return images + + +def prepare_vae(args, device): + vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype) + vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae) + vae.eval() + # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} + + # set chunk_size to CausalConv3d recursively + chunk_size = args.vae_chunk_size + if chunk_size is not None: + vae.set_chunk_size_for_causal_conv_3d(chunk_size) + logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d") + + if args.vae_spatial_tile_sample_min_size is not None: + vae.enable_spatial_tiling(True) + vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size + vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 + # elif args.vae_tiling: + else: + vae.enable_spatial_tiling(True) + + return vae, vae_dtype + + +def encode_to_latents(args, video, device): + vae, vae_dtype = prepare_vae(args, device) + + video = video.to(device=device, dtype=vae_dtype) + video = video * 2 - 1 # 0, 1 -> -1, 1 + with torch.no_grad(): + latents = vae.encode(video).latent_dist.sample() + + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor + else: + latents = latents * vae.config.scaling_factor + + return latents + + +def decode_latents(args, latents, device): + vae, vae_dtype = prepare_vae(args, device) + + expand_temporal_dim = False + if len(latents.shape) == 4: + latents = latents.unsqueeze(2) + expand_temporal_dim = True + elif len(latents.shape) == 5: + pass + else: + raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.") + + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = latents / vae.config.scaling_factor + vae.config.shift_factor + else: + latents = latents / vae.config.scaling_factor + + latents = latents.to(device=device, dtype=vae_dtype) + with torch.no_grad(): + image = vae.decode(latents, return_dict=False)[0] + + if expand_temporal_dim: + image = image.squeeze(2) + + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().float() + + return image + + +def parse_args(): + parser = argparse.ArgumentParser(description="HunyuanVideo inference script") + + parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory") + parser.add_argument( + "--dit_in_channels", + type=int, + default=None, + help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others", + ) + parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") + parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 path or directory") + parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 path or directory") + parser.add_argument("--clip_vision_path", type=str, default=None, help="CLIP vision model path for HunyuanVideo-I2V") + + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights") + + # inference + parser.add_argument("--prompt", type=str, required=True, help="prompt for generation") + parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation") + parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size") + parser.add_argument("--video_length", type=int, default=129, help="video length") + parser.add_argument("--fps", type=int, default=24, help="video fps") + parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + parser.add_argument( + "--guidance_scale", + type=float, + default=1.0, + help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)", + ) + parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.") + parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") + parser.add_argument( + "--image_path", type=str, default=None, help="path to image for image2video inference, only works for I2V model" + ) + parser.add_argument("--i2v_stability", action="store_true", help="use stability for HunyuanVideo-I2V model") + parser.add_argument( + "--split_uncond", + action="store_true", + help="split unconditional call for classifier free guidance, slower but less memory usage", + ) + parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference") + + # Flow Matching + parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.") + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode" + ) + parser.add_argument( + "--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True" + ) + parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") + parser.add_argument( + "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" + ) + parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model") + parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu") + parser.add_argument( + "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") + parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)") + parser.add_argument("--compile", action="store_true", help="Enable torch.compile") + parser.add_argument( + "--compile_args", nargs=4, metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), + default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], + help="Torch.compile settings" + ) + + args = parser.parse_args() + + assert (args.latent_path is None or len(args.latent_path) == 0) or ( + args.output_type == "images" or args.output_type == "video" + ), "latent_path is only supported for images or video output" + + # update dit_weight based on model_base if not exists + + if args.fp8_fast and not args.fp8: + raise ValueError("--fp8_fast requires --fp8") + + return args + + +def check_inputs(args): + height = args.video_size[0] + width = args.video_size[1] + video_length = args.video_length + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + return height, width, video_length + + +def main(): + args = parse_args() + + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + dit_dtype = torch.bfloat16 + dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype + logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}") + + original_base_names = None + if args.latent_path is not None and len(args.latent_path) > 0: + original_base_names = [] + latents_list = [] + seeds = [] + for latent_path in args.latent_path: + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + seed = 0 + + if os.path.splitext(latent_path)[1] != ".safetensors": + latents = torch.load(latent_path, map_location="cpu") + else: + latents = load_file(latent_path)["latent"] + with safe_open(latent_path, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + logger.info(f"Loaded metadata: {metadata}") + + if "seeds" in metadata: + seed = int(metadata["seeds"]) + + seeds.append(seed) + latents_list.append(latents) + + logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") + latents = torch.stack(latents_list, dim=0) + else: + # prepare accelerator + mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" + accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) + + # SkyReels-I2V or HunyuanVideo-I2V + i2v = args.image_path is not None + dit_in_channels = 16 if args.dit_in_channels is None else args.dit_in_channels + semantic_image: Image.Image = None + if i2v: + semantic_image = Image.open(args.image_path).convert("RGB") + + sky_reels_i2v = dit_in_channels == 32 + hunyuan_video_i2v = not sky_reels_i2v # only supports "token_replace" mode for HunyuanVideo-I2V + if hunyuan_video_i2v: + image_embed_interleave = 4 + else: + image_embed_interleave = None + else: + sky_reels_i2v = False + hunyuan_video_i2v = False + image_embed_interleave = None + + # load prompt + prompt = args.prompt # TODO load prompts from file + assert prompt is not None, "prompt is required" + + # check inputs: may be height, width, video_length etc will be changed for each generation in future + height, width, video_length = check_inputs(args) + + # encode prompt with LLM and Text Encoder + logger.info(f"Encoding prompt: {prompt}") + + do_classifier_free_guidance = args.guidance_scale != 1.0 + semantic_images = [semantic_image] # for prompt + if do_classifier_free_guidance: + negative_prompt = args.negative_prompt + if negative_prompt is None: + logger.info("Negative prompt is not provided, using default prompt") + logger.info(f"Encoding negative prompt: {negative_prompt}") + prompt = [negative_prompt, prompt] + # use black image for negative prompt + semantic_images = [Image.new("RGB", semantic_image.size, (0, 0, 0)), semantic_image] + else: + if args.negative_prompt is not None: + logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.") + + prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt( + hunyuan_video_i2v, prompt, semantic_images, image_embed_interleave, args, device, args.fp8_llm, accelerator + ) + + # encode latents for video2video inference + video_latents = None + if args.video_path is not None: + # v2v inference + logger.info(f"Video2Video inference: {args.video_path}") + + if os.path.isfile(args.video_path): + video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames + else: + video = load_images(args.video_path, video_length, bucket_reso=(width, height)) # list of frames + + if len(video) < video_length: + raise ValueError(f"Video length is less than {video_length}") + video = np.stack(video, axis=0) # F, H, W, C + video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W + video = video / 255.0 + + logger.info(f"Encoding video to latents") + video_latents = encode_to_latents(args, video, device) + video_latents = video_latents.to(device=device, dtype=dit_dtype) + + clean_memory_on_device(device) + + # encode latents for image2video inference + image_latents = None + if i2v: + # i2v inference + logger.info(f"Image2Video inference: {args.image_path}") + + image = resize_image_to_bucket(semantic_image, (width, height)) # returns a numpy array + image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W + image = image / 255.0 + + logger.info(f"Encoding image to latents") + image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W + image_latents = image_latents.to(device=device, dtype=dit_dtype) + + clean_memory_on_device(device) + + # load DiT model + blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0 + loading_device = "cpu" # if blocks_to_swap > 0 else device + + logger.info(f"Loading DiT model from {args.dit}") + if args.attn_mode == "sdpa": + args.attn_mode = "torch" + + # if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16 + # the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway + # on the fly merging will be a solution for this issue for .safetenors files (not implemented yet) + transformer = load_transformer( + args.dit, + args.attn_mode, + args.split_attn, + loading_device, + dit_dtype, + in_channels=dit_in_channels, + i2v_mode=hunyuan_video_i2v, + ) + transformer.eval() + + # load LoRA weights + if args.lora_weight is not None and len(args.lora_weight) > 0: + for i, lora_weight in enumerate(args.lora_weight): + if args.lora_multiplier is not None and len(args.lora_multiplier) > i: + lora_multiplier = args.lora_multiplier[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + + # Filter to exclude keys that are part of single_blocks + if args.exclude_single_blocks: + filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k} + weights_sd = filtered_weights + + if args.lycoris: + lycoris_net, _ = create_network_from_weights( + multiplier=lora_multiplier, + file=None, + weights_sd=weights_sd, + unet=transformer, + text_encoder=None, + vae=None, + for_inference=True, + ) + else: + network = lora.create_arch_network_from_weights( + lora_multiplier, weights_sd, unet=transformer, for_inference=True + ) + logger.info("Merging LoRA weights to DiT model") + + # try: + # network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + # info = network.load_state_dict(weights_sd, strict=True) + # logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + # network.eval() + # network.to(device) + # except Exception as e: + if args.lycoris: + lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device) + else: + network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True) + + synchronize_device(device) + + logger.info("LoRA weights loaded") + + # save model here before casting to dit_weight_dtype + if args.save_merged_model: + logger.info(f"Saving merged model to {args.save_merged_model}") + mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory + logger.info("Merged model saved") + return + + logger.info(f"Casting model to {dit_weight_dtype}") + transformer.to(dtype=dit_weight_dtype) + + if args.fp8_fast: + logger.info("Enabling FP8 acceleration") + params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} + for name, param in transformer.named_parameters(): + dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype + param.to(dtype=dtype_to_use) + convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep) + + if args.compile: + compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args + logger.info( + f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" + ) + torch._dynamo.config.cache_size_limit = 32 + for i, block in enumerate(transformer.single_blocks): + compiled_block = torch.compile( + block, backend=compile_backend, mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true" + ) + transformer.single_blocks[i] = compiled_block + for i, block in enumerate(transformer.double_blocks): + compiled_block = torch.compile( + block, backend=compile_backend, mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true" + ) + transformer.double_blocks[i] = compiled_block + + if blocks_to_swap > 0: + logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}") + transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False) + transformer.move_to_device_except_swap_blocks(device) + transformer.prepare_block_swap_before_forward() + else: + logger.info(f"Moving model to {device}") + transformer.to(device=device) + if args.img_in_txt_in_offloading: + logger.info("Enable offloading img_in and txt_in to CPU") + transformer.enable_img_in_txt_in_offloading() + + # load scheduler + logger.info(f"Loading scheduler") + scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler") + + # Prepare timesteps + num_inference_steps = args.infer_steps + scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler + timesteps = scheduler.timesteps + + # Prepare generator + num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size + seed = args.seed + if seed is None: + seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)] + elif isinstance(seed, int): + seeds = [seed + i for i in range(num_videos_per_prompt)] + else: + raise ValueError(f"Seed must be an integer or None, got {seed}.") + generator = [torch.Generator(device).manual_seed(seed) for seed in seeds] + + # Prepare noisy latents + num_channels_latents = 16 # transformer.config.in_channels + vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4 + + vae_ver = vae.VAE_VER + if "884" in vae_ver: + latent_video_length = (video_length - 1) // 4 + 1 + elif "888" in vae_ver: + latent_video_length = (video_length - 1) // 8 + 1 + else: + latent_video_length = video_length + + # shape = ( + # num_videos_per_prompt, + # num_channels_latents, + # latent_video_length, + # height // vae_scale_factor, + # width // vae_scale_factor, + # ) + # latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype) + + # make first N frames to be the same if the given seed is same + shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor) + latents = [] + for i in range(latent_video_length): + latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype)) + latents = torch.cat(latents, dim=2) + + if sky_reels_i2v: + # pad image_latents to match the length of video_latents + zero_latents = torch.zeros_like(latents) + zero_latents[:, :, :1, :, :] = image_latents + image_latents = zero_latents + elif hunyuan_video_i2v: + if args.i2v_stability: + t = torch.tensor([0.999]).to(device=device) + latents = latents * t + image_latents.repeat(1, 1, latent_video_length, 1, 1) * (1 - t) + + if args.video_path is not None: + # v2v inference + noise = latents + assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}" + + num_inference_steps = int(num_inference_steps * args.strength) + timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time + t = timestep_start / 1000.0 + latents = noise * t + video_latents * (1 - t) + + timesteps = timesteps[-num_inference_steps:] + + logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}") + + # FlowMatchDiscreteScheduler does not have init_noise_sigma + + # Denoising loop + embedded_guidance_scale = args.embedded_cfg_scale + if embedded_guidance_scale is not None: + guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu") + guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype) + if do_classifier_free_guidance: + guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0) + else: + guidance_expand = None + freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width) + # n_tokens = freqs_cos.shape[0] + + # move and cast all inputs to the correct device and dtype + prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype) + prompt_mask = prompt_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype) + prompt_mask_2 = prompt_mask_2.to(device=device) + + freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype) + freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype) + + num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference + + # assert split_uncond and split_attn + if args.split_attn and do_classifier_free_guidance and not args.split_uncond: + logger.warning("split_attn is enabled, split_uncond will be enabled as well.") + args.split_uncond = True + + # we do not support "latent_concat" mode for HunyuanVideo-I2V model, only "token_replace" mode is supported + + # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p: + with tqdm(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents = scheduler.scale_model_input(latents, t) + + # predict the noise residual + with torch.no_grad(), accelerator.autocast(): + if hunyuan_video_i2v: # and i2v_condition_type == "token_replace": + latents = torch.cat([image_latents, latents[:, :, 1:, :, :]], dim=2) + + latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0) + if sky_reels_i2v: + latents_image_input = ( + image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0) + ) + latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W + + batch_size = 1 if args.split_uncond else latents_input.shape[0] + + noise_pred_list = [] + for j in range(0, latents_input.shape[0], batch_size): + noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256) + latents_input[j : j + batch_size], # [1, 16, 33, 24, 42] + t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1] + text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096] + text_mask=prompt_mask[j : j + batch_size], # [1, 256] + text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768] + freqs_cos=freqs_cos, # [seqlen, head_dim] + freqs_sin=freqs_sin, # [seqlen, head_dim] + guidance=guidance_expand[j : j + batch_size], # [1] + return_dict=True, + )["x"] + noise_pred_list.append(noise_pred) + noise_pred = torch.cat(noise_pred_list, dim=0) + + # perform classifier free guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # # SkyReels' rescale noise config is omitted for now + # if guidance_rescale > 0.0: + # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # noise_pred = rescale_noise_cfg( + # noise_pred, + # noise_pred_cond, + # guidance_rescale=self.guidance_rescale, + # ) + + # compute the previous noisy sample x_t -> x_t-1 + # generator and eta arguments are not used in FlowMatchDiscreteScheduler + if not hunyuan_video_i2v: + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + else: + latents = scheduler.step(noise_pred[:, :, 1:, :, :], t, latents[:, :, 1:, :, :], return_dict=False)[0] + latents = torch.concat([image_latents, latents], dim=2) + + # update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): + if progress_bar is not None: + progress_bar.update() + + # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1)) + # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + latents = latents.detach().cpu() + transformer = None + clean_memory_on_device(device) + + # Save samples + output_type = args.output_type + save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}" + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + if output_type == "latent" or output_type == "both": + # save latent + for i, latent in enumerate(latents): + latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors" + + if args.no_metadata: + metadata = None + else: + metadata = { + "seeds": f"{seeds[i]}", + "prompt": f"{args.prompt}", + "height": f"{height}", + "width": f"{width}", + "video_length": f"{video_length}", + "infer_steps": f"{num_inference_steps}", + "guidance_scale": f"{args.guidance_scale}", + "embedded_cfg_scale": f"{args.embedded_cfg_scale}", + } + if args.negative_prompt is not None: + metadata["negative_prompt"] = f"{args.negative_prompt}" + sd = {"latent": latent} + save_file(sd, latent_path, metadata=metadata) + + logger.info(f"Latent save to: {latent_path}") + if output_type == "video" or output_type == "both": + # save video + videos = decode_latents(args, latents, device) + for i, sample in enumerate(videos): + original_name = "" if original_base_names is None else f"_{original_base_names[i]}" + sample = sample.unsqueeze(0) + video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4" + save_videos_grid(sample, video_path, fps=args.fps) + logger.info(f"Sample save to: {video_path}") + elif output_type == "images": + # save images + videos = decode_latents(args, latents, device) + for i, sample in enumerate(videos): + original_name = "" if original_base_names is None else f"_{original_base_names[i]}" + sample = sample.unsqueeze(0) + image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}" + save_images_grid(sample, save_path, image_name) + logger.info(f"Sample images save to: {save_path}/{image_name}") + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/hv_train.py b/hv_train.py new file mode 100644 index 0000000000000000000000000000000000000000..041c35407db1467ff88bc41b389684e23afca03b --- /dev/null +++ b/hv_train.py @@ -0,0 +1,1719 @@ +import ast +import asyncio +from datetime import datetime +import gc +import importlib +import argparse +import math +import os +import pathlib +import re +import sys +import random +import time +import json +from multiprocessing import Value +from typing import Any, Dict, List, Optional +import accelerate +import numpy as np +from packaging.version import Version + +import huggingface_hub +import toml + +import torch +from tqdm import tqdm +from accelerate.utils import set_seed +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs +from safetensors.torch import load_file, save_file +import transformers +from diffusers.optimization import ( + SchedulerType as DiffusersSchedulerType, + TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, +) +from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION + +from dataset import config_utils +from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape +import hunyuan_model.text_encoder as text_encoder_module +from hunyuan_model.vae import load_vae +import hunyuan_model.vae as vae_module +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +import networks.lora as lora_module +from dataset.config_utils import BlueprintGenerator, ConfigSanitizer + +import logging + +from utils import huggingface_utils, model_utils, train_utils, sai_model_spec + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video" + +# TODO make separate file for some functions to commonize with other scripts + + +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + +# for collate_fn: epoch and step is multiprocessing.Value +class collator_class: + def __init__(self, epoch, step, dataset): + self.current_epoch = epoch + self.current_step = step + self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing + + def __call__(self, examples): + worker_info = torch.utils.data.get_worker_info() + # worker_info is None in the main process + if worker_info is not None: + dataset = worker_info.dataset + else: + dataset = self.dataset + + # set epoch and step + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] + + +def prepare_accelerator(args: argparse.Namespace) -> Accelerator: + """ + DeepSpeed is not supported in this script currently. + """ + if args.logging_dir is None: + logging_dir = None + else: + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) + + if args.log_with is None: + if logging_dir is not None: + log_with = "tensorboard" + else: + log_with = None + else: + log_with = args.log_with + if log_with in ["tensorboard", "all"]: + if logging_dir is None: + raise ValueError( + "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください" + ) + if log_with in ["wandb", "all"]: + try: + import wandb + except ImportError: + raise ImportError("No wandb / wandb がインストールされていないようです") + if logging_dir is not None: + os.makedirs(logging_dir, exist_ok=True) + os.environ["WANDB_DIR"] = logging_dir + if args.wandb_api_key is not None: + wandb.login(key=args.wandb_api_key) + + kwargs_handlers = [ + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), + ] + kwargs_handlers = [i for i in kwargs_handlers if i is not None] + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=log_with, + project_dir=logging_dir, + kwargs_handlers=kwargs_handlers, + ) + print("accelerator device:", accelerator.device) + return accelerator + + +def line_to_prompt_dict(line: str) -> dict: + # subset of gen_img_diffusers + prompt_args = line.split(" --") + prompt_dict = {} + prompt_dict["prompt"] = prompt_args[0] + + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["width"] = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["height"] = int(m.group(1)) + continue + + m = re.match(r"f (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["frame_count"] = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["seed"] = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1)))) + continue + + # m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + # if m: # scale + # prompt_dict["scale"] = float(m.group(1)) + # continue + # m = re.match(r"n (.+)", parg, re.IGNORECASE) + # if m: # negative prompt + # prompt_dict["negative_prompt"] = m.group(1) + # continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(ex) + + return prompt_dict + + +def load_prompts(prompt_file: str) -> list[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + + # if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps): + if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]): + # raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません") + # round to nearest timestep + logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません") + step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps] + else: + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap": + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype) + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + else: + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = None # torch.ones_like(sigmas) + return weighting + + +class FineTuningTrainer: + def __init__(self): + pass + + def process_sample_prompts( + self, + args: argparse.Namespace, + accelerator: Accelerator, + sample_prompts: str, + text_encoder1: str, + text_encoder2: str, + fp8_llm: bool, + ): + logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}") + prompts = load_prompts(sample_prompts) + + def encode_for_text_encoder(text_encoder, is_llm=True): + sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask) + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + + data_type = "video" + text_inputs = text_encoder.text2tokens(p, data_type=data_type) + + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type) + sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask) + + return sample_prompts_te_outputs + + # Load Text Encoder 1 and encode + text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype) + logger.info(f"loading text encoder 1: {text_encoder1}") + text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype) + + logger.info("encoding with Text Encoder 1") + te_outputs_1 = encode_for_text_encoder(text_encoder_1) + del text_encoder_1 + + # Load Text Encoder 2 and encode + logger.info(f"loading text encoder 2: {text_encoder2}") + text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype) + + logger.info("encoding with Text Encoder 2") + te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False) + del text_encoder_2 + + # prepare sample parameters + sample_parameters = [] + for prompt_dict in prompts: + prompt_dict_copy = prompt_dict.copy() + p = prompt_dict.get("prompt", "") + prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0] + prompt_dict_copy["llm_mask"] = te_outputs_1[p][1] + prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0] + prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1] + sample_parameters.append(prompt_dict_copy) + + clean_memory_on_device(accelerator.device) + + return sample_parameters + + def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]: + # adamw, adamw8bit, adafactor + + optimizer_type = args.optimizer_type.lower() + + # split optimizer_type and optimizer_args + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + optimizer_kwargs[key] = value + + lr = args.learning_rate + optimizer = None + optimizer_class = None + + if optimizer_type.endswith("8bit".lower()): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + + if optimizer_type == "AdamW8bit".lower(): + logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Adafactor".lower(): + # Adafactor: check relative_step and warmup_init + if "relative_step" not in optimizer_kwargs: + optimizer_kwargs["relative_step"] = True # default + if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): + logger.info( + f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします" + ) + optimizer_kwargs["relative_step"] = True + logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") + + if optimizer_kwargs["relative_step"]: + logger.info(f"relative_step is true / relative_stepがtrueです") + if lr != 0.0: + logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + args.learning_rate = None + + if args.lr_scheduler != "adafactor": + logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど + + lr = None + else: + if args.max_grad_norm != 0.0: + logger.warning( + f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" + ) + if args.lr_scheduler != "constant_with_warmup": + logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: + logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + + optimizer_class = transformers.optimization.Adafactor + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "AdamW".lower(): + logger.info(f"use AdamW optimizer | {optimizer_kwargs}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + if optimizer is None: + # 任意のoptimizerを使う + case_sensitive_optimizer_type = args.optimizer_type # not lower + logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") + + if "." not in case_sensitive_optimizer_type: # from torch.optim + optimizer_module = torch.optim + else: # from other library + values = case_sensitive_optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + case_sensitive_optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + # for logging + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + + # get train and eval functions + if hasattr(optimizer, "train") and callable(optimizer.train): + train_fn = optimizer.train + eval_fn = optimizer.eval + else: + train_fn = lambda: None + eval_fn = lambda: None + + return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn + + def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool: + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + + def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any: + # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. + # this scheduler is used for logging only. + # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler + class DummyScheduler: + def __init__(self, optimizer: torch.optim.Optimizer): + self.optimizer = optimizer + + def step(self): + pass + + def get_last_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + return DummyScheduler(optimizer) + + def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int): + """ + Unified API to get any scheduler from its name. + """ + # if schedulefree optimizer, return dummy scheduler + if self.is_schedulefree_optimizer(optimizer, args): + return self.get_dummy_scheduler(optimizer) + + name = args.lr_scheduler + num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps + num_warmup_steps: Optional[int] = ( + int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + ) + num_decay_steps: Optional[int] = ( + int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + ) + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps + num_cycles = args.lr_scheduler_num_cycles + power = args.lr_scheduler_power + timescale = args.lr_scheduler_timescale + min_lr_ratio = args.lr_scheduler_min_lr_ratio + + lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs + if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: + for arg in args.lr_scheduler_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + lr_scheduler_kwargs[key] = value + + def wrap_check_needless_num_warmup_steps(return_vals): + if num_warmup_steps is not None and num_warmup_steps != 0: + raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.") + return return_vals + + # using any lr_scheduler from other library + if args.lr_scheduler_type: + lr_scheduler_type = args.lr_scheduler_type + logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + if "." not in lr_scheduler_type: # default to use torch.optim + lr_scheduler_module = torch.optim.lr_scheduler + else: + values = lr_scheduler_type.split(".") + lr_scheduler_module = importlib.import_module(".".join(values[:-1])) + lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return lr_scheduler + + if name.startswith("adafactor"): + assert ( + type(optimizer) == transformers.optimization.Adafactor + ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" + initial_lr = float(name.split(":")[1]) + # logger.info(f"adafactor scheduler init lr {initial_lr}") + return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) + + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value: + name = DiffusersSchedulerType(name) + schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs + + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + if name == SchedulerType.CONSTANT: + return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) + + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + **lr_scheduler_kwargs, + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + power=power, + **lr_scheduler_kwargs, + ) + + if name == SchedulerType.COSINE_WITH_MIN_LR: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles / 2, + min_lr_rate=min_lr_ratio, + **lr_scheduler_kwargs, + ) + + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + + # All other schedulers require `num_decay_steps` + if num_decay_steps is None: + raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, + min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, + **lr_scheduler_kwargs, + ) + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_decay_steps=num_decay_steps, + **lr_scheduler_kwargs, + ) + + def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool: + if not args.resume: + return False + + if not args.resume_from_huggingface: + logger.info(f"resume training from local state: {args.resume}") + accelerator.load_state(args.resume) + return True + + logger.info(f"resume training from huggingface state: {args.resume}") + repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] + path_in_repo = "/".join(args.resume.split("/")[2:]) + revision = None + repo_type = None + if ":" in path_in_repo: + divided = path_in_repo.split(":") + if len(divided) == 2: + path_in_repo, revision = divided + repo_type = "model" + else: + path_in_repo, revision, repo_type = divided + logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") + + list_files = huggingface_utils.list_dir( + repo_id=repo_id, + subfolder=path_in_repo, + revision=revision, + token=args.huggingface_token, + repo_type=repo_type, + ) + + async def download(filename) -> str: + def task(): + return huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type=repo_type, + token=args.huggingface_token, + ) + + return await asyncio.get_event_loop().run_in_executor(None, task) + + loop = asyncio.get_event_loop() + results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) + if len(results) == 0: + raise ValueError( + "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした" + ) + dirname = os.path.dirname(results[0]) + accelerator.load_state(dirname) + + return True + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, transformer, sample_parameters): + pass + + def get_noisy_model_input_and_timesteps( + self, + args: argparse.Namespace, + noise: torch.Tensor, + latents: torch.Tensor, + noise_scheduler: FlowMatchDiscreteScheduler, + device: torch.device, + dtype: torch.dtype, + ): + batch_size = noise.shape[0] + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift": + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device)) + else: + t = torch.rand((batch_size,), device=device) + + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(batch_size, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + t = logits_norm.sigmoid() + t = (t * shift) / (1 + (shift - 1) * t) + + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000.0 + t_min /= 1000.0 + t_max /= 1000.0 + t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1] + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + + timesteps += 1 # 1 to 1000 + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=batch_size, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + # indices = (u * noise_scheduler.config.num_train_timesteps).long() + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000 + indices = (u * (t_max - t_min) + t_min).long() + + timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000 + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps + + def train(self, args): + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + # Load dataset config + blueprint_generator = BlueprintGenerator(ConfigSanitizer()) + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_utils.load_user_config(args.dataset_config) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = collator_class(current_epoch, current_step, ds_for_collator) + + # prepare accelerator + logger.info("preparing accelerator") + accelerator = prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # prepare dtype + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # HunyuanVideo specific + vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype) + + # get embedding for sampling images + sample_parameters = vae = None + if args.sample_prompts: + sample_parameters = self.process_sample_prompts( + args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm + ) + + # Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory + vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae) + vae.requires_grad_(False) + vae.eval() + + if args.vae_chunk_size is not None: + vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size) + logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE") + if args.vae_spatial_tile_sample_min_size is not None: + vae.enable_spatial_tiling(True) + vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size + vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 + elif args.vae_tiling: + vae.enable_spatial_tiling(True) + + # load DiT model + blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0 + loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device + + logger.info(f"Loading DiT model from {args.dit}") + if args.sdpa: + attn_mode = "torch" + elif args.flash_attn: + attn_mode = "flash" + elif args.sage_attn: + attn_mode = "sageattn" + elif args.xformers: + attn_mode = "xformers" + else: + raise ValueError( + f"either --sdpa, --flash-attn, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --sage-attn, --xformersのいずれかを指定してください" + ) + transformer = load_transformer(args.dit, attn_mode, args.split_attn, loading_device, None) # load as is + + if blocks_to_swap > 0: + logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}") + transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True) + transformer.move_to_device_except_swap_blocks(accelerator.device) + if args.img_in_txt_in_offloading: + logger.info("Enable offloading img_in and txt_in to CPU") + transformer.enable_img_in_txt_in_offloading() + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # prepare optimizer, data loader etc. + accelerator.print("prepare optimizer, data loader etc.") + + transformer.requires_grad_(False) + if accelerator.is_main_process: + accelerator.print( + f"Trainable modules '{args.trainable_modules}'." + ) + for name, param in transformer.named_parameters(): + for trainable_module_name in args.trainable_modules: + if trainable_module_name in name: + param.requires_grad = True + break + + total_params = list(transformer.parameters()) + trainable_params = list(filter(lambda p: p.requires_grad, transformer.parameters())) + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params) / 1e6} M, total paramters: {sum(p.numel() for p in total_params) / 1e6} M") + optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer( + args, trainable_params + ) + + # prepare dataloader + + # num workers for data loader: if 0, persistent_workers is not available + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # calculate max_train_steps + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # send max_train_steps to train_dataset_group + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # prepare lr_scheduler + lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes) + + # prepare training model. accelerator does some magic here + + # experimental feature: train the model with gradients in fp16/bf16 + dit_dtype = torch.float32 + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + dit_weight_dtype = torch.float16 + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + dit_weight_dtype = torch.bfloat16 + else: + dit_weight_dtype = torch.float32 + + # TODO add fused optimizer and stochastic rounding + + # cast model to dit_weight_dtype + # if dit_dtype != dit_weight_dtype: + logger.info(f"casting model to {dit_weight_dtype}") + transformer.to(dit_weight_dtype) + + if blocks_to_swap > 0: + transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0]) + accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(transformer).prepare_block_swap_before_forward() + else: + transformer = accelerator.prepare(transformer) + + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + transformer.train() + + if args.full_fp16: + # patch accelerator for fp16 training + # def patch_accelerator_for_fp16_training(accelerator): + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + + # resume from local or huggingface. accelerator.step is set + self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("running training / 学習開始") + accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "hunyuan_video_ft" if args.log_tracker_name is None else args.log_tracker_name, + config=train_utils.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # TODO skip until initial step + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + + epoch_to_start = 0 + global_step = 0 + noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler") + + loss_recorder = train_utils.LossRecorder() + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + + title = args.metadata_title if args.metadata_title is not None else args.output_name + if args.min_timestep is not None or args.max_timestep is not None: + min_time_step = args.min_timestep if args.min_timestep is not None else 0 + max_time_step = args.max_timestep if args.max_timestep is not None else 1000 + md_timesteps = (min_time_step, max_time_step) + else: + md_timesteps = None + + sai_metadata = sai_model_spec.build_metadata( + None, + time.time(), + title, + None, + args.metadata_author, + args.metadata_description, + args.metadata_license, + args.metadata_tags, + timesteps=md_timesteps, + is_lora=False, + ) + + save_file(unwrapped_nw.state_dict(), ckpt_file, sai_metadata) + if args.huggingface_repo_id is not None: + huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + optimizer_eval_fn() + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, transformer, sample_parameters) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + # training loop + + # log device and dtype for each model + logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}") + + clean_memory_on_device(accelerator.device) + + pos_embed_cache = {} + + for epoch in range(epoch_to_start, num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + latents, llm_embeds, llm_mask, clip_embeds = batch + bsz = latents.shape[0] + current_step.value = global_step + + with accelerator.accumulate(transformer): + latents = latents * vae_module.SCALING_FACTOR + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # calculate model input and timesteps + noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps( + args, noise, latents, noise_scheduler, accelerator.device, dit_dtype + ) + + weighting = compute_loss_weighting_for_sd3( + args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype + ) + + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + guidance_vec.requires_grad_(True) + + pos_emb_shape = latents.shape[1:] + if pos_emb_shape not in pos_embed_cache: + freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(accelerator.unwrap_model(transformer), latents.shape[2:]) + # freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype) + # freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype) + pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin) + else: + freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape] + + # call DiT + latents = latents.to(device=accelerator.device, dtype=dit_dtype) + noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=dit_dtype) + # timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype) + # llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype) + # llm_mask = llm_mask.to(device=accelerator.device) + # clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype) + with accelerator.autocast(): + model_pred = transformer( + noisy_model_input, + timesteps, + text_states=llm_embeds, + text_mask=llm_mask, + text_states_2=clip_embeds, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + guidance=guidance_vec, + return_dict=False, + ) + + # flow matching loss + target = noise - latents + + loss = torch.nn.functional.mse_loss(model_pred.to(dit_dtype), target, reduction="none") + + if weighting is not None: + loss = loss * weighting + # loss = loss.mean([1, 2, 3]) + # # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients: + # self.all_reduce_network(accelerator, network) # sync DDP grad manually + state = accelerate.PartialState() + if state.distributed_type != accelerate.DistributedType.NO: + for param in transformer.parameters(): + if param.grad is not None: + param.grad = accelerator.reduce(param.grad, reduction="mean") + + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(transformer).parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + self.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, transformer, sample_parameters + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step) + save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch) + + if args.save_state: + train_utils.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_utils.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no) + remove_model(remove_ckpt_name) + optimizer_train_fn() + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch + 1) + + remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, transformer, sample_parameters) + optimizer_train_fn() + + # end of epoch + + if is_main_process: + transformer = accelerator.unwrap_model(transformer) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_utils.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_utils.get_last_ckpt_name(args.output_name) + save_model(ckpt_name, transformer, global_step, num_train_epochs, force_sync_upload=True) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + def int_or_float(value): + if value.endswith("%"): + try: + return float(value[:-1]) / 100.0 + except ValueError: + raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") + try: + float_value = float(value) + if float_value >= 1 and float_value.is_integer(): + return int(value) + return float(value) + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") + + parser = argparse.ArgumentParser() + + # general settings + parser.add_argument( + "--config_file", + type=str, + default=None, + help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す", + ) + parser.add_argument( + "--dataset_config", + type=pathlib.Path, + default=None, + required=True, + help="config file for dataset / データセットの設定ファイル", + ) + + # training settings + parser.add_argument( + "--sdpa", + action="store_true", + help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", + ) + parser.add_argument( + "--flash_attn", + action="store_true", + help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要", + ) + parser.add_argument( + "--sage_attn", + action="store_true", + help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要", + ) + parser.add_argument( + "--xformers", + action="store_true", + help="use xformers for CrossAttention, requires xformers / CrossAttentionにxformersを使う、xformersが必要", + ) + parser.add_argument( + "--split_attn", + action="store_true", + help="use split attention for attention calculation (split batch size=1, affects memory usage and speed)" + " / attentionを分割して計算する(バッチサイズ=1に分割、メモリ使用量と速度に影響)", + ) + + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument( + "--max_train_epochs", + type=int, + default=None, + help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)", + ) + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)", + ) + parser.add_argument( + "--persistent_data_loader_workers", + action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument( + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", + ) + parser.add_argument( + "--trainable_modules", + nargs='+', + default=".", + help='Enter a list of trainable modules' + ) + + parser.add_argument( + "--logging_dir", + type=str, + default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", + ) + parser.add_argument( + "--log_with", + type=str, + default=None, + choices=["tensorboard", "wandb", "all"], + help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", + ) + parser.add_argument( + "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列" + ) + parser.add_argument( + "--log_tracker_name", + type=str, + default=None, + help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", + ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前", + ) + parser.add_argument( + "--log_tracker_config", + type=str, + default=None, + help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス", + ) + parser.add_argument( + "--wandb_api_key", + type=str, + default=None, + help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", + ) + parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する") + + parser.add_argument( + "--ddp_timeout", + type=int, + default=None, + help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", + ) + parser.add_argument( + "--ddp_gradient_as_bucket_view", + action="store_true", + help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする", + ) + parser.add_argument( + "--ddp_static_graph", + action="store_true", + help="enable static_graph for DDP / DDPでstatic_graphを有効にする", + ) + + parser.add_argument( + "--sample_every_n_steps", + type=int, + default=None, + help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する", + ) + parser.add_argument( + "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する" + ) + parser.add_argument( + "--sample_every_n_epochs", + type=int, + default=None, + help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", + ) + parser.add_argument( + "--sample_prompts", + type=str, + default=None, + help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル", + ) + + # optimizer and lr scheduler settings + parser.add_argument( + "--optimizer_type", + type=str, + default="", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. " + "Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ", + ) + parser.add_argument( + "--optimizer_args", + type=str, + default=None, + nargs="*", + help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', + ) + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない", + ) + + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int_or_float, + default=0, + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps" + " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", + ) + parser.add_argument( + "--lr_decay_steps", + type=int_or_float, + default=0, + help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps" + " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", + ) + parser.add_argument( + "--lr_scheduler_num_cycles", + type=int, + default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数", + ) + parser.add_argument( + "--lr_scheduler_power", + type=float, + default=1, + help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", + ) + parser.add_argument( + "--lr_scheduler_timescale", + type=int, + default=None, + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", + ) + parser.add_argument( + "--lr_scheduler_min_lr_ratio", + type=float, + default=None, + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + ) + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") + parser.add_argument( + "--lr_scheduler_args", + type=str, + default=None, + nargs="*", + help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")', + ) + + # model settings + parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path / DiTのチェックポイントのパス") + parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16") + parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") + parser.add_argument( + "--vae_tiling", + action="store_true", + help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled." + " / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。", + ) + parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") + parser.add_argument( + "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" + ) + parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ") + parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ") + parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16") + parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う") + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する") + + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX", + ) + parser.add_argument( + "--img_in_txt_in_offloading", + action="store_true", + help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする", + ) + + # parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers") + parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.") + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=1.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。', + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"], + help="weighting scheme for timestep distribution. Default is none" + " / タイムステップ分布の重み付けスキーム、デフォルトはnone", + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール", + ) + parser.add_argument( + "--min_timestep", + type=int, + default=None, + help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ", + ) + parser.add_argument( + "--max_timestep", + type=int, + default=None, + help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", + ) + + # save and load settings + parser.add_argument( + "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ" + ) + parser.add_argument( + "--output_name", + type=str, + default=None, + required=True, + help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名", + ) + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") + + parser.add_argument( + "--save_every_n_epochs", + type=int, + default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する", + ) + parser.add_argument( + "--save_every_n_steps", + type=int, + default=None, + help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", + ) + parser.add_argument( + "--save_last_n_epochs", + type=int, + default=None, + help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)", + ) + parser.add_argument( + "--save_last_n_epochs_state", + type=int, + default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)", + ) + parser.add_argument( + "--save_last_n_steps", + type=int, + default=None, + help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)", + ) + parser.add_argument( + "--save_last_n_steps_state", + type=int, + default=None, + help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)", + ) + parser.add_argument( + "--save_state", + action="store_true", + help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する", + ) + parser.add_argument( + "--save_state_on_train_end", + action="store_true", + help="save training state (including optimizer states etc.) on train end even if --save_state is not specified" + " / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する", + ) + + # SAI Model spec + parser.add_argument( + "--metadata_title", + type=str, + default=None, + help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", + ) + parser.add_argument( + "--metadata_author", + type=str, + default=None, + help="author name for model metadata / メタデータに書き込まれるモデル作者名", + ) + parser.add_argument( + "--metadata_description", + type=str, + default=None, + help="description for model metadata / メタデータに書き込まれるモデル説明", + ) + parser.add_argument( + "--metadata_license", + type=str, + default=None, + help="license for model metadata / メタデータに書き込まれるモデルライセンス", + ) + parser.add_argument( + "--metadata_tags", + type=str, + default=None, + help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", + ) + + # huggingface settings + parser.add_argument( + "--huggingface_repo_id", + type=str, + default=None, + help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名", + ) + parser.add_argument( + "--huggingface_repo_type", + type=str, + default=None, + help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類", + ) + parser.add_argument( + "--huggingface_path_in_repo", + type=str, + default=None, + help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス", + ) + parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン") + parser.add_argument( + "--huggingface_repo_visibility", + type=str, + default=None, + help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)", + ) + parser.add_argument( + "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する" + ) + parser.add_argument( + "--resume_from_huggingface", + action="store_true", + help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", + ) + parser.add_argument( + "--async_upload", + action="store_true", + help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする", + ) + + return parser + + +def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): + if not args.config_file: + return args + + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + + if not os.path.exists(config_path): + logger.info(f"{config_path} not found.") + exit(1) + + logger.info(f"Loading settings from {config_path}...") + with open(config_path, "r", encoding="utf-8") as f: + config_dict = toml.load(f) + + # combine all sections into one + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + # if value is not dict, save key and value as is + if not isinstance(section_dict, dict): + ignore_nesting_dict[section_name] = section_dict + continue + + # if value is dict, save all key and value into one dict + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = os.path.splitext(args.config_file)[0] + logger.info(args.config_file) + + return args + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = read_config_from_file(args, parser) + + trainer = FineTuningTrainer() + trainer.train(args) diff --git a/hv_train_network.py b/hv_train_network.py new file mode 100644 index 0000000000000000000000000000000000000000..7d7144398e44bccb9b29742b61f704fc75e34a5c --- /dev/null +++ b/hv_train_network.py @@ -0,0 +1,2366 @@ +import ast +import asyncio +from datetime import datetime +import gc +import importlib +import argparse +import math +import os +import pathlib +import re +import sys +import random +import time +import json +from multiprocessing import Value +from typing import Any, Dict, List, Optional +import accelerate +import numpy as np +from packaging.version import Version + +import huggingface_hub +import toml + +import torch +from tqdm import tqdm +from accelerate.utils import set_seed +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState +from safetensors.torch import load_file +import transformers +from diffusers.optimization import ( + SchedulerType as DiffusersSchedulerType, + TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, +) +from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION + +from dataset import config_utils +from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape, HYVideoDiffusionTransformer +import hunyuan_model.text_encoder as text_encoder_module +from hunyuan_model.vae import load_vae, VAE_VER +import hunyuan_model.vae as vae_module +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +import networks.lora as lora_module +from dataset.config_utils import BlueprintGenerator, ConfigSanitizer +from hv_generate_video import save_images_grid, save_videos_grid + +import logging + +from utils import huggingface_utils, model_utils, train_utils, sai_model_spec + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video" + +SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version" +SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module" +SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim" +SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha" +SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args" + +SS_METADATA_MINIMUM_KEYS = [ + SS_METADATA_KEY_BASE_MODEL_VERSION, + SS_METADATA_KEY_NETWORK_MODULE, + SS_METADATA_KEY_NETWORK_DIM, + SS_METADATA_KEY_NETWORK_ALPHA, + SS_METADATA_KEY_NETWORK_ARGS, +] + + +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + +# for collate_fn: epoch and step is multiprocessing.Value +class collator_class: + def __init__(self, epoch, step, dataset): + self.current_epoch = epoch + self.current_step = step + self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing + + def __call__(self, examples): + worker_info = torch.utils.data.get_worker_info() + # worker_info is None in the main process + if worker_info is not None: + dataset = worker_info.dataset + else: + dataset = self.dataset + + # set epoch and step + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] + + +def prepare_accelerator(args: argparse.Namespace) -> Accelerator: + """ + DeepSpeed is not supported in this script currently. + """ + if args.logging_dir is None: + logging_dir = None + else: + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) + + if args.log_with is None: + if logging_dir is not None: + log_with = "tensorboard" + else: + log_with = None + else: + log_with = args.log_with + if log_with in ["tensorboard", "all"]: + if logging_dir is None: + raise ValueError( + "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください" + ) + if log_with in ["wandb", "all"]: + try: + import wandb + except ImportError: + raise ImportError("No wandb / wandb がインストールされていないようです") + if logging_dir is not None: + os.makedirs(logging_dir, exist_ok=True) + os.environ["WANDB_DIR"] = logging_dir + if args.wandb_api_key is not None: + wandb.login(key=args.wandb_api_key) + + kwargs_handlers = [ + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), + ] + kwargs_handlers = [i for i in kwargs_handlers if i is not None] + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=log_with, + project_dir=logging_dir, + kwargs_handlers=kwargs_handlers, + ) + print("accelerator device:", accelerator.device) + return accelerator + + +def line_to_prompt_dict(line: str) -> dict: + # subset of gen_img_diffusers + prompt_args = line.split(" --") + prompt_dict = {} + prompt_dict["prompt"] = prompt_args[0] + + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["width"] = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["height"] = int(m.group(1)) + continue + + m = re.match(r"f (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["frame_count"] = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["seed"] = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + prompt_dict["guidance_scale"] = float(m.group(1)) + continue + + m = re.match(r"fs ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + prompt_dict["discrete_flow_shift"] = float(m.group(1)) + continue + + # m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + # if m: # scale + # prompt_dict["scale"] = float(m.group(1)) + # continue + # m = re.match(r"n (.+)", parg, re.IGNORECASE) + # if m: # negative prompt + # prompt_dict["negative_prompt"] = m.group(1) + # continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(ex) + + return prompt_dict + + +def load_prompts(prompt_file: str) -> list[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + + # if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps): + if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]): + # raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません") + # round to nearest timestep + logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません") + step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps] + else: + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap": + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype) + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + else: + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = None # torch.ones_like(sigmas) + return weighting + + +def should_sample_images(args, steps, epoch=None): + if steps == 0: + if not args.sample_at_first: + return False + else: + should_sample_by_steps = args.sample_every_n_steps is not None and steps % args.sample_every_n_steps == 0 + should_sample_by_epochs = ( + args.sample_every_n_epochs is not None and epoch is not None and epoch % args.sample_every_n_epochs == 0 + ) + if not should_sample_by_steps and not should_sample_by_epochs: + return False + return True + + +def sample_images(accelerator, args, epoch, steps, vae, transformer, sample_parameters, dit_dtype): + if not should_sample_images(args, steps, epoch): + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if sample_parameters is None: + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # Use the unwrapped model + transformer: HYVideoDiffusionTransformer = accelerator.unwrap_model(transformer) + transformer.switch_block_swap_for_inference() + + # Create a directory to save the samples + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(), accelerator.autocast(): + for sample_parameter in sample_parameters: + sample_image_inference(accelerator, args, transformer, dit_dtype, vae, save_dir, sample_parameter, epoch, steps) + clean_memory_on_device(accelerator.device) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_params = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_params.append(sample_parameters[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_params) as sample_parameter_lists: + for sample_parameter in sample_parameter_lists[0]: + sample_image_inference(accelerator, args, transformer, dit_dtype, vae, save_dir, sample_parameter, epoch, steps) + clean_memory_on_device(accelerator.device) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + transformer.switch_block_swap_for_training() + clean_memory_on_device(accelerator.device) + + +def sample_image_inference(accelerator, args, transformer, dit_dtype, vae, save_dir, sample_parameter, epoch, steps): + sample_steps = sample_parameter.get("sample_steps", 20) + width = sample_parameter.get("width", 256) # make smaller for faster and memory saving inference + height = sample_parameter.get("height", 256) + frame_count = sample_parameter.get("frame_count", 1) + guidance_scale = sample_parameter.get("guidance_scale", 6.0) + discrete_flow_shift = sample_parameter.get("discrete_flow_shift", 14.5) + seed = sample_parameter.get("seed") + prompt: str = sample_parameter.get("prompt", "") + + # Calculate latent video length based on VAE version + if "884" in VAE_VER: + latent_video_length = (frame_count - 1) // 4 + 1 + elif "888" in VAE_VER: + latent_video_length = (frame_count - 1) // 8 + 1 + else: + latent_video_length = frame_count + + device = accelerator.device + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + generator = torch.Generator(device=device).manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + generator = torch.Generator(device=device).manual_seed(torch.initial_seed()) + + logger.info(f"prompt: {prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"frame count: {frame_count}") + logger.info(f"sample steps: {sample_steps}") + logger.info(f"guidance scale: {guidance_scale}") + logger.info(f"discrete flow shift: {discrete_flow_shift}") + if seed is not None: + logger.info(f"seed: {seed}") + + # Prepare scheduler for each prompt + scheduler = FlowMatchDiscreteScheduler(shift=discrete_flow_shift, reverse=True, solver="euler") + + # Number of inference steps for sampling + scheduler.set_timesteps(sample_steps, device=device) + timesteps = scheduler.timesteps + + # Get embeddings + prompt_embeds = sample_parameter["llm_embeds"].to(device=device, dtype=dit_dtype) + prompt_mask = sample_parameter["llm_mask"].to(device=device) + prompt_embeds_2 = sample_parameter["clipL_embeds"].to(device=device, dtype=dit_dtype) + + num_channels_latents = 16 # transformer.config.in_channels + vae_scale_factor = 2 ** (4 - 1) # Assuming 4 VAE blocks + + # Initialize latents + shape_or_frame = ( + 1, + num_channels_latents, + 1, + height // vae_scale_factor, + width // vae_scale_factor, + ) + latents = [] + for _ in range(latent_video_length): + latents.append(torch.randn(shape_or_frame, generator=generator, device=device, dtype=dit_dtype)) + latents = torch.cat(latents, dim=2) + + # Guidance scale + guidance_expand = torch.tensor([guidance_scale * 1000.0], dtype=torch.float32, device=device).to(dit_dtype) + + # Get rotary positional embeddings + freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(transformer, latents.shape[2:]) + freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype) + freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype) + + # Wrap the inner loop with tqdm to track progress over timesteps + prompt_idx = sample_parameter.get("enum", 0) + with torch.no_grad(): + for i, t in enumerate(tqdm(timesteps, desc=f"Sampling timesteps for prompt {prompt_idx+1}")): + latents = scheduler.scale_model_input(latents, t) + noise_pred = transformer( + latents, + t.repeat(latents.shape[0]).to(device=device, dtype=dit_dtype), + text_states=prompt_embeds, + text_mask=prompt_mask, + text_states_2=prompt_embeds_2, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + guidance=guidance_expand, + return_dict=True, + )["x"] + + # Compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Move VAE to the appropriate device for sampling + vae.to(device) + vae.eval() + + # Decode latents to video + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = latents / vae.config.scaling_factor + vae.config.shift_factor + else: + latents = latents / vae.config.scaling_factor + + latents = latents.to(device=device, dtype=vae.dtype) + with torch.no_grad(): + video = vae.decode(latents, return_dict=False)[0] + video = (video / 2 + 0.5).clamp(0, 1) + video = video.cpu().float() + + # Save video + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + save_path = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{prompt_idx:02d}_{ts_str}{seed_suffix}" + if video.shape[2] == 1: + save_images_grid(video, save_dir, save_path, create_subdir=False) + else: + save_videos_grid(video, os.path.join(save_dir, save_path) + ".mp4") + + # Move models back to initial state + vae.to("cpu") + + +class NetworkTrainer: + def __init__(self): + pass + + # TODO 他のスクリプトと共通化する + def generate_step_logs( + self, + args: argparse.Namespace, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer=None, + keys_scaled=None, + mean_norm=None, + maximum_norm=None, + ): + network_train_unet_only = True + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if keys_scaled is not None: + logs["max_norm/keys_scaled"] = keys_scaled + logs["max_norm/average_key_norm"] = mean_norm + logs["max_norm/max_key_norm"] = maximum_norm + + lrs = lr_scheduler.get_last_lr() + for i, lr in enumerate(lrs): + if lr_descriptions is not None: + lr_desc = lr_descriptions[i] + else: + idx = i - (0 if network_train_unet_only else -1) + if idx == -1: + lr_desc = "textencoder" + else: + if len(lrs) > 2: + lr_desc = f"group{idx}" + else: + lr_desc = "unet" + + logs[f"lr/{lr_desc}"] = lr + + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + # tracking d*lr value + logs[f"lr/d*lr/{lr_desc}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + else: + idx = 0 + if not network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + + return logs + + def process_sample_prompts( + self, + args: argparse.Namespace, + accelerator: Accelerator, + sample_prompts: str, + text_encoder1: str, + text_encoder2: str, + fp8_llm: bool, + ): + logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}") + prompts = load_prompts(sample_prompts) + + def encode_for_text_encoder(text_encoder, is_llm=True): + sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask) + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + + data_type = "video" + text_inputs = text_encoder.text2tokens(p, data_type=data_type) + + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type) + sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask) + + return sample_prompts_te_outputs + + # Load Text Encoder 1 and encode + text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype) + logger.info(f"loading text encoder 1: {text_encoder1}") + text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype) + + logger.info("encoding with Text Encoder 1") + te_outputs_1 = encode_for_text_encoder(text_encoder_1) + del text_encoder_1 + + # Load Text Encoder 2 and encode + logger.info(f"loading text encoder 2: {text_encoder2}") + text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype) + + logger.info("encoding with Text Encoder 2") + te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False) + del text_encoder_2 + + # prepare sample parameters + sample_parameters = [] + for prompt_dict in prompts: + prompt_dict_copy = prompt_dict.copy() + p = prompt_dict.get("prompt", "") + prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0] + prompt_dict_copy["llm_mask"] = te_outputs_1[p][1] + prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0] + prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1] + sample_parameters.append(prompt_dict_copy) + + clean_memory_on_device(accelerator.device) + + return sample_parameters + + def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]: + # adamw, adamw8bit, adafactor + + optimizer_type = args.optimizer_type.lower() + + # split optimizer_type and optimizer_args + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + optimizer_kwargs[key] = value + + lr = args.learning_rate + optimizer = None + optimizer_class = None + + if optimizer_type.endswith("8bit".lower()): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + + if optimizer_type == "AdamW8bit".lower(): + logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Adafactor".lower(): + # Adafactor: check relative_step and warmup_init + if "relative_step" not in optimizer_kwargs: + optimizer_kwargs["relative_step"] = True # default + if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): + logger.info( + f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします" + ) + optimizer_kwargs["relative_step"] = True + logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") + + if optimizer_kwargs["relative_step"]: + logger.info(f"relative_step is true / relative_stepがtrueです") + if lr != 0.0: + logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + args.learning_rate = None + + if args.lr_scheduler != "adafactor": + logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど + + lr = None + else: + if args.max_grad_norm != 0.0: + logger.warning( + f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" + ) + if args.lr_scheduler != "constant_with_warmup": + logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: + logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + + optimizer_class = transformers.optimization.Adafactor + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "AdamW".lower(): + logger.info(f"use AdamW optimizer | {optimizer_kwargs}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + if optimizer is None: + # 任意のoptimizerを使う + case_sensitive_optimizer_type = args.optimizer_type # not lower + logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") + + if "." not in case_sensitive_optimizer_type: # from torch.optim + optimizer_module = torch.optim + else: # from other library + values = case_sensitive_optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + case_sensitive_optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + # for logging + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + + # get train and eval functions + if hasattr(optimizer, "train") and callable(optimizer.train): + train_fn = optimizer.train + eval_fn = optimizer.eval + else: + train_fn = lambda: None + eval_fn = lambda: None + + return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn + + def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool: + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + + def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any: + # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. + # this scheduler is used for logging only. + # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler + class DummyScheduler: + def __init__(self, optimizer: torch.optim.Optimizer): + self.optimizer = optimizer + + def step(self): + pass + + def get_last_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + return DummyScheduler(optimizer) + + def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int): + """ + Unified API to get any scheduler from its name. + """ + # if schedulefree optimizer, return dummy scheduler + if self.is_schedulefree_optimizer(optimizer, args): + return self.get_dummy_scheduler(optimizer) + + name = args.lr_scheduler + num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps + num_warmup_steps: Optional[int] = ( + int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + ) + num_decay_steps: Optional[int] = ( + int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + ) + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps + num_cycles = args.lr_scheduler_num_cycles + power = args.lr_scheduler_power + timescale = args.lr_scheduler_timescale + min_lr_ratio = args.lr_scheduler_min_lr_ratio + + lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs + if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: + for arg in args.lr_scheduler_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + lr_scheduler_kwargs[key] = value + + def wrap_check_needless_num_warmup_steps(return_vals): + if num_warmup_steps is not None and num_warmup_steps != 0: + raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.") + return return_vals + + # using any lr_scheduler from other library + if args.lr_scheduler_type: + lr_scheduler_type = args.lr_scheduler_type + logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + if "." not in lr_scheduler_type: # default to use torch.optim + lr_scheduler_module = torch.optim.lr_scheduler + else: + values = lr_scheduler_type.split(".") + lr_scheduler_module = importlib.import_module(".".join(values[:-1])) + lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return lr_scheduler + + if name.startswith("adafactor"): + assert ( + type(optimizer) == transformers.optimization.Adafactor + ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" + initial_lr = float(name.split(":")[1]) + # logger.info(f"adafactor scheduler init lr {initial_lr}") + return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) + + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value: + name = DiffusersSchedulerType(name) + schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs + + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + if name == SchedulerType.CONSTANT: + return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) + + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + **lr_scheduler_kwargs, + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + power=power, + **lr_scheduler_kwargs, + ) + + if name == SchedulerType.COSINE_WITH_MIN_LR: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles / 2, + min_lr_rate=min_lr_ratio, + **lr_scheduler_kwargs, + ) + + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + + # All other schedulers require `num_decay_steps` + if num_decay_steps is None: + raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, + min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, + **lr_scheduler_kwargs, + ) + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_decay_steps=num_decay_steps, + **lr_scheduler_kwargs, + ) + + def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool: + if not args.resume: + return False + + if not args.resume_from_huggingface: + logger.info(f"resume training from local state: {args.resume}") + accelerator.load_state(args.resume) + return True + + logger.info(f"resume training from huggingface state: {args.resume}") + repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] + path_in_repo = "/".join(args.resume.split("/")[2:]) + revision = None + repo_type = None + if ":" in path_in_repo: + divided = path_in_repo.split(":") + if len(divided) == 2: + path_in_repo, revision = divided + repo_type = "model" + else: + path_in_repo, revision, repo_type = divided + logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") + + list_files = huggingface_utils.list_dir( + repo_id=repo_id, + subfolder=path_in_repo, + revision=revision, + token=args.huggingface_token, + repo_type=repo_type, + ) + + async def download(filename) -> str: + def task(): + return huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type=repo_type, + token=args.huggingface_token, + ) + + return await asyncio.get_event_loop().run_in_executor(None, task) + + loop = asyncio.get_event_loop() + results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) + if len(results) == 0: + raise ValueError( + "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした" + ) + dirname = os.path.dirname(results[0]) + accelerator.load_state(dirname) + + return True + + def get_noisy_model_input_and_timesteps( + self, + args: argparse.Namespace, + noise: torch.Tensor, + latents: torch.Tensor, + noise_scheduler: FlowMatchDiscreteScheduler, + device: torch.device, + dtype: torch.dtype, + ): + batch_size = noise.shape[0] + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift": + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device)) + else: + t = torch.rand((batch_size,), device=device) + + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(batch_size, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + t = logits_norm.sigmoid() + t = (t * shift) / (1 + (shift - 1) * t) + + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000.0 + t_min /= 1000.0 + t_max /= 1000.0 + t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1] + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + + timesteps += 1 # 1 to 1000 + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=batch_size, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + # indices = (u * noise_scheduler.config.num_train_timesteps).long() + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000 + indices = (u * (t_max - t_min) + t_min).long() + + timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000 + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps + + def show_timesteps(self, args: argparse.Namespace): + N_TRY = 100000 + BATCH_SIZE = 1000 + CONSOLE_WIDTH = 64 + N_TIMESTEPS_PER_LINE = 25 + + noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler") + # print(f"Noise scheduler timesteps: {noise_scheduler.timesteps}") + + latents = torch.zeros(BATCH_SIZE, 1, 1, 1, 1, dtype=torch.float16) + noise = torch.ones_like(latents) + + # sample timesteps + sampled_timesteps = [0] * noise_scheduler.config.num_train_timesteps + for i in tqdm(range(N_TRY // BATCH_SIZE)): + # we use noise=1, so retured noisy_model_input is same as timestep, because `noisy_model_input = (1 - t) * latents + t * noise` + actual_timesteps, _ = self.get_noisy_model_input_and_timesteps( + args, noise, latents, noise_scheduler, "cpu", torch.float16 + ) + actual_timesteps = actual_timesteps[:, 0, 0, 0, 0] * 1000 + for t in actual_timesteps: + t = int(t.item()) + sampled_timesteps[t] += 1 + + # sample weighting + sampled_weighting = [0] * noise_scheduler.config.num_train_timesteps + for i in tqdm(range(len(sampled_weighting))): + timesteps = torch.tensor([i + 1], device="cpu") + weighting = compute_loss_weighting_for_sd3(args.weighting_scheme, noise_scheduler, timesteps, "cpu", torch.float16) + if weighting is None: + weighting = torch.tensor(1.0, device="cpu") + elif torch.isinf(weighting).any(): + weighting = torch.tensor(1.0, device="cpu") + sampled_weighting[i] = weighting.item() + + # show results + if args.show_timesteps == "image": + # show timesteps with matplotlib + import matplotlib.pyplot as plt + + plt.figure(figsize=(10, 5)) + plt.subplot(1, 2, 1) + plt.bar(range(len(sampled_timesteps)), sampled_timesteps, width=1.0) + plt.title("Sampled timesteps") + plt.xlabel("Timestep") + plt.ylabel("Count") + + plt.subplot(1, 2, 2) + plt.bar(range(len(sampled_weighting)), sampled_weighting, width=1.0) + plt.title("Sampled loss weighting") + plt.xlabel("Timestep") + plt.ylabel("Weighting") + + plt.tight_layout() + plt.show() + + else: + sampled_timesteps = np.array(sampled_timesteps) + sampled_weighting = np.array(sampled_weighting) + + # average per line + sampled_timesteps = sampled_timesteps.reshape(-1, N_TIMESTEPS_PER_LINE).mean(axis=1) + sampled_weighting = sampled_weighting.reshape(-1, N_TIMESTEPS_PER_LINE).mean(axis=1) + + max_count = max(sampled_timesteps) + print(f"Sampled timesteps: max count={max_count}") + for i, t in enumerate(sampled_timesteps): + line = f"{(i)*N_TIMESTEPS_PER_LINE:4d}-{(i+1)*N_TIMESTEPS_PER_LINE-1:4d}: " + line += "#" * int(t / max_count * CONSOLE_WIDTH) + print(line) + + max_weighting = max(sampled_weighting) + print(f"Sampled loss weighting: max weighting={max_weighting}") + for i, w in enumerate(sampled_weighting): + line = f"{i*N_TIMESTEPS_PER_LINE:4d}-{(i+1)*N_TIMESTEPS_PER_LINE-1:4d}: {w:8.2f} " + line += "#" * int(w / max_weighting * CONSOLE_WIDTH) + print(line) + + def train(self, args): + # check required arguments + if args.dataset_config is None: + raise ValueError("dataset_config is required / dataset_configが必要です") + if args.dit is None: + raise ValueError("path to DiT model is required / DiTモデルのパスが必要です") + + # show timesteps for debugging + if args.show_timesteps: + self.show_timesteps(args) + return + + session_id = random.randint(0, 2**32) + training_started_at = time.time() + # setup_logging(args, reset=True) + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + # Load dataset config + blueprint_generator = BlueprintGenerator(ConfigSanitizer()) + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_utils.load_user_config(args.dataset_config) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = collator_class(current_epoch, current_step, ds_for_collator) + + # prepare accelerator + logger.info("preparing accelerator") + accelerator = prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # prepare dtype + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # HunyuanVideo specific + dit_dtype = torch.bfloat16 if args.dit_dtype is None else model_utils.str_to_dtype(args.dit_dtype) + dit_weight_dtype = torch.float8_e4m3fn if args.fp8_base else dit_dtype + logger.info(f"DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}") + vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype) + + # get embedding for sampling images + sample_parameters = vae = None + if args.sample_prompts: + sample_parameters = self.process_sample_prompts( + args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm + ) + # Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory + vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae) + vae.requires_grad_(False) + vae.eval() + + if args.vae_chunk_size is not None: + vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size) + logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE") + if args.vae_spatial_tile_sample_min_size is not None: + vae.enable_spatial_tiling(True) + vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size + vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 + elif args.vae_tiling: + vae.enable_spatial_tiling(True) + + # load DiT model + blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0 + loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device + + logger.info(f"Loading DiT model from {args.dit}") + if args.sdpa: + attn_mode = "torch" + elif args.flash_attn: + attn_mode = "flash" + elif args.sage_attn: + attn_mode = "sageattn" + elif args.xformers: + attn_mode = "xformers" + else: + raise ValueError( + f"either --sdpa, --flash-attn, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --sage-attn, --xformersのいずれかを指定してください" + ) + transformer = load_transformer(args.dit, attn_mode, args.split_attn, loading_device, dit_weight_dtype) + transformer.eval() + transformer.requires_grad_(False) + + if blocks_to_swap > 0: + logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}") + transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True) + transformer.move_to_device_except_swap_blocks(accelerator.device) + if args.img_in_txt_in_offloading: + logger.info("Enable offloading img_in and txt_in to CPU") + transformer.enable_img_in_txt_in_offloading() + + # load network model for differential training + sys.path.append(os.path.dirname(__file__)) + accelerator.print("import network module:", args.network_module) + network_module: lora_module = importlib.import_module(args.network_module) # actual module may be different + + if args.base_weights is not None: + # if base_weights is specified, merge the weights to DiT model + for i, weight_path in enumerate(args.base_weights): + if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + multiplier = 1.0 + else: + multiplier = args.base_weights_multiplier[i] + + accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") + + weights_sd = load_file(weight_path) + module = network_module.create_network_from_weights_hunyuan_video( + multiplier, weights_sd, unet=transformer, for_inference=True + ) + module.merge_to(None, transformer, weights_sd, weight_dtype, "cpu") + + accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") + + # prepare network + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.dim_from_weights: + logger.info(f"Loading network from weights: {args.dim_from_weights}") + weights_sd = load_file(args.dim_from_weights) + network, _ = network_module.create_network_from_weights_hunyuan_video(1, weights_sd, unet=transformer) + else: + if hasattr(network_module, 'create_network_hunyuan_video'): + network = network_module.create_network_hunyuan_video( + 1.0, + args.network_dim, + args.network_alpha, + vae, + None, + transformer, + neuron_dropout=args.network_dropout, + **net_kwargs, + ) + else: + network = network_module.create_network( + 1.0, + args.network_dim, + args.network_alpha, + vae, + None, + transformer, + **net_kwargs, + ) + if network is None: + return + + if hasattr(network_module, 'prepare_network'): + network.prepare_network(args) + + # apply network to DiT + network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + + if args.network_weights is not None: + # FIXME consider alpha of weights: this assumes that the alpha is not changed + info = network.load_weights(args.network_weights) + accelerator.print(f"load network weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + network.enable_gradient_checkpointing() # may have no effect + + # prepare optimizer, data loader etc. + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params, lr_descriptions = network.prepare_optimizer_params(unet_lr=args.learning_rate) + optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer( + args, trainable_params + ) + + # prepare dataloader + + # num workers for data loader: if 0, persistent_workers is not available + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # calculate max_train_steps + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # send max_train_steps to train_dataset_group + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # prepare lr_scheduler + lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes) + + # prepare training model. accelerator does some magic here + + # experimental feature: train the model with gradients in fp16/bf16 + network_dtype = torch.float32 + args.full_fp16 = args.full_bf16 = False # temporary disabled because stochastic rounding is not supported yet + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + network_dtype = weight_dtype + network.to(network_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + network_dtype = weight_dtype + network.to(network_dtype) + + if dit_weight_dtype != dit_dtype: + logger.info(f"casting model to {dit_weight_dtype}") + transformer.to(dit_weight_dtype) + + if blocks_to_swap > 0: + transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0]) + accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(transformer).prepare_block_swap_before_forward() + else: + transformer = accelerator.prepare(transformer) + + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + training_model = network + + if args.gradient_checkpointing: + transformer.train() + else: + transformer.eval() + + accelerator.unwrap_model(network).prepare_grad_etc(transformer) + + if args.full_fp16: + # patch accelerator for fp16 training + # def patch_accelerator_for_fp16_training(accelerator): + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + + # before resuming make hook for saving/loading to save/load the network weights only + def save_model_hook(models, weights, output_dir): + # pop weights of other models than network to save only network weights + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process: # or args.deepspeed: + remove_indices = [] + for i, model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + if len(weights) > i: + weights.pop(i) + # print(f"save model hook: {len(weights)} weights will be saved") + + def load_model_hook(models, input_dir): + # remove models except network + remove_indices = [] + for i, model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + models.pop(i) + # print(f"load model hook: {len(models)} models will be loaded") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # resume from local or huggingface. accelerator.step is set + self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("running training / 学習開始") + accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + # TODO refactor metadata creation and move to util + metadata = { + "ss_" "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, + "ss_learning_rate": args.learning_rate, + "ss_num_train_items": train_dataset_group.num_train_items, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_gradient_checkpointing": args.gradient_checkpointing, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + SS_METADATA_KEY_BASE_MODEL_VERSION: BASE_MODEL_VERSION_HUNYUAN_VIDEO, + # "ss_network_module": args.network_module, + # "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + # "ss_network_alpha": args.network_alpha, # some networks may not have alpha + SS_METADATA_KEY_NETWORK_MODULE: args.network_module, + SS_METADATA_KEY_NETWORK_DIM: args.network_dim, + SS_METADATA_KEY_NETWORK_ALPHA: args.network_alpha, + "ss_network_dropout": args.network_dropout, # some networks may not have dropout + "ss_mixed_precision": args.mixed_precision, + "ss_seed": args.seed, + "ss_training_comment": args.training_comment, # will not be updated after training + # "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_llm": bool(args.fp8_llm), + "ss_full_fp16": bool(args.full_fp16), + "ss_full_bf16": bool(args.full_bf16), + "ss_weighting_scheme": args.weighting_scheme, + "ss_logit_mean": args.logit_mean, + "ss_logit_std": args.logit_std, + "ss_mode_scale": args.mode_scale, + "ss_guidance_scale": args.guidance_scale, + "ss_timestep_sampling": args.timestep_sampling, + "ss_sigmoid_scale": args.sigmoid_scale, + "ss_discrete_flow_shift": args.discrete_flow_shift, + } + + datasets_metadata = [] + # tag_frequency = {} # merge tag frequency for metadata editor # TODO support tag frequency + for dataset in train_dataset_group.datasets: + dataset_metadata = dataset.get_metadata() + datasets_metadata.append(dataset_metadata) + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + + # add extra args + if args.network_args: + # metadata["ss_network_args"] = json.dumps(net_kwargs) + metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(net_kwargs) + + # model name and hash + if args.dit is not None: + logger.info(f"calculate hash for DiT model: {args.dit}") + sd_model_name = args.dit + if os.path.exists(sd_model_name): + # metadata["ss_sd_model_hash"] = model_utils.model_hash(sd_model_name) + # metadata["ss_new_sd_model_hash"] = model_utils.calculate_sha256(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + logger.info(f"calculate hash for VAE model: {args.vae}") + vae_name = args.vae + if os.path.exists(vae_name): + # metadata["ss_vae_hash"] = model_utils.model_hash(vae_name) + # metadata["ss_new_vae_hash"] = model_utils.calculate_sha256(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + + # make minimum metadata for filtering + minimum_metadata = {} + for key in SS_METADATA_MINIMUM_KEYS: + if key in metadata: + minimum_metadata[key] = metadata[key] + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "network_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_utils.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # TODO skip until initial step + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + + epoch_to_start = 0 + global_step = 0 + noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler") + + loss_recorder = train_utils.LossRecorder() + del train_dataset_group + + # function for saving/removing + save_dtype = dit_dtype + + def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + metadata["ss_training_finished_at"] = str(time.time()) + metadata["ss_steps"] = str(steps) + metadata["ss_epoch"] = str(epoch_no) + + metadata_to_save = minimum_metadata if args.no_metadata else metadata + + title = args.metadata_title if args.metadata_title is not None else args.output_name + if args.min_timestep is not None or args.max_timestep is not None: + min_time_step = args.min_timestep if args.min_timestep is not None else 0 + max_time_step = args.max_timestep if args.max_timestep is not None else 1000 + md_timesteps = (min_time_step, max_time_step) + else: + md_timesteps = None + + sai_metadata = sai_model_spec.build_metadata( + None, + time.time(), + title, + None, + args.metadata_author, + args.metadata_description, + args.metadata_license, + args.metadata_tags, + timesteps=md_timesteps, + ) + + metadata_to_save.update(sai_metadata) + + unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) + if args.huggingface_repo_id is not None: + huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + if should_sample_images(args, global_step, epoch=0): + optimizer_eval_fn() + sample_images(accelerator, args, 0, global_step, vae, transformer, sample_parameters, dit_dtype) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + # training loop + + # log device and dtype for each model + logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}") + + clean_memory_on_device(accelerator.device) + + pos_embed_cache = {} + + for epoch in range(epoch_to_start, num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + metadata["ss_epoch"] = str(epoch + 1) + + accelerator.unwrap_model(network).on_epoch_start(transformer) + + for step, batch in enumerate(train_dataloader): + latents, llm_embeds, llm_mask, clip_embeds = batch + bsz = latents.shape[0] + current_step.value = global_step + + with accelerator.accumulate(training_model): + accelerator.unwrap_model(network).on_step_start() + + latents = latents * vae_module.SCALING_FACTOR + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # calculate model input and timesteps + noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps( + args, noise, latents, noise_scheduler, accelerator.device, dit_dtype + ) + + weighting = compute_loss_weighting_for_sd3( + args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype + ) + + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + guidance_vec.requires_grad_(True) + + pos_emb_shape = latents.shape[1:] + if pos_emb_shape not in pos_embed_cache: + freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(transformer, latents.shape[2:]) + # freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype) + # freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype) + pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin) + else: + freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape] + + # call DiT + latents = latents.to(device=accelerator.device, dtype=network_dtype) + noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype) + # timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype) + # llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype) + # llm_mask = llm_mask.to(device=accelerator.device) + # clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype) + with accelerator.autocast(): + model_pred = transformer( + noisy_model_input, + timesteps, + text_states=llm_embeds, + text_mask=llm_mask, + text_states_2=clip_embeds, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + guidance=guidance_vec, + return_dict=False, + ) + + # flow matching loss + target = noise - latents + + loss = torch.nn.functional.mse_loss(model_pred.to(network_dtype), target, reduction="none") + + if weighting is not None: + loss = loss * weighting + # loss = loss.mean([1, 2, 3]) + # # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients: + # self.all_reduce_network(accelerator, network) # sync DDP grad manually + state = accelerate.PartialState() + if state.distributed_type != accelerate.DistributedType.NO: + for param in network.parameters(): + if param.grad is not None: + param.grad = accelerator.reduce(param.grad, reduction="mean") + + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if args.scale_weight_norms: + keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( + args.scale_weight_norms, accelerator.device + ) + max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # to avoid calling optimizer_eval_fn() too frequently, we call it only when we need to sample images or save the model + should_sampling = should_sample_images(args, global_step, epoch=None) + should_saving = args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 + + if should_sampling or should_saving: + optimizer_eval_fn() + if should_sampling: + sample_images(accelerator, args, None, global_step, vae, transformer, sample_parameters, dit_dtype) + + if should_saving: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) + + if args.save_state: + train_utils.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_utils.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no) + remove_model(remove_ckpt_name) + optimizer_train_fn() + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.scale_weight_norms: + progress_bar.set_postfix(**{**max_mean_logs, **logs}) + + if len(accelerator.trackers) > 0: + logs = self.generate_step_logs( + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + ) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # save model at the end of epoch if needed + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) + + remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + sample_images(accelerator, args, epoch + 1, global_step, vae, transformer, sample_parameters, dit_dtype) + optimizer_train_fn() + + # end of epoch + + # metadata["ss_epoch"] = str(num_train_epochs) + metadata["ss_training_finished_at"] = str(time.time()) + + if is_main_process: + network = accelerator.unwrap_model(network) + + accelerator.end_training() + optimizer_eval_fn() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_utils.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_utils.get_last_ckpt_name(args.output_name) + save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + def int_or_float(value): + if value.endswith("%"): + try: + return float(value[:-1]) / 100.0 + except ValueError: + raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") + try: + float_value = float(value) + if float_value >= 1 and float_value.is_integer(): + return int(value) + return float(value) + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") + + parser = argparse.ArgumentParser() + + # general settings + parser.add_argument( + "--config_file", + type=str, + default=None, + help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す", + ) + parser.add_argument( + "--dataset_config", + type=pathlib.Path, + default=None, + help="config file for dataset / データセットの設定ファイル", + ) + + # training settings + parser.add_argument( + "--sdpa", + action="store_true", + help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", + ) + parser.add_argument( + "--flash_attn", + action="store_true", + help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要", + ) + parser.add_argument( + "--sage_attn", + action="store_true", + help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要", + ) + parser.add_argument( + "--xformers", + action="store_true", + help="use xformers for CrossAttention, requires xformers / CrossAttentionにxformersを使う、xformersが必要", + ) + parser.add_argument( + "--split_attn", + action="store_true", + help="use split attention for attention calculation (split batch size=1, affects memory usage and speed)" + " / attentionを分割して計算する(バッチサイズ=1に分割、メモリ使用量と速度に影響)", + ) + + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument( + "--max_train_epochs", + type=int, + default=None, + help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)", + ) + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)", + ) + parser.add_argument( + "--persistent_data_loader_workers", + action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument( + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", + ) + + parser.add_argument( + "--logging_dir", + type=str, + default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", + ) + parser.add_argument( + "--log_with", + type=str, + default=None, + choices=["tensorboard", "wandb", "all"], + help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", + ) + parser.add_argument( + "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列" + ) + parser.add_argument( + "--log_tracker_name", + type=str, + default=None, + help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", + ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前", + ) + parser.add_argument( + "--log_tracker_config", + type=str, + default=None, + help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス", + ) + parser.add_argument( + "--wandb_api_key", + type=str, + default=None, + help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", + ) + parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する") + + parser.add_argument( + "--ddp_timeout", + type=int, + default=None, + help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", + ) + parser.add_argument( + "--ddp_gradient_as_bucket_view", + action="store_true", + help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする", + ) + parser.add_argument( + "--ddp_static_graph", + action="store_true", + help="enable static_graph for DDP / DDPでstatic_graphを有効にする", + ) + + parser.add_argument( + "--sample_every_n_steps", + type=int, + default=None, + help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する", + ) + parser.add_argument( + "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する" + ) + parser.add_argument( + "--sample_every_n_epochs", + type=int, + default=None, + help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", + ) + parser.add_argument( + "--sample_prompts", + type=str, + default=None, + help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル", + ) + + # optimizer and lr scheduler settings + parser.add_argument( + "--optimizer_type", + type=str, + default="", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. " + "Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ", + ) + parser.add_argument( + "--optimizer_args", + type=str, + default=None, + nargs="*", + help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', + ) + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない", + ) + + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int_or_float, + default=0, + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps" + " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", + ) + parser.add_argument( + "--lr_decay_steps", + type=int_or_float, + default=0, + help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps" + " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", + ) + parser.add_argument( + "--lr_scheduler_num_cycles", + type=int, + default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数", + ) + parser.add_argument( + "--lr_scheduler_power", + type=float, + default=1, + help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", + ) + parser.add_argument( + "--lr_scheduler_timescale", + type=int, + default=None, + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", + ) + parser.add_argument( + "--lr_scheduler_min_lr_ratio", + type=float, + default=None, + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + ) + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") + parser.add_argument( + "--lr_scheduler_args", + type=str, + default=None, + nargs="*", + help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")', + ) + + # model settings + parser.add_argument("--dit", type=str, help="DiT checkpoint path / DiTのチェックポイントのパス") + parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16") + parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") + parser.add_argument( + "--vae_tiling", + action="store_true", + help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled." + " / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。", + ) + parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") + parser.add_argument( + "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" + ) + parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ") + parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ") + parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16") + parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う") + parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") + # parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + # parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する") + + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX", + ) + parser.add_argument( + "--img_in_txt_in_offloading", + action="store_true", + help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする", + ) + + # parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers") + parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.") + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=1.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。', + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"], + help="weighting scheme for timestep distribution. Default is none" + " / タイムステップ分布の重み付けスキーム、デフォルトはnone", + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール", + ) + parser.add_argument( + "--min_timestep", + type=int, + default=None, + help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ", + ) + parser.add_argument( + "--max_timestep", + type=int, + default=None, + help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", + ) + + parser.add_argument( + "--show_timesteps", + type=str, + default=None, + choices=["image", "console"], + help="show timesteps in image or console, and return to console / タイムステップを画像またはコンソールに表示し、コンソールに戻る", + ) + + # network settings + parser.add_argument( + "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) + parser.add_argument( + "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" + ) + parser.add_argument( + "--network_dim", + type=int, + default=None, + help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", + ) + parser.add_argument( + "--network_alpha", + type=float, + default=1, + help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)", + ) + parser.add_argument( + "--network_dropout", + type=float, + default=None, + help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--training_comment", + type=str, + default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", + ) + parser.add_argument( + "--dim_from_weights", + action="store_true", + help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", + ) + parser.add_argument( + "--scale_weight_norms", + type=float, + default=None, + help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)", + ) + parser.add_argument( + "--base_weights", + type=str, + default=None, + nargs="*", + help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル", + ) + parser.add_argument( + "--base_weights_multiplier", + type=float, + default=None, + nargs="*", + help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", + ) + + # save and load settings + parser.add_argument( + "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ" + ) + parser.add_argument( + "--output_name", + type=str, + default=None, + help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名", + ) + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") + + parser.add_argument( + "--save_every_n_epochs", + type=int, + default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する", + ) + parser.add_argument( + "--save_every_n_steps", + type=int, + default=None, + help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", + ) + parser.add_argument( + "--save_last_n_epochs", + type=int, + default=None, + help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)", + ) + parser.add_argument( + "--save_last_n_epochs_state", + type=int, + default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)", + ) + parser.add_argument( + "--save_last_n_steps", + type=int, + default=None, + help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)", + ) + parser.add_argument( + "--save_last_n_steps_state", + type=int, + default=None, + help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)", + ) + parser.add_argument( + "--save_state", + action="store_true", + help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する", + ) + parser.add_argument( + "--save_state_on_train_end", + action="store_true", + help="save training state (including optimizer states etc.) on train end even if --save_state is not specified" + " / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する", + ) + + # SAI Model spec + parser.add_argument( + "--metadata_title", + type=str, + default=None, + help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", + ) + parser.add_argument( + "--metadata_author", + type=str, + default=None, + help="author name for model metadata / メタデータに書き込まれるモデル作者名", + ) + parser.add_argument( + "--metadata_description", + type=str, + default=None, + help="description for model metadata / メタデータに書き込まれるモデル説明", + ) + parser.add_argument( + "--metadata_license", + type=str, + default=None, + help="license for model metadata / メタデータに書き込まれるモデルライセンス", + ) + parser.add_argument( + "--metadata_tags", + type=str, + default=None, + help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", + ) + + # huggingface settings + parser.add_argument( + "--huggingface_repo_id", + type=str, + default=None, + help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名", + ) + parser.add_argument( + "--huggingface_repo_type", + type=str, + default=None, + help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類", + ) + parser.add_argument( + "--huggingface_path_in_repo", + type=str, + default=None, + help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス", + ) + parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン") + parser.add_argument( + "--huggingface_repo_visibility", + type=str, + default=None, + help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)", + ) + parser.add_argument( + "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する" + ) + parser.add_argument( + "--resume_from_huggingface", + action="store_true", + help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", + ) + parser.add_argument( + "--async_upload", + action="store_true", + help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする", + ) + + return parser + + +def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): + if not args.config_file: + return args + + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + + if not os.path.exists(config_path): + logger.info(f"{config_path} not found.") + exit(1) + + logger.info(f"Loading settings from {config_path}...") + with open(config_path, "r", encoding="utf-8") as f: + config_dict = toml.load(f) + + # combine all sections into one + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + # if value is not dict, save key and value as is + if not isinstance(section_dict, dict): + ignore_nesting_dict[section_name] = section_dict + continue + + # if value is dict, save all key and value into one dict + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = os.path.splitext(args.config_file)[0] + logger.info(args.config_file) + + return args + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = read_config_from_file(args, parser) + + trainer = NetworkTrainer() + trainer.train(args) diff --git a/i1111.py b/i1111.py new file mode 100644 index 0000000000000000000000000000000000000000..86886b697b8f37005f345475c2f7a23ae26ca0ae --- /dev/null +++ b/i1111.py @@ -0,0 +1,4960 @@ +import gradio as gr +from gradio import update as gr_update +import subprocess +import threading +import time +import re +import os +import random +import tiktoken +import sys +import ffmpeg +from typing import List, Tuple, Optional, Generator, Dict +import json +from gradio import themes +from gradio.themes.utils import colors +import subprocess +from PIL import Image +import math +import cv2 +import glob +import shutil +from pathlib import Path +import logging +from datetime import datetime +from tqdm import tqdm + + +# Add global stop event +stop_event = threading.Event() + +logger = logging.getLogger(__name__) + +def process_hunyuani2v_video( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + lora1: str = "", + lora2: str = "", + lora3: str = "", + lora4: str = "", + lora1_multiplier: float = 1.0, + lora2_multiplier: float = 1.0, + lora3_multiplier: float = 1.0, + lora4_multiplier: float = 1.0, + video_path: Optional[str] = None, + image_path: Optional[str] = None, + strength: Optional[float] = None, + negative_prompt: Optional[str] = None, + embedded_cfg_scale: Optional[float] = None, + split_uncond: Optional[bool] = None, + guidance_scale: Optional[float] = None, + use_fp8: bool = True, + clip_vision_path: Optional[str] = None, + i2v_stability: bool = False, + fp8_fast: bool = False, + compile_model: bool = False, + compile_backend: str = "inductor", + compile_mode: str = "max-autotune-no-cudagraphs", + compile_dynamic: bool = False, + compile_fullgraph: bool = False +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate a single video with the hunyuani2v script with updated parameters""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + # Set defaults for hunyuani2v specific parameters + if is_skyreels: + # Force certain parameters for SkyReels + if negative_prompt is None: + negative_prompt = "" + if embedded_cfg_scale is None: + embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels + if split_uncond is None: + split_uncond = True + if guidance_scale is None: + guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided + + else: + embedded_cfg_scale = cfg_scale + + if os.path.isabs(model): + model_path = model + else: + model_path = os.path.normpath(os.path.join(dit_folder, model)) + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + env["BATCH_RUN_ID"] = f"{time.time()}" + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) + if batch_size > 1: # Only modify seed for batch generation + current_seed = (seed + batch_id * 100003) % (2**32) + else: + current_seed = seed + + clear_cuda_cache() + + # Now use hv_generate_video_with_hunyuani2v.py instead + command = [ + sys.executable, + "hv_generate_video_with_hunyuani2v.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128" + ] + + if use_fp8: + command.append("--fp8") + + # Add new parameters specific to hunyuani2v script + if clip_vision_path: + command.extend(["--clip_vision_path", clip_vision_path]) + + if i2v_stability: + command.append("--i2v_stability") + + if fp8_fast: + command.append("--fp8_fast") + + if compile_model: + command.append("--compile") + command.extend([ + "--compile_args", + compile_backend, + compile_mode, + str(compile_dynamic).lower(), + str(compile_fullgraph).lower() + ]) + + # Add negative prompt and embedded cfg scale + command.extend(["--guidance_scale", str(guidance_scale)]) + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + if split_uncond: + command.append("--split_uncond") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + if use_split_attn: + command.append("--split_attn") + + # Handle input paths + if video_path: + command.extend(["--video_path", video_path]) + if strength is not None: + command.extend(["--strength", str(strength)]) + elif image_path: + command.extend(["--image_path", image_path]) + # Only add strength parameter for non-SkyReels I2V models + # SkyReels I2V doesn't use strength parameter for image-to-video generation + if strength is not None and not is_skyreels_i2v: + command.extend(["--strength", str(strength)]) + + print(f"{command}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "model": model, + "vae": vae, + "te1": te1, + "te2": te2, + "save_path": save_path, + "flow_shift": flow_shift, + "cfg_scale": cfg_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "input_video": video_path if video_path else None, + "input_image": image_path if image_path else None, + "strength": strength, + "negative_prompt": negative_prompt, + "embedded_cfg_scale": embedded_cfg_scale, + "clip_vision_path": clip_vision_path, + "i2v_stability": i2v_stability, + "fp8_fast": fp8_fast, + "compile_model": compile_model + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +# Now let's create a new batch processing function that uses the hunyuani2v function +def process_hunyuani2v_batch( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + *args +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Process a batch of videos using the hunyuani2v script""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract additional arguments + num_lora_weights = 4 + lora_weights = args[:num_lora_weights] + lora_multipliers = args[num_lora_weights:num_lora_weights*2] + + # New parameters for hunyuani2v + # Base parameter list index after lora weights and multipliers + base_idx = num_lora_weights*2 + + # Extract parameters + input_path = args[base_idx] if len(args) > base_idx else None + strength = float(args[base_idx+1]) if len(args) > base_idx+1 and args[base_idx+1] is not None else None + negative_prompt = str(args[base_idx+2]) if len(args) > base_idx+2 and args[base_idx+2] is not None else None + guidance_scale = float(args[base_idx+3]) if len(args) > base_idx+3 and args[base_idx+3] is not None else cfg_scale + split_uncond = bool(args[base_idx+4]) if len(args) > base_idx+4 else None + use_fp8 = bool(args[base_idx+5]) if len(args) > base_idx+5 else True + + # New hunyuani2v parameters + clip_vision_path = str(args[base_idx+6]) if len(args) > base_idx+6 and args[base_idx+6] is not None else None + i2v_stability = bool(args[base_idx+7]) if len(args) > base_idx+7 else False + fp8_fast = bool(args[base_idx+8]) if len(args) > base_idx+8 else False + compile_model = bool(args[base_idx+9]) if len(args) > base_idx+9 else False + compile_backend = str(args[base_idx+10]) if len(args) > base_idx+10 and args[base_idx+10] is not None else "inductor" + compile_mode = str(args[base_idx+11]) if len(args) > base_idx+11 and args[base_idx+11] is not None else "max-autotune-no-cudagraphs" + compile_dynamic = bool(args[base_idx+12]) if len(args) > base_idx+12 else False + compile_fullgraph = bool(args[base_idx+13]) if len(args) > base_idx+13 else False + + embedded_cfg_scale = cfg_scale + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Handle different input types + video_path = None + image_path = None + + if input_path: + is_image = False + lower_path = input_path.lower() + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') + is_image = any(lower_path.endswith(ext) for ext in image_extensions) + + if is_image: + image_path = input_path + else: + video_path = input_path + + # Prepare arguments for process_hunyuani2v_video + current_seed = seed + i if seed != -1 and batch_size > 1 else seed if seed != -1 else -1 + + hunyuani2v_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + hunyuani2v_args.extend(lora_weights) + hunyuani2v_args.extend(lora_multipliers) + hunyuani2v_args.extend([ + video_path, image_path, strength, negative_prompt, embedded_cfg_scale, + split_uncond, guidance_scale, use_fp8, clip_vision_path, i2v_stability, + fp8_fast, compile_model, compile_backend, compile_mode, compile_dynamic, compile_fullgraph + ]) + + for videos, status, progress in process_hunyuani2v_video(*hunyuani2v_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def variance_of_laplacian(image): + """ + Compute the variance of the Laplacian of the image. + Higher variance indicates a sharper image. + """ + return cv2.Laplacian(image, cv2.CV_64F).var() + +def extract_sharpest_frame(video_path, frames_to_check=30): + """ + Extract the sharpest frame from the last N frames of the video. + + Args: + video_path (str): Path to the video file + frames_to_check (int): Number of frames from the end to check + + Returns: + tuple: (temp_image_path, frame_number, sharpness_score) + """ + print(f"\n=== Extracting sharpest frame from the last {frames_to_check} frames ===") + print(f"Input video path: {video_path}") + + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None, None, None + + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None, None, None + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + print(f"Total frames detected: {total_frames}, FPS: {fps:.2f}") + + if total_frames < 1: + print("❌ Error: Video contains 0 frames") + return None, None, None + + # Determine how many frames to check (the last N frames) + if frames_to_check > total_frames: + frames_to_check = total_frames + start_frame = 0 + else: + start_frame = total_frames - frames_to_check + + print(f"Checking frames {start_frame} to {total_frames-1}") + + # Find the sharpest frame + sharpest_frame = None + max_sharpness = -1 + sharpest_frame_number = -1 + + # Set starting position + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + # Process frames with a progress bar + with tqdm(total=frames_to_check, desc="Finding sharpest frame") as pbar: + frame_idx = start_frame + while frame_idx < total_frames: + ret, frame = cap.read() + if not ret: + break + + # Convert to grayscale and calculate sharpness + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + sharpness = variance_of_laplacian(gray) + + # Update if this is the sharpest frame so far + if sharpness > max_sharpness: + max_sharpness = sharpness + sharpest_frame = frame.copy() + sharpest_frame_number = frame_idx + + frame_idx += 1 + pbar.update(1) + + cap.release() + + if sharpest_frame is None: + print("❌ Error: Failed to find a sharp frame") + return None, None, None + + # Prepare output path + temp_dir = os.path.abspath("temp_frames") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, f"sharpest_frame_{os.path.basename(video_path)}.png") + print(f"Saving frame to: {temp_path}") + + # Write and verify + if not cv2.imwrite(temp_path, sharpest_frame): + print("❌ Error: Failed to write frame to file") + return None, None, None + + if not os.path.exists(temp_path): + print("❌ Error: Output file not created") + return None, None, None + + # Calculate frame time in seconds + frame_time = sharpest_frame_number / fps + + print(f"✅ Extracted sharpest frame: {sharpest_frame_number} (at {frame_time:.2f}s) with sharpness {max_sharpness:.2f}") + return temp_path, sharpest_frame_number, max_sharpness + + except Exception as e: + print(f"❌ Unexpected error: {str(e)}") + return None, None, None + finally: + if 'cap' in locals(): + cap.release() + +def trim_video_to_frame(video_path, frame_number, output_dir="outputs"): + """ + Trim video up to the specified frame and save as a new video. + + Args: + video_path (str): Path to the video file + frame_number (int): Frame number to trim to + output_dir (str): Directory to save the trimmed video + + Returns: + str: Path to the trimmed video file + """ + print(f"\n=== Trimming video to frame {frame_number} ===") + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None + + try: + # Get video information + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None + + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + + # Calculate time in seconds + time_seconds = frame_number / fps + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Generate output filename + timestamp = f"{int(time_seconds)}s" + base_name = Path(video_path).stem + output_file = os.path.join(output_dir, f"{base_name}_trimmed_to_{timestamp}.mp4") + + # Use ffmpeg to trim the video + ( + ffmpeg + .input(video_path) + .output(output_file, to=time_seconds, c="copy") + .global_args('-y') # Overwrite output files + .run(quiet=True) + ) + + if not os.path.exists(output_file): + print("❌ Error: Failed to create trimmed video") + return None + + print(f"✅ Successfully trimmed video to {time_seconds:.2f}s: {output_file}") + return output_file + + except Exception as e: + print(f"❌ Error trimming video: {str(e)}") + return None + +def send_sharpest_frame_handler(gallery, selected_idx, frames_to_check=30): + """ + Extract the sharpest frame from the last N frames of the selected video + + Args: + gallery: Gradio gallery component with videos + selected_idx: Index of the selected video + frames_to_check: Number of frames from the end to check + + Returns: + tuple: (image_path, video_path, frame_number, sharpness) + """ + if gallery is None or not gallery: + return None, None, None, "No videos in gallery" + + if selected_idx is None and len(gallery) == 1: + selected_idx = 0 + + if selected_idx is None or selected_idx >= len(gallery): + return None, None, None, "No video selected" + + # Get the video path + item = gallery[selected_idx] + if isinstance(item, tuple): + video_path = item[0] + elif isinstance(item, dict): + video_path = item.get('name') or item.get('data') + else: + video_path = str(item) + + # Extract the sharpest frame + image_path, frame_number, sharpness = extract_sharpest_frame(video_path, frames_to_check) + + if image_path is None: + return None, None, None, "Failed to extract sharpest frame" + + return image_path, video_path, frame_number, f"Extracted frame {frame_number} with sharpness {sharpness:.2f}" + +def trim_and_prepare_for_extension(video_path, frame_number, save_path="outputs"): + """ + Trim the video to the specified frame and prepare for extension. + + Args: + video_path: Path to the video file + frame_number: Frame number to trim to + save_path: Directory to save the trimmed video + + Returns: + tuple: (trimmed_video_path, status_message) + """ + if not video_path or not os.path.exists(video_path): + return None, "No video selected or video file does not exist" + + if frame_number is None: + return None, "No frame number provided, please extract sharpest frame first" + + # Trim the video + trimmed_video = trim_video_to_frame(video_path, frame_number, save_path) + + if trimmed_video is None: + return None, "Failed to trim video" + + return trimmed_video, f"Video trimmed to frame {frame_number} and ready for extension" + +def send_last_frame_handler(gallery, selected_idx): + """Handle sending last frame to input with better error handling""" + if gallery is None or not gallery: + return None, None + + if selected_idx is None and len(gallery) == 1: + selected_idx = 0 + + if selected_idx is None or selected_idx >= len(gallery): + return None, None + + # Get the frame and video path + frame = handle_last_frame_transfer(gallery, selected_idx) + video_path = None + + if selected_idx < len(gallery): + item = gallery[selected_idx] + video_path = parse_video_path(item) + + return frame, video_path + +def extract_last_frame(video_path: str) -> Optional[str]: + """Extract last frame from video and return temporary image path with error handling""" + print(f"\n=== Starting frame extraction ===") + print(f"Input video path: {video_path}") + + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None + + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"Total frames detected: {total_frames}") + + if total_frames < 1: + print("❌ Error: Video contains 0 frames") + return None + + # Extract last frame + cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) + success, frame = cap.read() + + if not success or frame is None: + print("❌ Error: Failed to read last frame") + return None + + # Prepare output path + temp_dir = os.path.abspath("temp_frames") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, f"last_frame_{os.path.basename(video_path)}.png") + print(f"Saving frame to: {temp_path}") + + # Write and verify + if not cv2.imwrite(temp_path, frame): + print("❌ Error: Failed to write frame to file") + return None + + if not os.path.exists(temp_path): + print("❌ Error: Output file not created") + return None + + print("✅ Frame extraction successful") + return temp_path + + except Exception as e: + print(f"❌ Unexpected error: {str(e)}") + return None + finally: + if 'cap' in locals(): + cap.release() + +def handle_last_frame_transfer(gallery: list, selected_idx: int) -> Optional[str]: + """Improved frame transfer with video input validation""" + try: + if gallery is None or not gallery: + raise ValueError("No videos generated yet") + + if selected_idx is None: + # Auto-select last generated video if batch_size=1 + if len(gallery) == 1: + selected_idx = 0 + else: + raise ValueError("Please select a video first") + + if selected_idx >= len(gallery): + raise ValueError("Invalid selection index") + + item = gallery[selected_idx] + + # Video file existence check + video_path = parse_video_path(item) + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file missing: {video_path}") + + return extract_last_frame(video_path) + + except Exception as e: + print(f"Frame transfer failed: {str(e)}") + return None + +def parse_video_path(item) -> str: + """Parse different gallery item formats""" + if isinstance(item, tuple): + return item[0] + elif isinstance(item, dict): + return item.get('name') or item.get('data') + return str(item) + +def get_random_image_from_folder(folder_path): + """Get a random image from the specified folder""" + if not os.path.isdir(folder_path): + return None, f"Error: {folder_path} is not a valid directory" + + # Get all image files in the folder + image_files = [] + for ext in ('*.jpg', '*.jpeg', '*.png', '*.bmp', '*.webp'): + image_files.extend(glob.glob(os.path.join(folder_path, ext))) + for ext in ('*.JPG', '*.JPEG', '*.PNG', '*.BMP', '*.WEBP'): + image_files.extend(glob.glob(os.path.join(folder_path, ext))) + + if not image_files: + return None, f"Error: No image files found in {folder_path}" + + # Select a random image + random_image = random.choice(image_files) + return random_image, f"Selected: {os.path.basename(random_image)}" + +def resize_image_keeping_aspect_ratio(image_path, max_width, max_height): + """Resize image keeping aspect ratio and ensuring dimensions are divisible by 16""" + try: + img = Image.open(image_path) + width, height = img.size + + # Calculate aspect ratio + aspect_ratio = width / height + + # Calculate new dimensions while maintaining aspect ratio + if width > height: + new_width = min(max_width, width) + new_height = int(new_width / aspect_ratio) + else: + new_height = min(max_height, height) + new_width = int(new_height * aspect_ratio) + + # Make dimensions divisible by 16 + new_width = math.floor(new_width / 16) * 16 + new_height = math.floor(new_height / 16) * 16 + + # Ensure minimum size + new_width = max(16, new_width) + new_height = max(16, new_height) + + # Resize image + resized_img = img.resize((new_width, new_height), Image.LANCZOS) + + # Save to temporary file + temp_path = f"temp_resized_{os.path.basename(image_path)}" + resized_img.save(temp_path) + + return temp_path, (new_width, new_height) + except Exception as e: + return None, f"Error: {str(e)}" +# Function to process a batch of images from a folder +def batch_handler( + use_random, + prompt, negative_prompt, + width, height, + video_length, fps, infer_steps, + seed, flow_shift, guidance_scale, embedded_cfg_scale, + batch_size, input_folder_path, + dit_folder, model, vae, te1, te2, save_path, output_type, attn_mode, + block_swap, exclude_single_blocks, use_split_attn, use_fp8, split_uncond, + lora_folder, *lora_params +): + """Handle both folder-based batch processing and regular batch processing""" + global stop_event + + # Check if this is a SkyReels model that needs special handling + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + + if use_random: + # Random image from folder mode + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Get random image from folder + random_image, status = get_random_image_from_folder(input_folder_path) + if random_image is None: + yield all_videos, f"Error in batch {i+1}: {status}", "" + continue + + # Resize image + resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) + if resized_image is None: + yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" + continue + + # If we have dimensions, update them + local_width, local_height = width, height + if isinstance(size_info, tuple): + local_width, local_height = size_info + progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height}" + else: + progress_text = f"Using image: {os.path.basename(random_image)}" + + yield all_videos.copy(), batch_text, progress_text + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + # Process the image + # For SkyReels models, we need to create a command with dit_in_channels=32 + if is_skyreels_i2v: + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model + + # Extract parameters from lora_params + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + cmd = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(local_height), str(local_width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(embedded_cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128", + "--dit_in_channels", "32", # This is crucial for SkyReels i2v + "--image_path", resized_image # Pass the image directly + ] + + if use_fp8: + cmd.append("--fp8") + + if split_uncond: + cmd.append("--split_uncond") + + if use_split_attn: + cmd.append("--split_attn") + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + if negative_prompt: + cmd.extend(["--negative_prompt", negative_prompt]) + + if guidance_scale is not None: + cmd.extend(["--guidance_scale", str(guidance_scale)]) + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + print(f"Running command: {' '.join(cmd)}") + + # Run the process + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield all_videos, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield all_videos.copy(), f"Processing video {i+1} (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + all_videos.append((str(video_path), f"Seed: {current_seed}")) + else: + # For non-SkyReels models, use the regular process_single_video function + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + single_video_args = [ + prompt, local_width, local_height, 1, video_length, fps, infer_steps, + current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, embedded_cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + single_video_args.extend(lora_weights) + single_video_args.extend(lora_multipliers) + single_video_args.extend([None, resized_image, None, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) + + for videos, status, progress in process_single_video(*single_video_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clean up temporary file + try: + if os.path.exists(resized_image): + os.remove(resized_image) + except: + pass + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # Regular image input - this is the part we need to fix + # When a SkyReels I2V model is used, we need to use the direct command approach + # with dit_in_channels=32 explicitly specified, just like in the folder processing branch + if is_skyreels_i2v: + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract lora parameters + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + extra_args = list(lora_params[num_lora_weights*2:]) if len(lora_params) > num_lora_weights*2 else [] + + # Print extra_args for debugging + print(f"Extra args: {extra_args}") + + # Get input image path from extra args - this is where we need to fix + # In skyreels_generate_btn.click, we're passing skyreels_input which + # should be the image path + image_path = None + if len(extra_args) > 0 and extra_args[0] is not None: + image_path = extra_args[0] + print(f"Image path found in extra_args[0]: {image_path}") + + # If we still don't have an image path, this is a problem + if not image_path: + # Let's try to debug what's happening - in the future, you can remove these + # debug prints once everything works correctly + print("No image path found in extra_args[0]") + print(f"Full lora_params: {lora_params}") + yield [], "Error: No input image provided", "An input image is required for SkyReels I2V models" + return + + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Set up environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model + + # Build the command with dit_in_channels=32 + cmd = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(embedded_cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128", + "--dit_in_channels", "32", # This is crucial for SkyReels i2v + "--image_path", image_path + ] + + if use_fp8: + cmd.append("--fp8") + + if split_uncond: + cmd.append("--split_uncond") + + if use_split_attn: + cmd.append("--split_attn") + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + if negative_prompt: + cmd.extend(["--negative_prompt", negative_prompt]) + + if guidance_scale is not None: + cmd.extend(["--guidance_scale", str(guidance_scale)]) + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + print(f"Running command: {' '.join(cmd)}") + + # Run the process + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield all_videos, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield all_videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + all_videos.append((str(video_path), f"Seed: {current_seed}")) + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # For regular non-SkyReels models, use the original process_batch function + regular_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, guidance_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + yield from process_batch(*(regular_args + list(lora_params))) + +def get_dit_models(dit_folder: str) -> List[str]: + """Get list of available DiT models in the specified folder""" + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + +def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: + """Update both DiT and LoRA dropdowns""" + # Get model lists + dit_models = get_dit_models(dit_folder) + lora_choices = get_lora_options(lora_folder) + + # Current values processing + dit_value = current_values[0] + if dit_value not in dit_models: + dit_value = dit_models[0] if dit_models else None + + weights = current_values[1:5] + multipliers = current_values[5:9] + + results = [gr.update(choices=dit_models, value=dit_value)] + + # Add LoRA updates + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in lora_choices: + weight = "None" + results.extend([ + gr.update(choices=lora_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def extract_video_metadata(video_path: str) -> Dict: + """Extract metadata from video file using ffprobe.""" + cmd = [ + 'ffprobe', + '-v', 'quiet', + '-print_format', 'json', + '-show_format', + video_path + ] + + try: + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + metadata = json.loads(result.stdout.decode('utf-8')) + if 'format' in metadata and 'tags' in metadata['format']: + comment = metadata['format']['tags'].get('comment', '{}') + return json.loads(comment) + return {} + except Exception as e: + print(f"Metadata extraction failed: {str(e)}") + return {} + +def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: + """Map metadata parameters to Gradio components for different tabs""" + mapping = { + 'common': { + 'prompt': ('prompt', 'v2v_prompt'), + 'width': ('width', 'v2v_width'), + 'height': ('height', 'v2v_height'), + 'batch_size': ('batch_size', 'v2v_batch_size'), + 'video_length': ('video_length', 'v2v_video_length'), + 'fps': ('fps', 'v2v_fps'), + 'infer_steps': ('infer_steps', 'v2v_infer_steps'), + 'seed': ('seed', 'v2v_seed'), + 'model': ('model', 'v2v_model'), + 'vae': ('vae', 'v2v_vae'), + 'te1': ('te1', 'v2v_te1'), + 'te2': ('te2', 'v2v_te2'), + 'save_path': ('save_path', 'v2v_save_path'), + 'flow_shift': ('flow_shift', 'v2v_flow_shift'), + 'cfg_scale': ('cfg_scale', 'v2v_cfg_scale'), + 'output_type': ('output_type', 'v2v_output_type'), + 'attn_mode': ('attn_mode', 'v2v_attn_mode'), + 'block_swap': ('block_swap', 'v2v_block_swap') + }, + 'lora': { + 'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]') for i in range(4)], + 'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]') for i in range(4)] + } + } + + results = {} + for param, value in metadata.items(): + # Handle common parameters + if param in mapping['common']: + target = mapping['common'][param][0 if target_tab == 't2v' else 1] + results[target] = value + + # Handle LoRA parameters + if param == 'lora_weights': + for i, weight in enumerate(value[:4]): + target = mapping['lora']['lora_weights'][i][1 if target_tab == 'v2v' else 0] + results[target] = weight + + if param == 'lora_multipliers': + for i, mult in enumerate(value[:4]): + target = mapping['lora']['lora_multipliers'][i][1 if target_tab == 'v2v' else 0] + results[target] = float(mult) + + return results + +def add_metadata_to_video(video_path: str, parameters: dict) -> None: + """Add generation parameters to video metadata using ffmpeg.""" + import json + import subprocess + + # Convert parameters to JSON string + params_json = json.dumps(parameters, indent=2) + + # Temporary output path + temp_path = video_path.replace(".mp4", "_temp.mp4") + + # FFmpeg command to add metadata without re-encoding + cmd = [ + 'ffmpeg', + '-i', video_path, + '-metadata', f'comment={params_json}', + '-codec', 'copy', + temp_path + ] + + try: + # Execute FFmpeg command + subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + # Replace original file with the metadata-enhanced version + os.replace(temp_path, video_path) + except subprocess.CalledProcessError as e: + print(f"Failed to add metadata: {e.stderr.decode()}") + if os.path.exists(temp_path): + os.remove(temp_path) + except Exception as e: + print(f"Error: {str(e)}") + +def count_prompt_tokens(prompt: str) -> int: + enc = tiktoken.get_encoding("cl100k_base") + tokens = enc.encode(prompt) + return len(tokens) + + +def get_lora_options(lora_folder: str = "lora") -> List[str]: + if not os.path.exists(lora_folder): + return ["None"] + lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] + lora_files.sort(key=str.lower) + return ["None"] + lora_files + +def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: + """Transfer selected video and prompt to Video2Video tab""" + if not gallery or evt.index >= len(gallery): + return None, "", selected_index.value + + selected_item = gallery[evt.index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Update the selected index + selected_index.value = evt.index + + return str(video_path), prompt, evt.index + +def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: + """Send the currently selected video to V2V tab""" + if not gallery or selected_index.value is None or selected_index.value >= len(gallery): + return None, "" + + selected_item = gallery[selected_index.value] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return str(video_path), prompt + +def clear_cuda_cache(): + """Clear CUDA cache if available""" + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Optional: synchronize to ensure cache is cleared + torch.cuda.synchronize() + +def wanx_batch_handler( + use_random, + prompt, + negative_prompt, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + batch_size, + input_folder_path, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + *lora_params +): + """Handle both folder-based batch processing and regular processing for WanX""" + global stop_event + + if use_random: + # Random image from folder mode + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Ensure batch_size is treated as an integer + batch_size = int(batch_size) + + # Process each item in the batch separately + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Get random image from folder + random_image, status = get_random_image_from_folder(input_folder_path) + if random_image is None: + yield all_videos, f"Error in batch {i+1}: {status}", "" + continue + + # Resize image + resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) + if resized_image is None: + yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" + continue + + # Use the dimensions returned from the resize function + local_width, local_height = width, height # Default fallback + if isinstance(size_info, tuple): + local_width, local_height = size_info + progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height} (maintaining aspect ratio)" + else: + progress_text = f"Using image: {os.path.basename(random_image)}" + + yield all_videos.copy(), batch_text, progress_text + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + # Extract LoRA weights and multipliers + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + # Generate video for this image - one at a time + for videos, status, progress in wanx_generate_video( + prompt, + negative_prompt, + resized_image, + local_width, + local_height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + current_seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + *lora_weights, + *lora_multipliers + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clean up temporary file + try: + if os.path.exists(resized_image): + os.remove(resized_image) + except: + pass + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # For non-random mode, if batch_size > 1, we need to process multiple times + # with the same input image but different seeds + if int(batch_size) > 1: + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract LoRA weights and multipliers and input image + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None + + # Process each batch item + for i in range(int(batch_size)): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Generate a single video with the current seed + for videos, status, progress in wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + current_seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + *lora_weights, + *lora_multipliers + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # Single image, single generation - use existing function + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None + + yield from wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + *lora_weights, + *lora_multipliers + ) + +def process_single_video( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + lora1: str = "", + lora2: str = "", + lora3: str = "", + lora4: str = "", + lora1_multiplier: float = 1.0, + lora2_multiplier: float = 1.0, + lora3_multiplier: float = 1.0, + lora4_multiplier: float = 1.0, + video_path: Optional[str] = None, + image_path: Optional[str] = None, + strength: Optional[float] = None, + negative_prompt: Optional[str] = None, + embedded_cfg_scale: Optional[float] = None, + split_uncond: Optional[bool] = None, + guidance_scale: Optional[float] = None, + use_fp8: bool = True +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate a single video with the given parameters""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + if is_skyreels: + # Force certain parameters for SkyReels + if negative_prompt is None: + negative_prompt = "" + if embedded_cfg_scale is None: + embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels + if split_uncond is None: + split_uncond = True + if guidance_scale is None: + guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided + + # Determine the input channels based on model type + if is_skyreels_i2v: + dit_in_channels = 32 # SkyReels I2V uses 32 channels + else: + dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) + else: + dit_in_channels = 16 # Regular Hunyuan models use 16 channels + embedded_cfg_scale = cfg_scale + + if os.path.isabs(model): + model_path = model + else: + model_path = os.path.normpath(os.path.join(dit_folder, model)) + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + env["BATCH_RUN_ID"] = f"{time.time()}" + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) + if batch_size > 1: # Only modify seed for batch generation + current_seed = (seed + batch_id * 100003) % (2**32) + else: + current_seed = seed + + clear_cuda_cache() + + command = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128" + ] + + if use_fp8: + command.append("--fp8") + + # Add negative prompt and embedded cfg scale for SkyReels + if is_skyreels: + command.extend(["--dit_in_channels", str(dit_in_channels)]) + command.extend(["--guidance_scale", str(guidance_scale)]) + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + if split_uncond: + command.append("--split_uncond") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + if use_split_attn: + command.append("--split_attn") + + # Handle input paths + if video_path: + command.extend(["--video_path", video_path]) + if strength is not None: + command.extend(["--strength", str(strength)]) + elif image_path: + command.extend(["--image_path", image_path]) + # Only add strength parameter for non-SkyReels I2V models + # SkyReels I2V doesn't use strength parameter for image-to-video generation + if strength is not None and not is_skyreels_i2v: + command.extend(["--strength", str(strength)]) + + print(f"{command}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "model": model, + "vae": vae, + "te1": te1, + "te2": te2, + "save_path": save_path, + "flow_shift": flow_shift, + "cfg_scale": cfg_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "input_video": video_path if video_path else None, + "input_image": image_path if image_path else None, + "strength": strength, + "negative_prompt": negative_prompt if is_skyreels else None, + "embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +# The issue is in the process_batch function, in the section that handles different input types +# Here's the corrected version of that section: + +def process_batch( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + *args +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Process a batch of videos using Gradio's queue""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract additional arguments + num_lora_weights = 4 + lora_weights = args[:num_lora_weights] + lora_multipliers = args[num_lora_weights:num_lora_weights*2] + extra_args = args[num_lora_weights*2:] + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + # Handle input paths and additional parameters + input_path = extra_args[0] if extra_args else None + strength = float(extra_args[1]) if len(extra_args) > 1 else None + + # Get use_fp8 flag (it should be the last parameter) + use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True + + # Get SkyReels specific parameters if applicable + if is_skyreels: + # Always set embedded_cfg_scale to 1.0 for SkyReels models + embedded_cfg_scale = 1.0 + + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" + # Use cfg_scale for guidance_scale parameter + guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale + split_uncond = True if len(extra_args) > 4 and extra_args[4] else False + else: + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None + guidance_scale = cfg_scale + embedded_cfg_scale = cfg_scale + split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Handle different input types + video_path = None + image_path = None + + if input_path: + # Check if it's an image file (common image extensions) + is_image = False + lower_path = input_path.lower() + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') + is_image = any(lower_path.endswith(ext) for ext in image_extensions) + + # Only use image_path for SkyReels I2V models and actual image files + if is_skyreels_i2v and is_image: + image_path = input_path + else: + video_path = input_path + + # Prepare arguments for process_single_video + single_video_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + single_video_args.extend(lora_weights) + single_video_args.extend(lora_multipliers) + single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) + + for videos, status, progress in process_single_video(*single_video_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def update_wanx_image_dimensions(image): + """Update dimensions from uploaded image""" + if image is None: + return "", gr.update(value=832), gr.update(value=480) + img = Image.open(image) + w, h = img.size + w = (w // 32) * 32 + h = (h // 32) * 32 + return f"{w}x{h}", w, h + +def calculate_wanx_width(height, original_dims): + """Calculate width based on height maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 32) * 32 + return gr.update(value=new_width) + +def calculate_wanx_height(width, original_dims): + """Calculate height based on width maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 32) * 32 + return gr.update(value=new_height) + +def update_wanx_from_scale(scale, original_dims): + """Update dimensions based on scale percentage""" + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 32) * 32 + new_h = math.floor((orig_h * scale / 100) / 32) * 32 + return gr.update(value=new_w), gr.update(value=new_h) + +def recommend_wanx_flow_shift(width, height): + """Get recommended flow shift value based on dimensions""" + recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 + return gr.update(value=recommended_shift) + +def handle_wanx_gallery_select(evt: gr.SelectData, gallery) -> tuple: + """Track selected index and video path when gallery item is clicked""" + if gallery is None: + return None, None + + if evt.index >= len(gallery): + return None, None + + selected_item = gallery[evt.index] + video_path = None + + # Extract the video path based on the item type + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + return evt.index, video_path + +def wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0 +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate video with WanX model (supports both i2v and t2v)""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + current_seed = seed + + # Check if we need input image (required for i2v, not for t2v) + if "i2v" in task and not input_image: + yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" + return + + # Prepare environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + clear_cuda_cache() + + command = [ + sys.executable, + "wan_generate_video.py", + "--task", task, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--dit", dit_path, + "--vae", vae_path, + "--t5", t5_path, + "--sample_solver", sample_solver + ] + + # Add image path only for i2v task and if input image is provided + if "i2v" in task and input_image: + command.extend(["--image_path", input_image]) + command.extend(["--clip", clip_path]) # CLIP is only needed for i2v + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + + if fp8: + command.append("--fp8") + + if fp8_t5: + command.append("--fp8_t5") + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + print(f"Running: {' '.join(command)}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "task": task, + "flow_shift": flow_shift, + "guidance_scale": guidance_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "input_image": input_image if "i2v" in task else None + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +def send_wanx_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + guidance_scale: float, + negative_prompt: str +) -> Tuple: + """Send the selected WanX video to Video2Video tab""" + if gallery is None or not gallery: + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + # If no selection made but we have videos, use the first one + if selected_index is None and len(gallery) > 0: + selected_index = 0 + + if selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + # Clean up path for Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Make sure it's a string + video_path = str(video_path) + + return (video_path, prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def wanx_generate_video_batch( + prompt, + negative_prompt, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0, + batch_size=1, + input_image=None # Make input_image optional and place it at the end +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate videos with WanX with support for batches""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Process each item in the batch + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Generate a single video using the existing function + for videos, status, progress in wanx_generate_video( + prompt, negative_prompt, input_image, width, height, + video_length, fps, infer_steps, flow_shift, guidance_scale, + current_seed, task, dit_path, vae_path, t5_path, clip_path, + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_t5, + lora_folder, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def update_wanx_t2v_dimensions(size): + """Update width and height based on selected size""" + width, height = map(int, size.split('*')) + return gr.update(value=width), gr.update(value=height) + +def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when gallery item is clicked""" + return evt.index + +def send_wanx_t2v_to_v2v( + gallery, prompt, selected_index, width, height, video_length, + fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt +) -> Tuple: + """Send the selected WanX T2V video to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def prepare_for_batch_extension(input_img, base_video, batch_size): + """Prepare inputs for batch video extension""" + if input_img is None: + return None, None, batch_size, "No input image found", "" + + if base_video is None: + return input_img, None, batch_size, "No base video selected for extension", "" + + return input_img, base_video, batch_size, "Preparing batch extension...", f"Will create {batch_size} variations of extended video" + +def concat_batch_videos(base_video_path, generated_videos, save_path, original_video_path=None): + """Concatenate multiple generated videos with the base video""" + if not base_video_path: + return [], "No base video provided" + + if not generated_videos or len(generated_videos) == 0: + return [], "No new videos generated" + + # Create output directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + + # Track all extended videos + extended_videos = [] + + # For each generated video, create an extended version + for i, video_item in enumerate(generated_videos): + try: + # Extract video path from gallery item + if isinstance(video_item, tuple): + new_video_path = video_item[0] + seed_info = video_item[1] if len(video_item) > 1 else "" + elif isinstance(video_item, dict): + new_video_path = video_item.get("name", video_item.get("data", None)) + seed_info = "" + else: + new_video_path = video_item + seed_info = "" + + if not new_video_path or not os.path.exists(new_video_path): + print(f"Skipping missing video: {new_video_path}") + continue + + # Create unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + # Extract seed from seed_info if available + seed_match = re.search(r"Seed: (\d+)", seed_info) + seed_part = f"_seed{seed_match.group(1)}" if seed_match else f"_{i}" + + output_filename = f"extended_{timestamp}{seed_part}_{Path(base_video_path).stem}.mp4" + output_path = os.path.join(save_path, output_filename) + + # Create a temporary file list for ffmpeg + list_file = os.path.join(save_path, f"temp_list_{i}.txt") + with open(list_file, "w") as f: + f.write(f"file '{os.path.abspath(base_video_path)}'\n") + f.write(f"file '{os.path.abspath(new_video_path)}'\n") + + # Run ffmpeg concatenation + command = [ + "ffmpeg", + "-f", "concat", + "-safe", "0", + "-i", list_file, + "-c", "copy", + output_path + ] + + subprocess.run(command, check=True, capture_output=True) + + # Clean up temporary file + if os.path.exists(list_file): + os.remove(list_file) + + # Add to extended videos list if successful + if os.path.exists(output_path): + seed_display = f"Extended {seed_info}" if seed_info else f"Extended video #{i+1}" + extended_videos.append((output_path, seed_display)) + + except Exception as e: + print(f"Error processing video {i}: {str(e)}") + + if not extended_videos: + return [], "Failed to create any extended videos" + + return extended_videos, f"Successfully created {len(extended_videos)} extended videos" + +def handle_extend_generation(base_video_path: str, new_videos: list, save_path: str, current_gallery: list) -> tuple: + """Combine generated video with base video and update gallery""" + if not base_video_path: + return current_gallery, "Extend failed: No base video provided" + + if not new_videos: + return current_gallery, "Extend failed: No new video generated" + + # Ensure save path exists + os.makedirs(save_path, exist_ok=True) + + # Get the first video from new_videos (gallery item) + new_video_path = new_videos[0][0] if isinstance(new_videos[0], tuple) else new_videos[0] + + # Create a unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + output_filename = f"extended_{timestamp}_{Path(base_video_path).stem}.mp4" + output_path = str(Path(save_path) / output_filename) + + try: + # Concatenate the videos using ffmpeg + ( + ffmpeg + .input(base_video_path) + .concat( + ffmpeg.input(new_video_path) + ) + .output(output_path) + .run(overwrite_output=True, quiet=True) + ) + + # Create a new gallery entry with the combined video + updated_gallery = [(output_path, f"Extended video: {Path(output_path).stem}")] + + return updated_gallery, f"Successfully extended video to {Path(output_path).name}" + except Exception as e: + print(f"Error extending video: {str(e)}") + return current_gallery, f"Failed to extend video: {str(e)}" + +# UI setup +with gr.Blocks( + theme=themes.Default( + primary_hue=colors.Color( + name="custom", + c50="#E6F0FF", + c100="#CCE0FF", + c200="#99C1FF", + c300="#66A3FF", + c400="#3384FF", + c500="#0060df", # This is your main color + c600="#0052C2", + c700="#003D91", + c800="#002961", + c900="#001430", + c950="#000A18" + ) + ), + css=""" + .gallery-item:first-child { border: 2px solid #4CAF50 !important; } + .gallery-item:first-child:hover { border-color: #45a049 !important; } + .green-btn { + background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; + color: white !important; + border: none !important; + } + .green-btn:hover { + background: linear-gradient(to bottom right, #27ae60, #219651) !important; + } + .refresh-btn { + max-width: 40px !important; + min-width: 40px !important; + height: 40px !important; + border-radius: 50% !important; + padding: 0 !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + } + """, + +) as demo: + # Add state for tracking selected video indices in both tabs + selected_index = gr.State(value=None) # For Text to Video + v2v_selected_index = gr.State(value=None) # For Video to Video + params_state = gr.State() #New addition + i2v_selected_index = gr.State(value=None) + skyreels_selected_index = gr.State(value=None) + wanx_i2v_selected_index = gr.State(value=None) + extended_videos = gr.State(value=[]) + wanx_base_video = gr.State(value=None) + wanx_sharpest_frame_number = gr.State(value=None) + wanx_sharpest_frame_path = gr.State(value=None) + wanx_trimmed_video_path = gr.State(value=None) + demo.load(None, None, None, js=""" + () => { + document.title = 'H1111'; + + function updateTitle(text) { + if (text && text.trim()) { + const progressMatch = text.match(/(\d+)%.*\[.*<(\d+:\d+),/); + if (progressMatch) { + const percentage = progressMatch[1]; + const timeRemaining = progressMatch[2]; + document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; + } + } + } + + setTimeout(() => { + const progressElements = document.querySelectorAll('textarea.scroll-hide'); + progressElements.forEach(element => { + if (element) { + new MutationObserver(() => { + updateTitle(element.value); + }).observe(element, { + attributes: true, + childList: true, + characterData: true + }); + } + }); + }, 1000); + } + """) + + with gr.Tabs() as tabs: + # Text to Video Tab + with gr.Tab(id=1, label="Hunyuan-t2v"): + with gr.Row(): + with gr.Column(scale=4): + prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + + with gr.Column(scale=1): + token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + + t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") + t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") + video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") + fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") + infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") + flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") + cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") + + with gr.Column(): + + with gr.Row(): + video_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + with gr.Row(): + refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + lora_weights = [] + lora_multipliers = [] + for i in range(4): + with gr.Column(): + lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + seed = gr.Number(label="Seed (use -1 for random)", value=-1) + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + + #Image to Video Tab + with gr.Tab(label="Hunyuan-i2v") as i2v_tab: + with gr.Row(): + with gr.Column(scale=4): + i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + + with gr.Column(scale=1): + i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + i2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + i2v_input = gr.Image(label="Input Image", type="filepath") + i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + # Scale slider as percentage + scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + # Width and height inputs + with gr.Row(): + width = gr.Number(label="New Width", value=544, step=16) + calc_height_btn = gr.Button("→") + calc_width_btn = gr.Button("←") + height = gr.Number(label="New Height", value=544, step=16) + i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) + with gr.Column(): + i2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + i2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for Image2Video + i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + i2v_lora_weights = [] + i2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + i2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + i2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + i2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + + i2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + i2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + i2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + i2v_save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + i2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + + # Video to Video Tab + with gr.Tab(id=2, label="Hunyuan-v2v") as v2v_tab: + with gr.Row(): + with gr.Column(scale=4): + v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + v2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt (for SkyReels models)", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + v2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + v2v_input = gr.Video(label="Input Video", format="mp4") + v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and Height Inputs + with gr.Row(): + v2v_width = gr.Number(label="New Width", value=544, step=16) + v2v_calc_height_btn = gr.Button("→") + v2v_calc_width_btn = gr.Button("←") + v2v_height = gr.Number(label="New Height", value=544, step=16) + v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) + with gr.Column(): + v2v_output = gr.Gallery( + label="Generated Videos", + columns=[1], + rows=[1], + object_fit="contain", + height="auto" + ) + v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button + v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + v2v_lora_weights = [] + v2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + v2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + v2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + v2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + v2v_save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) + +### SKYREELS + + with gr.Tab(label="SkyReels-i2v") as skyreels_tab: + with gr.Row(): + with gr.Column(scale=4): + skyreels_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + skyreels_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") + with gr.Row(): + skyreels_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) + skyreels_input_folder = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images", + visible=False + ) + skyreels_folder_status = gr.Textbox( + label="Folder Status", + placeholder="Status will appear here", + interactive=False, + visible=False + ) + skyreels_validate_folder_btn = gr.Button("Validate Folder", visible=False) + skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + + # Scale slider as percentage + skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height inputs + with gr.Row(): + skyreels_width = gr.Number(label="New Width", value=544, step=16) + skyreels_calc_height_btn = gr.Button("→") + skyreels_calc_width_btn = gr.Button("←") + skyreels_height = gr.Number(label="New Height", value=544, step=16) + + skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) + skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) + + with gr.Column(): + skyreels_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for SKYREELS + skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + skyreels_lora_weights = [] + skyreels_lora_multipliers = [] + for i in range(4): + with gr.Column(): + skyreels_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + skyreels_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + skyreels_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("skyreels"), + value="skyreels_hunyuan_i2v_bf16.safetensors", + allow_custom_value=True, + interactive=True + ) + skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) + + # WanX Image to Video Tab + with gr.Tab(id=4, label="WanX-i2v") as wanx_i2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + wanx_input = gr.Image(label="Input Image", type="filepath") + with gr.Row(): + wanx_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) + wanx_input_folder = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images", + visible=False + ) + wanx_folder_status = gr.Textbox( + label="Folder Status", + placeholder="Status will appear here", + interactive=False, + visible=False + ) + wanx_validate_folder_btn = gr.Button("Validate Folder", visible=False) + wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height display + with gr.Row(): + wanx_width = gr.Number(label="Width", value=832, interactive=True) + wanx_calc_height_btn = gr.Button("→") + wanx_calc_width_btn = gr.Button("←") + wanx_height = gr.Number(label="Height", value=480, interactive=True) + wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, + info="Recommended: 3.0 for 480p, 5.0 for others") + wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + wanx_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") + wanx_send_last_frame_btn = gr.Button("Send Last Frame to Input") + wanx_extend_btn = gr.Button("Extend Video") + wanx_frames_to_check = gr.Slider(minimum=1, maximum=100, step=1, value=30, + label="Frames to Check from End", + info="Number of frames from the end to check for sharpness") + wanx_send_sharpest_frame_btn = gr.Button("Extract Sharpest Frame") + wanx_trim_and_extend_btn = gr.Button("Trim Video & Prepare for Extension") + wanx_sharpest_frame_status = gr.Textbox(label="Status", interactive=False) + + # Add a new button for directly extending with the trimmed video + wanx_extend_with_trimmed_btn = gr.Button("Extend with Trimmed Video") + + # Add LoRA section for WanX-i2v similar to other tabs + wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_lora_weights = [] + wanx_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_task = gr.Dropdown( + label="Task", + choices=["i2v-14B"], + value="i2v-14B", + info="Currently only i2v-14B is supported" + ) + wanx_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_i2v_480p_14B_bf16.safetensors") + wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") + wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) + wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + #WanX-t2v Tab + + # WanX Text to Video Tab + with gr.Tab(id=5, label="WanX-t2v") as wanx_t2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_t2v_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_t2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") + wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") + wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, + info="Recommended: 3.0 for I2V with 480p, 5.0 for others") + wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_t2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for WanX-t2v + wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_t2v_lora_weights = [] + wanx_t2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_t2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_t2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_t2v_task = gr.Dropdown( + label="Task", + choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], + value="t2v-14B", + info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" + ) + wanx_t2v_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_t2v_14B_bf16.safetensors") + wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") + wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, + info="Max 39 for 14B model, 29 for 1.3B model") + wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + #Video Info Tab + with gr.Tab("Video Info") as video_info_tab: + with gr.Row(): + video_input = gr.Video(label="Upload Video", interactive=True) + metadata_output = gr.JSON(label="Generation Parameters") + + with gr.Row(): + send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") + send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") + send_to_wanx_i2v_btn = gr.Button("Send to WanX-i2v", variant="primary") + send_to_wanx_t2v_btn = gr.Button("Send to WanX-t2v", variant="primary") + + with gr.Row(): + status = gr.Textbox(label="Status", interactive=False) + + #Merge Model's tab + with gr.Tab("Convert LoRA") as convert_lora_tab: + def suggest_output_name(file_obj) -> str: + """Generate suggested output name from input file""" + if not file_obj: + return "" + # Get input filename without extension and add MUSUBI + base_name = os.path.splitext(os.path.basename(file_obj.name))[0] + return f"{base_name}_MUSUBI" + + def convert_lora(input_file, output_name: str, target_format: str) -> str: + """Convert LoRA file to specified format""" + try: + if not input_file: + return "Error: No input file selected" + + # Ensure output directory exists + os.makedirs("lora", exist_ok=True) + + # Construct output path + output_path = os.path.join("lora", f"{output_name}.safetensors") + + # Build command + cmd = [ + sys.executable, + "convert_lora.py", + "--input", input_file.name, + "--output", output_path, + "--target", target_format + ] + + print(f"Converting {input_file.name} to {output_path}") + + # Execute conversion + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + if os.path.exists(output_path): + return f"Successfully converted LoRA to {output_path}" + else: + return "Error: Output file not created" + + except subprocess.CalledProcessError as e: + return f"Error during conversion: {e.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + with gr.Row(): + input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) + output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") + format_radio = gr.Radio( + choices=["default", "other"], + value="default", + label="Target Format", + info="Choose 'default' for H1111/MUSUBI format or 'other' for diffusion pipe format" + ) + + with gr.Row(): + convert_btn = gr.Button("Convert LoRA", variant="primary") + status_output = gr.Textbox(label="Status", interactive=False) + + # Automatically update output name when file is selected + input_file.change( + fn=suggest_output_name, + inputs=[input_file], + outputs=[output_name] + ) + + # Handle conversion + convert_btn.click( + fn=convert_lora, + inputs=[input_file, output_name, format_radio], + outputs=status_output + ) + with gr.Tab("Model Merging") as model_merge_tab: + with gr.Row(): + with gr.Column(): + # Model selection + dit_model = gr.Dropdown( + label="Base DiT Model", + choices=["mp_rank_00_model_states.pt"], + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + with gr.Row(): + with gr.Column(): + # Output model name + output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + merge_btn = gr.Button("Merge Models", variant="primary") + merge_status = gr.Textbox(label="Status", interactive=False) + with gr.Row(): + # LoRA selection section (similar to Text2Video) + merge_lora_weights = [] + merge_lora_multipliers = [] + for i in range(4): + with gr.Column(): + merge_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + merge_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + + #Video Extension + wanx_send_last_frame_btn.click( + fn=send_last_frame_handler, + inputs=[wanx_output, wanx_i2v_selected_index], + outputs=[wanx_input, wanx_base_video] + ) + + wanx_extend_btn.click( + fn=prepare_for_batch_extension, + inputs=[wanx_input, wanx_base_video, wanx_batch_size], + outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] + ).then( + fn=wanx_batch_handler, + inputs=[ + gr.Checkbox(value=False), # Not using random folder + wanx_prompt, wanx_negative_prompt, + wanx_width, wanx_height, wanx_video_length, + wanx_fps, wanx_infer_steps, wanx_flow_shift, + wanx_guidance_scale, wanx_seed, wanx_batch_size, + wanx_input_folder, # Not used but needed for function signature + wanx_task, + wanx_dit_path, wanx_vae_path, wanx_t5_path, + wanx_clip_path, wanx_save_path, wanx_output_type, + wanx_sample_solver, wanx_exclude_single_blocks, + wanx_attn_mode, wanx_block_swap, wanx_fp8, + wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, + *wanx_lora_multipliers, wanx_input # Include input image + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] + ).then( + fn=concat_batch_videos, + inputs=[wanx_base_video, wanx_output, wanx_save_path], + outputs=[wanx_output, wanx_progress_text] + ) + + # Extract and send sharpest frame to input + wanx_send_sharpest_frame_btn.click( + fn=send_sharpest_frame_handler, + inputs=[wanx_output, wanx_i2v_selected_index, wanx_frames_to_check], + outputs=[wanx_input, wanx_base_video, wanx_sharpest_frame_number, wanx_sharpest_frame_status] + ) + + # Trim video to sharpest frame and prepare for extension + wanx_trim_and_extend_btn.click( + fn=trim_and_prepare_for_extension, + inputs=[wanx_base_video, wanx_sharpest_frame_number, wanx_save_path], + outputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status] + ).then( + fn=lambda path, status: (path, status if "Failed" in status else "Video trimmed successfully and ready for extension"), + inputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status], + outputs=[wanx_base_video, wanx_sharpest_frame_status] + ) + + # Event handler for extending with the trimmed video + wanx_extend_with_trimmed_btn.click( + fn=prepare_for_batch_extension, + inputs=[wanx_input, wanx_trimmed_video_path, wanx_batch_size], + outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] + ).then( + fn=wanx_batch_handler, + inputs=[ + gr.Checkbox(value=False), # Not using random folder + wanx_prompt, wanx_negative_prompt, + wanx_width, wanx_height, wanx_video_length, + wanx_fps, wanx_infer_steps, wanx_flow_shift, + wanx_guidance_scale, wanx_seed, wanx_batch_size, + wanx_input_folder, # Not used but needed for function signature + wanx_task, + wanx_dit_path, wanx_vae_path, wanx_t5_path, + wanx_clip_path, wanx_save_path, wanx_output_type, + wanx_sample_solver, wanx_exclude_single_blocks, + wanx_attn_mode, wanx_block_swap, wanx_fp8, + wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, + *wanx_lora_multipliers, wanx_input # Include input image + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] + ).then( + fn=concat_batch_videos, + inputs=[wanx_trimmed_video_path, wanx_output, wanx_save_path], + outputs=[wanx_output, wanx_progress_text] + ) + + #Video Info + def handle_send_to_wanx_tab(metadata, target_tab): + """Common handler for sending video parameters to WanX tabs""" + if not metadata: + return "No parameters to send", {} + + # Tab names for clearer messages + tab_names = { + 'wanx_i2v': 'WanX-i2v', + 'wanx_t2v': 'WanX-t2v' + } + + # Just pass through all parameters - we'll use them in the .then() function + return f"Parameters ready for {tab_names.get(target_tab, target_tab)}", metadata + + def change_to_wanx_i2v_tab(): + return gr.Tabs(selected=4) # WanX-i2v tab index + + def change_to_wanx_t2v_tab(): + return gr.Tabs(selected=5) # WanX-t2v tab index + + send_to_wanx_i2v_btn.click( + fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_i2v'), + inputs=[metadata_output], + outputs=[status, params_state] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 40), + params.get("seed", -1), + params.get("flow_shift", 3.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0), + params.get("task", "i2v-14B") + ] if params else [gr.update()]*12, + inputs=params_state, + outputs=[ + wanx_prompt, + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_seed, + wanx_flow_shift, + wanx_guidance_scale, + wanx_attn_mode, + wanx_block_swap, + wanx_task + ] + ).then( + fn=change_to_wanx_i2v_tab, inputs=None, outputs=[tabs] + ) + + # 3. Update the WanX-t2v button handler + send_to_wanx_t2v_btn.click( + fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_t2v'), + inputs=[metadata_output], + outputs=[status, params_state] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 50), + params.get("seed", -1), + params.get("flow_shift", 5.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0) + ] if params else [gr.update()]*11, + inputs=params_state, + outputs=[ + wanx_t2v_prompt, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_seed, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_attn_mode, + wanx_t2v_block_swap + ] + ).then( + fn=change_to_wanx_t2v_tab, inputs=None, outputs=[tabs] + ) + + #text to video + def change_to_tab_one(): + return gr.Tabs(selected=1) #This will navigate + #video to video + def change_to_tab_two(): + return gr.Tabs(selected=2) #This will navigate + def change_to_skyreels_tab(): + return gr.Tabs(selected=3) + + #SKYREELS TAB!!! + # Add state management for dimensions + def sync_skyreels_dimensions(width, height): + return gr.update(value=width), gr.update(value=height) + + # Add this function to update the LoRA dropdowns in the SKYREELS tab + def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + + # Add this function to update the models dropdown in the SKYREELS tab + def update_skyreels_model_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Add event handler for model dropdown refresh + skyreels_dit_folder.change( + fn=update_skyreels_model_dropdown, + inputs=[skyreels_dit_folder], + outputs=[skyreels_model] + ) + + # Add handlers for the refresh button + skyreels_refresh_btn.click( + fn=update_skyreels_lora_dropdowns, + inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] + ) + # Skyreels dimension handling + def calculate_skyreels_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 + return gr.update(value=new_width) + + def calculate_skyreels_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 + return gr.update(value=new_height) + + def update_skyreels_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_skyreels_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + def send_skyreels_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float, + negative_prompt: str = "" # Add this parameter + ) -> Tuple: + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + # Add event handlers for the SKYREELS tab + skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) + skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling + skyreels_input.change( + fn=update_skyreels_dimensions, + inputs=[skyreels_input], + outputs=[skyreels_original_dims, skyreels_width, skyreels_height] + ) + + skyreels_scale_slider.change( + fn=update_skyreels_from_scale, + inputs=[skyreels_scale_slider, skyreels_original_dims], + outputs=[skyreels_width, skyreels_height] + ) + + skyreels_calc_width_btn.click( + fn=calculate_skyreels_width, + inputs=[skyreels_height, skyreels_original_dims], + outputs=[skyreels_width] + ) + + skyreels_calc_height_btn.click( + fn=calculate_skyreels_height, + inputs=[skyreels_width, skyreels_original_dims], + outputs=[skyreels_height] + ) + + # Handle checkbox visibility toggling + skyreels_use_random_folder.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), + inputs=[skyreels_use_random_folder], + outputs=[skyreels_input_folder, skyreels_folder_status, skyreels_input] + ) + + # Validate folder button click handler + skyreels_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], + inputs=[skyreels_input_folder], + outputs=[skyreels_folder_status] + ) + + skyreels_use_random_folder.change( + fn=lambda x: gr.update(visible=x), + inputs=[skyreels_use_random_folder], + outputs=[skyreels_validate_folder_btn] + ) + + # Modify the skyreels_generate_btn.click event handler to use process_random_image_batch when folder mode is on + skyreels_generate_btn.click( + fn=batch_handler, + inputs=[ + skyreels_use_random_folder, + # Rest of the arguments + skyreels_prompt, + skyreels_negative_prompt, + skyreels_width, + skyreels_height, + skyreels_video_length, + skyreels_fps, + skyreels_infer_steps, + skyreels_seed, + skyreels_flow_shift, + skyreels_guidance_scale, + skyreels_embedded_cfg_scale, + skyreels_batch_size, + skyreels_input_folder, + skyreels_dit_folder, + skyreels_model, + skyreels_vae, + skyreels_te1, + skyreels_te2, + skyreels_save_path, + skyreels_output_type, + skyreels_attn_mode, + skyreels_block_swap, + skyreels_exclude_single_blocks, + skyreels_use_split_attn, + skyreels_use_fp8, + skyreels_split_uncond, + skyreels_lora_folder, + *skyreels_lora_weights, + *skyreels_lora_multipliers, + skyreels_input # Add the input image path + ], + outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[skyreels_batch_size], + outputs=skyreels_selected_index + ) + + # Gallery selection handling + skyreels_output.select( + fn=handle_skyreels_gallery_select, + outputs=skyreels_selected_index + ) + + # Send to Video2Video handler + skyreels_send_to_v2v_btn.click( + fn=send_skyreels_to_v2v, + inputs=[ + skyreels_output, skyreels_prompt, skyreels_selected_index, + skyreels_width, skyreels_height, skyreels_video_length, + skyreels_fps, skyreels_infer_steps, skyreels_seed, + skyreels_flow_shift, skyreels_guidance_scale + ] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component + outputs=[ + v2v_input, v2v_prompt, v2v_width, v2v_height, + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + + # Refresh button handler + skyreels_refresh_outputs = [skyreels_model] + for i in range(4): + skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) + + skyreels_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=skyreels_refresh_outputs + ) + + def calculate_v2v_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_width) + + def calculate_v2v_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_height) + + def update_v2v_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_v2v_dimensions(video): + if video is None: + return "", gr.update(value=544), gr.update(value=544) + cap = cv2.VideoCapture(video) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + # Make dimensions divisible by 16 + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + # Event Handlers for Video to Video Tab + v2v_input.change( + fn=update_v2v_dimensions, + inputs=[v2v_input], + outputs=[v2v_original_dims, v2v_width, v2v_height] + ) + + v2v_scale_slider.change( + fn=update_v2v_from_scale, + inputs=[v2v_scale_slider, v2v_original_dims], + outputs=[v2v_width, v2v_height] + ) + + v2v_calc_width_btn.click( + fn=calculate_v2v_width, + inputs=[v2v_height, v2v_original_dims], + outputs=[v2v_width] + ) + + v2v_calc_height_btn.click( + fn=calculate_v2v_height, + inputs=[v2v_width, v2v_original_dims], + outputs=[v2v_height] + ) + + ##Image 2 video dimension logic + def calculate_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_width) + + def calculate_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_height) + + def update_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + # Make dimensions divisible by 16 + w = (w // 16) * 16 # Changed from 8 to 16 + h = (h // 16) * 16 # Changed from 8 to 16 + return f"{w}x{h}", w, h + i2v_input.change( + fn=update_dimensions, + inputs=[i2v_input], + outputs=[original_dims, width, height] + ) + + scale_slider.change( + fn=update_from_scale, + inputs=[scale_slider, original_dims], + outputs=[width, height] + ) + + calc_width_btn.click( + fn=calculate_width, + inputs=[height, original_dims], + outputs=[width] + ) + + calc_height_btn.click( + fn=calculate_height, + inputs=[width, original_dims], + outputs=[height] + ) + + # Function to get available DiT models + def get_dit_models(dit_folder: str) -> List[str]: + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + + # Function to perform model merging + def merge_models( + dit_folder: str, + dit_model: str, + output_model: str, + exclude_single_blocks: bool, + merge_lora_folder: str, + *lora_params # Will contain both weights and multipliers + ) -> str: + try: + # Separate weights and multipliers + num_loras = len(lora_params) // 2 + weights = list(lora_params[:num_loras]) + multipliers = list(lora_params[num_loras:]) + + # Filter out "None" selections + valid_loras = [] + for weight, mult in zip(weights, multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) + + if not valid_loras: + return "No LoRA models selected for merging" + + # Create output path in the dit folder + os.makedirs(dit_folder, exist_ok=True) + output_path = os.path.join(dit_folder, output_model) + + # Prepare command + cmd = [ + sys.executable, + "merge_lora.py", + "--dit", os.path.join(dit_folder, dit_model), + "--save_merged_model", output_path + ] + + # Add LoRA weights and multipliers + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + # Execute merge operation + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + if os.path.exists(output_path): + return f"Successfully merged model and saved to {output_path}" + else: + return "Error: Output file not created" + + except subprocess.CalledProcessError as e: + return f"Error during merging: {e.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + # Update DiT model dropdown + def update_dit_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Connect events + merge_btn.click( + fn=merge_models, + inputs=[ + dit_folder, + dit_model, + output_model, + exclude_single_blocks, + merge_lora_folder, + *merge_lora_weights, + *merge_lora_multipliers + ], + outputs=merge_status + ) + + # Refresh buttons for both DiT and LoRA dropdowns + merge_refresh_btn.click( + fn=lambda f: update_dit_dropdown(f), + inputs=[dit_folder], + outputs=[dit_model] + ) + + # LoRA refresh handling + merge_refresh_outputs = [] + for i in range(4): + merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) + + merge_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, + outputs=merge_refresh_outputs + ) + # Event handlers + prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) + v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) + stop_btn.click(fn=lambda: stop_event.set(), queue=False) + v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + #Image_to_Video + def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters + img = Image.open(image_path) + + # Resize to the specified dimensions + img_resized = img.resize((width, height), Image.LANCZOS) + temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") + img_resized.save(temp_image_path) + + # Rest of function remains the same + frame_rate = 24 + duration = frames / frame_rate + command = [ + "ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", + "-t", str(duration), "-pix_fmt", "yuv420p", + "-vf", f"fps={frame_rate}", output_path + ] + + try: + subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + print(f"Video saved to {output_path}") + return True + except subprocess.CalledProcessError as e: + print(f"An error occurred while creating the video: {e}") + return False + finally: + # Clean up the temporary image file + if os.path.exists(temp_image_path): + os.remove(temp_image_path) + img.close() # Make sure to close the image file explicitly + + def generate_from_image( + image_path, + prompt, width, height, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, strength, batch_size, *lora_params + ): + """Generate video from input image with progressive updates""" + global stop_event + stop_event.clear() + + # Create temporary video path + temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") + + try: + # Convert image to video + if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): + yield [], "Failed to create temporary video", "Error in video creation" + return + + # Ensure video is fully written before proceeding + time.sleep(1) + if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: + yield [], "Failed to create temporary video", "Temporary video file is empty or missing" + return + + # Get video dimensions + try: + probe = ffmpeg.probe(temp_video_path) + video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) + if video_stream is None: + raise ValueError("No video stream found") + width = int(video_stream['width']) + height = int(video_stream['height']) + except Exception as e: + yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" + return + + # Generate the video using the temporary file + try: + generator = process_single_video( + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_params, video_path=temp_video_path, strength=strength + ) + + # Forward all generator updates + for videos, batch_text, progress_text in generator: + yield videos, batch_text, progress_text + + except Exception as e: + yield [], f"Error in video generation: {str(e)}", "Generation error" + return + + except Exception as e: + yield [], f"Unexpected error: {str(e)}", "Error occurred" + return + + finally: + # Clean up temporary file + try: + if os.path.exists(temp_video_path): + os.remove(temp_video_path) + except Exception: + pass # Ignore cleanup errors + + + # Add event handlers + i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) + i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + def handle_i2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when I2V gallery item is clicked""" + return evt.index + + def send_i2v_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: + """Send the selected video and parameters from Image2Video tab to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ + lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Use the original width and height without doubling + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) + + # Generate button handler + i2v_generate_btn.click( + fn=process_batch, + inputs=[ + i2v_prompt, width, height, + i2v_batch_size, i2v_video_length, + i2v_fps, i2v_infer_steps, i2v_seed, i2v_dit_folder, i2v_model, i2v_vae, i2v_te1, i2v_te2, + i2v_save_path, i2v_flow_shift, i2v_cfg_scale, i2v_output_type, i2v_attn_mode, + i2v_block_swap, i2v_exclude_single_blocks, i2v_use_split_attn, i2v_lora_folder, + *i2v_lora_weights, *i2v_lora_multipliers, i2v_input, i2v_strength, i2v_use_fp8 + ], + outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[i2v_batch_size], + outputs=i2v_selected_index + ) + # Send to Video2Video + i2v_output.select( + fn=handle_i2v_gallery_select, + outputs=i2v_selected_index + ) + + i2v_send_to_v2v_btn.click( + fn=send_i2v_to_v2v, + inputs=[ + i2v_output, i2v_prompt, i2v_selected_index, + width, height, + i2v_video_length, i2v_fps, i2v_infer_steps, + i2v_seed, i2v_flow_shift, i2v_cfg_scale + ] + i2v_lora_weights + i2v_lora_multipliers, + outputs=[ + v2v_input, v2v_prompt, + v2v_width, v2v_height, + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + #Video Info + def clean_video_path(video_path) -> str: + """Extract clean video path from Gradio's various return formats""" + print(f"Input video_path: {video_path}, type: {type(video_path)}") + if isinstance(video_path, dict): + path = video_path.get("name", "") + elif isinstance(video_path, (tuple, list)): + path = video_path[0] + elif isinstance(video_path, str): + path = video_path + else: + path = "" + print(f"Cleaned path: {path}") + return path + def handle_video_upload(video_path: str) -> Dict: + """Handle video upload and metadata extraction""" + if not video_path: + return {}, "No video uploaded" + + metadata = extract_video_metadata(video_path) + if not metadata: + return {}, "No metadata found in video" + + return metadata, "Metadata extracted successfully" + + def get_video_info(video_path: str) -> dict: + try: + probe = ffmpeg.probe(video_path) + video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') + + width = int(video_info['width']) + height = int(video_info['height']) + fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 + + # Calculate total frames + duration = float(probe['format']['duration']) + total_frames = int(duration * fps) + + # Ensure video length does not exceed 201 frames + if total_frames > 201: + total_frames = 201 + duration = total_frames / fps # Adjust duration accordingly + + return { + 'width': width, + 'height': height, + 'fps': fps, + 'total_frames': total_frames, + 'duration': duration # Might be useful in some contexts + } + except Exception as e: + print(f"Error extracting video info: {e}") + return {} + + def extract_video_details(video_path: str) -> Tuple[dict, str]: + metadata = extract_video_metadata(video_path) + video_details = get_video_info(video_path) + + # Combine metadata with video details + for key, value in video_details.items(): + if key not in metadata: + metadata[key] = value + + # Ensure video length does not exceed 201 frames + if 'video_length' in metadata: + metadata['video_length'] = min(metadata['video_length'], 201) + else: + metadata['video_length'] = min(video_details.get('total_frames', 0), 201) + + # Return both the updated metadata and a status message + return metadata, "Video details extracted successfully" + + def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: + """Create parameter mapping for target tab""" + if not metadata: + return "No parameters to send", {} + + tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" + try: + mapping = create_parameter_transfer_map(metadata, target_tab) + return f"Parameters ready for {tab_name}", mapping + except Exception as e: + return f"Error: {str(e)}", {} + + video_input.upload( + fn=extract_video_details, + inputs=video_input, + outputs=[metadata_output, status] + ) + + send_to_t2v_btn.click( + fn=lambda m: send_parameters_to_tab(m, "t2v"), + inputs=metadata_output, + outputs=[status, params_state] + ).then( + fn=change_to_tab_one, inputs=None, outputs=[tabs] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 544), + params.get("height", 544), + params.get("batch_size", 1), + params.get("video_length", 25), + params.get("fps", 24), + params.get("infer_steps", 30), + params.get("seed", -1), + params.get("model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("vae", "hunyuan/pytorch_model.pt"), + params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("te2", "hunyuan/clip_l.safetensors"), + params.get("save_path", "outputs"), + params.get("flow_shift", 11.0), + params.get("cfg_scale", 7.0), + params.get("output_type", "video"), + params.get("attn_mode", "sdpa"), + params.get("block_swap", "0"), + *[params.get(f"lora{i+1}", "") for i in range(4)], + *[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] + ] if params else [gr.update()]*26, + inputs=params_state, + outputs=[prompt, width, height, batch_size, video_length, fps, infer_steps, seed, + model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap] + lora_weights + lora_multipliers + ) + # Text to Video generation + generate_btn.click( + fn=process_batch, + inputs=[ + prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 + ], + outputs=[video_output, batch_progress, progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[batch_size], + outputs=selected_index + ) + + # Update gallery selection handling + def handle_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + # Track selected index when gallery item is clicked + video_output.select( + fn=handle_gallery_select, + outputs=selected_index + ) + + # Track selected index when Video2Video gallery item is clicked + def handle_v2v_gallery_select(evt: gr.SelectData) -> int: + """Handle gallery selection without automatically updating the input""" + return evt.index + + # Update the gallery selection event + v2v_output.select( + fn=handle_v2v_gallery_select, + outputs=v2v_selected_index + ) + + # Send button handler with gallery selection + def handle_send_button( + gallery: list, + prompt: str, + idx: int, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> tuple: + if not gallery or idx is None or idx >= len(gallery): + return (None, "", width, height, batch_size, video_length, fps, infer_steps, + seed, flow_shift, cfg_scale, + lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + "") # Add empty string for negative_prompt in the return values + + # Auto-select first item if only one exists and no selection made + if idx is None and len(gallery) == 1: + idx = 0 + + selected_item = gallery[idx] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return ( + str(video_path), + prompt, + width, + height, + batch_size, + video_length, + fps, + infer_steps, + seed, + flow_shift, + cfg_scale, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier, + "" # Add empty string for negative_prompt + ) + + send_t2v_to_v2v_btn.click( + fn=handle_send_button, + inputs=[ + video_output, prompt, selected_index, + t2v_width, t2v_height, batch_size, video_length, + fps, infer_steps, seed, flow_shift, cfg_scale + ] + lora_weights + lora_multipliers, # Remove the string here + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_batch_size, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: + """Handle both parameters and video transfer""" + status_msg, params = send_parameters_to_tab(metadata, "v2v") + return status_msg, params, video_path + + def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: + """Handle both parameters and video transfer from Video Info to V2V tab""" + if not video_path: + return "No video selected", {}, None + + status_msg, params = send_parameters_to_tab(metadata, "v2v") + # Just return the path directly + return status_msg, params, video_path + + # Send button click handler + send_to_v2v_btn.click( + fn=handle_info_to_v2v, + inputs=[metadata_output, video_input], + outputs=[status, params_state, v2v_input] + ).then( + lambda params: [ + params.get("v2v_prompt", ""), + params.get("v2v_width", 544), + params.get("v2v_height", 544), + params.get("v2v_batch_size", 1), + params.get("v2v_video_length", 25), + params.get("v2v_fps", 24), + params.get("v2v_infer_steps", 30), + params.get("v2v_seed", -1), + params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("v2v_vae", "hunyuan/pytorch_model.pt"), + params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("v2v_te2", "hunyuan/clip_l.safetensors"), + params.get("v2v_save_path", "outputs"), + params.get("v2v_flow_shift", 11.0), + params.get("v2v_cfg_scale", 7.0), + params.get("v2v_output_type", "video"), + params.get("v2v_attn_mode", "sdpa"), + params.get("v2v_block_swap", "0"), + *[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], + *[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] + ] if params else [gr.update()] * 26, + inputs=params_state, + outputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, + v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, + v2v_attn_mode, v2v_block_swap + ] + v2v_lora_weights + v2v_lora_multipliers + ).then( + lambda: print(f"Tabs object: {tabs}"), # Debug print + outputs=None + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + # Handler for sending selected video from Video2Video gallery to input + def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: + """Send the currently selected video in V2V gallery to V2V input""" + if not gallery or idx is None or idx >= len(gallery): + return None, "" + + selected_item = gallery[idx] + video_path = None + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] # Gallery returns (path, caption) + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, str): + video_path = selected_item + + if not video_path: + return None, "" + + # Check if the file exists and is accessible + if not os.path.exists(video_path): + print(f"Warning: Video file not found at {video_path}") + return None, "" + + return video_path, prompt + + v2v_send_to_input_btn.click( + fn=handle_v2v_send_button, + inputs=[v2v_output, v2v_prompt, v2v_selected_index], + outputs=[v2v_input, v2v_prompt] + ).then( + lambda: gr.update(visible=True), # Ensure the video input is visible + outputs=v2v_input + ) + + # Video to Video generation + v2v_generate_btn.click( + fn=process_batch, + inputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, + v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, + v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, + *v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, + v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 + ], + outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[v2v_batch_size], + outputs=v2v_selected_index + ) + refresh_outputs = [model] # Add model dropdown to outputs + for i in range(4): + refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) + + refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, + outputs=refresh_outputs + ) + # Image2Video refresh + i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs + for i in range(4): + i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) + + i2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, + outputs=i2v_refresh_outputs + ) + + # Video2Video refresh + v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs + for i in range(4): + v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) + + v2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, + outputs=v2v_refresh_outputs + ) + + # WanX-i2v tab connections + wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) + wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling for WanX-i2v + wanx_input.change( + fn=update_wanx_image_dimensions, + inputs=[wanx_input], + outputs=[wanx_original_dims, wanx_width, wanx_height] + ) + + # Scale slider handling for WanX-i2v + wanx_scale_slider.change( + fn=update_wanx_from_scale, + inputs=[wanx_scale_slider, wanx_original_dims], + outputs=[wanx_width, wanx_height] + ) + + # Width/height calculation buttons for WanX-i2v + wanx_calc_width_btn.click( + fn=calculate_wanx_width, + inputs=[wanx_height, wanx_original_dims], + outputs=[wanx_width] + ) + + wanx_calc_height_btn.click( + fn=calculate_wanx_height, + inputs=[wanx_width, wanx_original_dims], + outputs=[wanx_height] + ) + # Add visibility toggle for the folder input components + wanx_use_random_folder.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), + inputs=[wanx_use_random_folder], + outputs=[wanx_input_folder, wanx_folder_status, wanx_validate_folder_btn, wanx_input] + ) + + # Validate folder button handler + wanx_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], + inputs=[wanx_input_folder], + outputs=[wanx_folder_status] + ) + + # Flow shift recommendation buttons + wanx_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_width, wanx_height], + outputs=[wanx_flow_shift] + ) + + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Generate button handler + wanx_generate_btn.click( + fn=wanx_batch_handler, + inputs=[ + wanx_use_random_folder, + wanx_prompt, + wanx_negative_prompt, + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_flow_shift, + wanx_guidance_scale, + wanx_seed, + wanx_batch_size, + wanx_input_folder, + wanx_task, + wanx_dit_path, + wanx_vae_path, + wanx_t5_path, + wanx_clip_path, + wanx_save_path, + wanx_output_type, + wanx_sample_solver, + wanx_exclude_single_blocks, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, + wanx_fp8_t5, + wanx_lora_folder, + *wanx_lora_weights, + *wanx_lora_multipliers, + wanx_input # Include input image path for non-batch mode + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_batch_size], + outputs=wanx_i2v_selected_index # Update to use correct state + ) + + # Add refresh button handler for WanX-i2v tab + wanx_refresh_outputs = [] + for i in range(4): + wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) + + wanx_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, + outputs=wanx_refresh_outputs + ) + + # Gallery selection handling + wanx_output.select( + fn=handle_wanx_gallery_select, + inputs=[wanx_output], + outputs=[wanx_i2v_selected_index, wanx_base_video] + ) + + # Send to Video2Video handler + wanx_send_to_v2v_btn.click( + fn=send_wanx_to_v2v, + inputs=[ + wanx_output, # Gallery with videos + wanx_prompt, # Prompt text + wanx_i2v_selected_index, # Use the correct selected index state + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_seed, + wanx_flow_shift, + wanx_guidance_scale, + wanx_negative_prompt + ], + outputs=[ + v2v_input, # Video input in V2V tab + v2v_prompt, # Prompt in V2V tab + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, # Function to switch to Video2Video tab + inputs=None, + outputs=[tabs] + ) + + # Add state for T2V tab selected index + wanx_t2v_selected_index = gr.State(value=None) + + # Connect prompt token counter + wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) + + # Stop button handler + wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Flow shift recommendation button + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Task change handler to update CLIP visibility and path + def update_clip_visibility(task): + is_i2v = "i2v" in task + return gr.update(visible=is_i2v) + + wanx_t2v_task.change( + fn=update_clip_visibility, + inputs=[wanx_t2v_task], + outputs=[wanx_t2v_clip_path] + ) + + # Generate button handler for T2V + wanx_t2v_generate_btn.click( + fn=wanx_generate_video_batch, + inputs=[ + wanx_t2v_prompt, + wanx_t2v_negative_prompt, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_seed, + wanx_t2v_task, + wanx_t2v_dit_path, + wanx_t2v_vae_path, + wanx_t2v_t5_path, + wanx_t2v_clip_path, + wanx_t2v_save_path, + wanx_t2v_output_type, + wanx_t2v_sample_solver, + wanx_t2v_exclude_single_blocks, + wanx_t2v_attn_mode, + wanx_t2v_block_swap, + wanx_t2v_fp8, + wanx_t2v_fp8_t5, + wanx_t2v_lora_folder, + *wanx_t2v_lora_weights, + *wanx_t2v_lora_multipliers, + wanx_t2v_batch_size, + # input_image is now optional and not included here + ], + outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_t2v_batch_size], + outputs=wanx_t2v_selected_index + ) + + # Add refresh button handler for WanX-t2v tab + wanx_t2v_refresh_outputs = [] + for i in range(4): + wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) + + wanx_t2v_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, + outputs=wanx_t2v_refresh_outputs + ) + + # Gallery selection handling + wanx_t2v_output.select( + fn=handle_wanx_t2v_gallery_select, + outputs=wanx_t2v_selected_index + ) + + # Send to Video2Video handler + wanx_t2v_send_to_v2v_btn.click( + fn=send_wanx_t2v_to_v2v, + inputs=[ + wanx_t2v_output, + wanx_t2v_prompt, + wanx_t2v_selected_index, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_seed, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_negative_prompt + ], + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + +demo.queue().launch(server_name="0.0.0.0", share=False) \ No newline at end of file diff --git a/i2vhunyuan.py b/i2vhunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4a158fb78f107ebbe25338ed880623014d72f6 --- /dev/null +++ b/i2vhunyuan.py @@ -0,0 +1,4970 @@ +import gradio as gr +from gradio import update as gr_update +import subprocess +import threading +import time +import re +import os +import random +import tiktoken +import sys +import ffmpeg +from typing import List, Tuple, Optional, Generator, Dict +import json +from gradio import themes +from gradio.themes.utils import colors +import subprocess +from PIL import Image +import math +import cv2 +import glob +import shutil +from pathlib import Path +import logging +from datetime import datetime +from tqdm import tqdm + + +# Add global stop event +stop_event = threading.Event() + +logger = logging.getLogger(__name__) + +def process_hunyuani2v_video( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + lora1: str = "", + lora2: str = "", + lora3: str = "", + lora4: str = "", + lora1_multiplier: float = 1.0, + lora2_multiplier: float = 1.0, + lora3_multiplier: float = 1.0, + lora4_multiplier: float = 1.0, + video_path: Optional[str] = None, + image_path: Optional[str] = None, + strength: Optional[float] = None, + negative_prompt: Optional[str] = None, + embedded_cfg_scale: Optional[float] = None, + split_uncond: Optional[bool] = None, + guidance_scale: Optional[float] = None, + use_fp8: bool = True, + clip_vision_path: Optional[str] = None, + i2v_stability: bool = False, + fp8_fast: bool = False, + compile_model: bool = False, + compile_backend: str = "inductor", + compile_mode: str = "max-autotune-no-cudagraphs", + compile_dynamic: bool = False, + compile_fullgraph: bool = False +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate a single video with the hunyuani2v script with updated parameters""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + # Set defaults for hunyuani2v specific parameters + if is_skyreels: + # Force certain parameters for SkyReels + if negative_prompt is None: + negative_prompt = "" + if embedded_cfg_scale is None: + embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels + if split_uncond is None: + split_uncond = True + if guidance_scale is None: + guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided + + else: + embedded_cfg_scale = cfg_scale + + if os.path.isabs(model): + model_path = model + else: + model_path = os.path.normpath(os.path.join(dit_folder, model)) + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + env["BATCH_RUN_ID"] = f"{time.time()}" + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) + if batch_size > 1: # Only modify seed for batch generation + current_seed = (seed + batch_id * 100003) % (2**32) + else: + current_seed = seed + + clear_cuda_cache() + + # Now use hv_generate_video_with_hunyuani2v.py instead + command = [ + sys.executable, + "hv_generate_video_with_hunyuani2v.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128" + ] + + if use_fp8: + command.append("--fp8") + + # Add new parameters specific to hunyuani2v script + if clip_vision_path: + command.extend(["--clip_vision_path", clip_vision_path]) + + if i2v_stability: + command.append("--i2v_stability") + + if fp8_fast: + command.append("--fp8_fast") + + if compile_model: + command.append("--compile") + command.extend([ + "--compile_args", + compile_backend, + compile_mode, + str(compile_dynamic).lower(), + str(compile_fullgraph).lower() + ]) + + # Add negative prompt and embedded cfg scale + command.extend(["--guidance_scale", str(guidance_scale)]) + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + if split_uncond: + command.append("--split_uncond") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + if use_split_attn: + command.append("--split_attn") + + # Handle input paths + if video_path: + command.extend(["--video_path", video_path]) + if strength is not None: + command.extend(["--strength", str(strength)]) + elif image_path: + command.extend(["--image_path", image_path]) + # Only add strength parameter for non-SkyReels I2V models + # SkyReels I2V doesn't use strength parameter for image-to-video generation + if strength is not None and not is_skyreels_i2v: + command.extend(["--strength", str(strength)]) + + print(f"{command}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "model": model, + "vae": vae, + "te1": te1, + "te2": te2, + "save_path": save_path, + "flow_shift": flow_shift, + "cfg_scale": cfg_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "input_video": video_path if video_path else None, + "input_image": image_path if image_path else None, + "strength": strength, + "negative_prompt": negative_prompt, + "embedded_cfg_scale": embedded_cfg_scale, + "clip_vision_path": clip_vision_path, + "i2v_stability": i2v_stability, + "fp8_fast": fp8_fast, + "compile_model": compile_model + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +# Now let's create a new batch processing function that uses the hunyuani2v function +def process_hunyuani2v_batch( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + *args +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Process a batch of videos using the hunyuani2v script""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract additional arguments + num_lora_weights = 4 + lora_weights = args[:num_lora_weights] + lora_multipliers = args[num_lora_weights:num_lora_weights*2] + + # New parameters for hunyuani2v + # Base parameter list index after lora weights and multipliers + base_idx = num_lora_weights*2 + + # Extract parameters + input_path = args[base_idx] if len(args) > base_idx else None + strength = float(args[base_idx+1]) if len(args) > base_idx+1 and args[base_idx+1] is not None else None + negative_prompt = str(args[base_idx+2]) if len(args) > base_idx+2 and args[base_idx+2] is not None else None + guidance_scale = float(args[base_idx+3]) if len(args) > base_idx+3 and args[base_idx+3] is not None else cfg_scale + split_uncond = bool(args[base_idx+4]) if len(args) > base_idx+4 else None + use_fp8 = bool(args[base_idx+5]) if len(args) > base_idx+5 else True + + # New hunyuani2v parameters + clip_vision_path = str(args[base_idx+6]) if len(args) > base_idx+6 and args[base_idx+6] is not None else None + i2v_stability = bool(args[base_idx+7]) if len(args) > base_idx+7 else False + fp8_fast = bool(args[base_idx+8]) if len(args) > base_idx+8 else False + compile_model = bool(args[base_idx+9]) if len(args) > base_idx+9 else False + compile_backend = str(args[base_idx+10]) if len(args) > base_idx+10 and args[base_idx+10] is not None else "inductor" + compile_mode = str(args[base_idx+11]) if len(args) > base_idx+11 and args[base_idx+11] is not None else "max-autotune-no-cudagraphs" + compile_dynamic = bool(args[base_idx+12]) if len(args) > base_idx+12 else False + compile_fullgraph = bool(args[base_idx+13]) if len(args) > base_idx+13 else False + + embedded_cfg_scale = cfg_scale + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Handle different input types + video_path = None + image_path = None + + if input_path: + is_image = False + lower_path = input_path.lower() + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') + is_image = any(lower_path.endswith(ext) for ext in image_extensions) + + if is_image: + image_path = input_path + else: + video_path = input_path + + # Prepare arguments for process_hunyuani2v_video + current_seed = seed + i if seed != -1 and batch_size > 1 else seed if seed != -1 else -1 + + hunyuani2v_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + hunyuani2v_args.extend(lora_weights) + hunyuani2v_args.extend(lora_multipliers) + hunyuani2v_args.extend([ + video_path, image_path, strength, negative_prompt, embedded_cfg_scale, + split_uncond, guidance_scale, use_fp8, clip_vision_path, i2v_stability, + fp8_fast, compile_model, compile_backend, compile_mode, compile_dynamic, compile_fullgraph + ]) + + for videos, status, progress in process_hunyuani2v_video(*hunyuani2v_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def variance_of_laplacian(image): + """ + Compute the variance of the Laplacian of the image. + Higher variance indicates a sharper image. + """ + return cv2.Laplacian(image, cv2.CV_64F).var() + +def extract_sharpest_frame(video_path, frames_to_check=30): + """ + Extract the sharpest frame from the last N frames of the video. + + Args: + video_path (str): Path to the video file + frames_to_check (int): Number of frames from the end to check + + Returns: + tuple: (temp_image_path, frame_number, sharpness_score) + """ + print(f"\n=== Extracting sharpest frame from the last {frames_to_check} frames ===") + print(f"Input video path: {video_path}") + + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None, None, None + + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None, None, None + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + print(f"Total frames detected: {total_frames}, FPS: {fps:.2f}") + + if total_frames < 1: + print("❌ Error: Video contains 0 frames") + return None, None, None + + # Determine how many frames to check (the last N frames) + if frames_to_check > total_frames: + frames_to_check = total_frames + start_frame = 0 + else: + start_frame = total_frames - frames_to_check + + print(f"Checking frames {start_frame} to {total_frames-1}") + + # Find the sharpest frame + sharpest_frame = None + max_sharpness = -1 + sharpest_frame_number = -1 + + # Set starting position + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + # Process frames with a progress bar + with tqdm(total=frames_to_check, desc="Finding sharpest frame") as pbar: + frame_idx = start_frame + while frame_idx < total_frames: + ret, frame = cap.read() + if not ret: + break + + # Convert to grayscale and calculate sharpness + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + sharpness = variance_of_laplacian(gray) + + # Update if this is the sharpest frame so far + if sharpness > max_sharpness: + max_sharpness = sharpness + sharpest_frame = frame.copy() + sharpest_frame_number = frame_idx + + frame_idx += 1 + pbar.update(1) + + cap.release() + + if sharpest_frame is None: + print("❌ Error: Failed to find a sharp frame") + return None, None, None + + # Prepare output path + temp_dir = os.path.abspath("temp_frames") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, f"sharpest_frame_{os.path.basename(video_path)}.png") + print(f"Saving frame to: {temp_path}") + + # Write and verify + if not cv2.imwrite(temp_path, sharpest_frame): + print("❌ Error: Failed to write frame to file") + return None, None, None + + if not os.path.exists(temp_path): + print("❌ Error: Output file not created") + return None, None, None + + # Calculate frame time in seconds + frame_time = sharpest_frame_number / fps + + print(f"✅ Extracted sharpest frame: {sharpest_frame_number} (at {frame_time:.2f}s) with sharpness {max_sharpness:.2f}") + return temp_path, sharpest_frame_number, max_sharpness + + except Exception as e: + print(f"❌ Unexpected error: {str(e)}") + return None, None, None + finally: + if 'cap' in locals(): + cap.release() + +def trim_video_to_frame(video_path, frame_number, output_dir="outputs"): + """ + Trim video up to the specified frame and save as a new video. + + Args: + video_path (str): Path to the video file + frame_number (int): Frame number to trim to + output_dir (str): Directory to save the trimmed video + + Returns: + str: Path to the trimmed video file + """ + print(f"\n=== Trimming video to frame {frame_number} ===") + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None + + try: + # Get video information + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None + + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + + # Calculate time in seconds + time_seconds = frame_number / fps + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Generate output filename + timestamp = f"{int(time_seconds)}s" + base_name = Path(video_path).stem + output_file = os.path.join(output_dir, f"{base_name}_trimmed_to_{timestamp}.mp4") + + # Use ffmpeg to trim the video + ( + ffmpeg + .input(video_path) + .output(output_file, to=time_seconds, c="copy") + .global_args('-y') # Overwrite output files + .run(quiet=True) + ) + + if not os.path.exists(output_file): + print("❌ Error: Failed to create trimmed video") + return None + + print(f"✅ Successfully trimmed video to {time_seconds:.2f}s: {output_file}") + return output_file + + except Exception as e: + print(f"❌ Error trimming video: {str(e)}") + return None + +def send_sharpest_frame_handler(gallery, selected_idx, frames_to_check=30): + """ + Extract the sharpest frame from the last N frames of the selected video + + Args: + gallery: Gradio gallery component with videos + selected_idx: Index of the selected video + frames_to_check: Number of frames from the end to check + + Returns: + tuple: (image_path, video_path, frame_number, sharpness) + """ + if gallery is None or not gallery: + return None, None, None, "No videos in gallery" + + if selected_idx is None and len(gallery) == 1: + selected_idx = 0 + + if selected_idx is None or selected_idx >= len(gallery): + return None, None, None, "No video selected" + + # Get the video path + item = gallery[selected_idx] + if isinstance(item, tuple): + video_path = item[0] + elif isinstance(item, dict): + video_path = item.get('name') or item.get('data') + else: + video_path = str(item) + + # Extract the sharpest frame + image_path, frame_number, sharpness = extract_sharpest_frame(video_path, frames_to_check) + + if image_path is None: + return None, None, None, "Failed to extract sharpest frame" + + return image_path, video_path, frame_number, f"Extracted frame {frame_number} with sharpness {sharpness:.2f}" + +def trim_and_prepare_for_extension(video_path, frame_number, save_path="outputs"): + """ + Trim the video to the specified frame and prepare for extension. + + Args: + video_path: Path to the video file + frame_number: Frame number to trim to + save_path: Directory to save the trimmed video + + Returns: + tuple: (trimmed_video_path, status_message) + """ + if not video_path or not os.path.exists(video_path): + return None, "No video selected or video file does not exist" + + if frame_number is None: + return None, "No frame number provided, please extract sharpest frame first" + + # Trim the video + trimmed_video = trim_video_to_frame(video_path, frame_number, save_path) + + if trimmed_video is None: + return None, "Failed to trim video" + + return trimmed_video, f"Video trimmed to frame {frame_number} and ready for extension" + +def send_last_frame_handler(gallery, selected_idx): + """Handle sending last frame to input with better error handling""" + if gallery is None or not gallery: + return None, None + + if selected_idx is None and len(gallery) == 1: + selected_idx = 0 + + if selected_idx is None or selected_idx >= len(gallery): + return None, None + + # Get the frame and video path + frame = handle_last_frame_transfer(gallery, selected_idx) + video_path = None + + if selected_idx < len(gallery): + item = gallery[selected_idx] + video_path = parse_video_path(item) + + return frame, video_path + +def extract_last_frame(video_path: str) -> Optional[str]: + """Extract last frame from video and return temporary image path with error handling""" + print(f"\n=== Starting frame extraction ===") + print(f"Input video path: {video_path}") + + if not video_path or not os.path.exists(video_path): + print("❌ Error: Video file does not exist") + return None + + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("❌ Error: Failed to open video file") + return None + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"Total frames detected: {total_frames}") + + if total_frames < 1: + print("❌ Error: Video contains 0 frames") + return None + + # Extract last frame + cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) + success, frame = cap.read() + + if not success or frame is None: + print("❌ Error: Failed to read last frame") + return None + + # Prepare output path + temp_dir = os.path.abspath("temp_frames") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, f"last_frame_{os.path.basename(video_path)}.png") + print(f"Saving frame to: {temp_path}") + + # Write and verify + if not cv2.imwrite(temp_path, frame): + print("❌ Error: Failed to write frame to file") + return None + + if not os.path.exists(temp_path): + print("❌ Error: Output file not created") + return None + + print("✅ Frame extraction successful") + return temp_path + + except Exception as e: + print(f"❌ Unexpected error: {str(e)}") + return None + finally: + if 'cap' in locals(): + cap.release() + +def handle_last_frame_transfer(gallery: list, selected_idx: int) -> Optional[str]: + """Improved frame transfer with video input validation""" + try: + if gallery is None or not gallery: + raise ValueError("No videos generated yet") + + if selected_idx is None: + # Auto-select last generated video if batch_size=1 + if len(gallery) == 1: + selected_idx = 0 + else: + raise ValueError("Please select a video first") + + if selected_idx >= len(gallery): + raise ValueError("Invalid selection index") + + item = gallery[selected_idx] + + # Video file existence check + video_path = parse_video_path(item) + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file missing: {video_path}") + + return extract_last_frame(video_path) + + except Exception as e: + print(f"Frame transfer failed: {str(e)}") + return None + +def parse_video_path(item) -> str: + """Parse different gallery item formats""" + if isinstance(item, tuple): + return item[0] + elif isinstance(item, dict): + return item.get('name') or item.get('data') + return str(item) + +def get_random_image_from_folder(folder_path): + """Get a random image from the specified folder""" + if not os.path.isdir(folder_path): + return None, f"Error: {folder_path} is not a valid directory" + + # Get all image files in the folder + image_files = [] + for ext in ('*.jpg', '*.jpeg', '*.png', '*.bmp', '*.webp'): + image_files.extend(glob.glob(os.path.join(folder_path, ext))) + for ext in ('*.JPG', '*.JPEG', '*.PNG', '*.BMP', '*.WEBP'): + image_files.extend(glob.glob(os.path.join(folder_path, ext))) + + if not image_files: + return None, f"Error: No image files found in {folder_path}" + + # Select a random image + random_image = random.choice(image_files) + return random_image, f"Selected: {os.path.basename(random_image)}" + +def resize_image_keeping_aspect_ratio(image_path, max_width, max_height): + """Resize image keeping aspect ratio and ensuring dimensions are divisible by 16""" + try: + img = Image.open(image_path) + width, height = img.size + + # Calculate aspect ratio + aspect_ratio = width / height + + # Calculate new dimensions while maintaining aspect ratio + if width > height: + new_width = min(max_width, width) + new_height = int(new_width / aspect_ratio) + else: + new_height = min(max_height, height) + new_width = int(new_height * aspect_ratio) + + # Make dimensions divisible by 16 + new_width = math.floor(new_width / 16) * 16 + new_height = math.floor(new_height / 16) * 16 + + # Ensure minimum size + new_width = max(16, new_width) + new_height = max(16, new_height) + + # Resize image + resized_img = img.resize((new_width, new_height), Image.LANCZOS) + + # Save to temporary file + temp_path = f"temp_resized_{os.path.basename(image_path)}" + resized_img.save(temp_path) + + return temp_path, (new_width, new_height) + except Exception as e: + return None, f"Error: {str(e)}" +# Function to process a batch of images from a folder +def batch_handler( + use_random, + prompt, negative_prompt, + width, height, + video_length, fps, infer_steps, + seed, flow_shift, guidance_scale, embedded_cfg_scale, + batch_size, input_folder_path, + dit_folder, model, vae, te1, te2, save_path, output_type, attn_mode, + block_swap, exclude_single_blocks, use_split_attn, use_fp8, split_uncond, + lora_folder, *lora_params +): + """Handle both folder-based batch processing and regular batch processing""" + global stop_event + + # Check if this is a SkyReels model that needs special handling + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + + if use_random: + # Random image from folder mode + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Get random image from folder + random_image, status = get_random_image_from_folder(input_folder_path) + if random_image is None: + yield all_videos, f"Error in batch {i+1}: {status}", "" + continue + + # Resize image + resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) + if resized_image is None: + yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" + continue + + # If we have dimensions, update them + local_width, local_height = width, height + if isinstance(size_info, tuple): + local_width, local_height = size_info + progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height}" + else: + progress_text = f"Using image: {os.path.basename(random_image)}" + + yield all_videos.copy(), batch_text, progress_text + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + # Process the image + # For SkyReels models, we need to create a command with dit_in_channels=32 + if is_skyreels_i2v: + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model + + # Extract parameters from lora_params + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + cmd = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(local_height), str(local_width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(embedded_cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128", + "--dit_in_channels", "32", # This is crucial for SkyReels i2v + "--image_path", resized_image # Pass the image directly + ] + + if use_fp8: + cmd.append("--fp8") + + if split_uncond: + cmd.append("--split_uncond") + + if use_split_attn: + cmd.append("--split_attn") + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + if negative_prompt: + cmd.extend(["--negative_prompt", negative_prompt]) + + if guidance_scale is not None: + cmd.extend(["--guidance_scale", str(guidance_scale)]) + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + print(f"Running command: {' '.join(cmd)}") + + # Run the process + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield all_videos, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield all_videos.copy(), f"Processing video {i+1} (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + all_videos.append((str(video_path), f"Seed: {current_seed}")) + else: + # For non-SkyReels models, use the regular process_single_video function + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + single_video_args = [ + prompt, local_width, local_height, 1, video_length, fps, infer_steps, + current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, embedded_cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + single_video_args.extend(lora_weights) + single_video_args.extend(lora_multipliers) + single_video_args.extend([None, resized_image, None, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) + + for videos, status, progress in process_single_video(*single_video_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clean up temporary file + try: + if os.path.exists(resized_image): + os.remove(resized_image) + except: + pass + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # Regular image input - this is the part we need to fix + # When a SkyReels I2V model is used, we need to use the direct command approach + # with dit_in_channels=32 explicitly specified, just like in the folder processing branch + if is_skyreels_i2v: + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract lora parameters + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + extra_args = list(lora_params[num_lora_weights*2:]) if len(lora_params) > num_lora_weights*2 else [] + + # Print extra_args for debugging + print(f"Extra args: {extra_args}") + image_path = None + if len(extra_args) > 0 and extra_args[0] is not None: + image_path = extra_args[0] + print(f"Image path found in extra_args[0]: {image_path}") + if not image_path: + print("No image path found in extra_args[0]") + print(f"Full lora_params: {lora_params}") + yield [], "Error: No input image provided", "An input image is required for SkyReels I2V models" + return + + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Set up environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model + + # Build the command with dit_in_channels=32 + cmd = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(embedded_cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128", + "--dit_in_channels", "32", # This is crucial for SkyReels i2v + "--image_path", image_path + ] + + if use_fp8: + cmd.append("--fp8") + + if split_uncond: + cmd.append("--split_uncond") + + if use_split_attn: + cmd.append("--split_attn") + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + if negative_prompt: + cmd.extend(["--negative_prompt", negative_prompt]) + + if guidance_scale is not None: + cmd.extend(["--guidance_scale", str(guidance_scale)]) + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip(lora_weights, lora_multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + print(f"Running command: {' '.join(cmd)}") + + # Run the process + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield all_videos, "Generation stopped by user.", "" + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield all_videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos_files = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + all_videos.append((str(video_path), f"Seed: {current_seed}")) + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # For regular non-SkyReels models, use the original process_batch function + regular_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, guidance_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + yield from process_batch(*(regular_args + list(lora_params))) + +def get_dit_models(dit_folder: str) -> List[str]: + """Get list of available DiT models in the specified folder""" + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + +def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: + """Update both DiT and LoRA dropdowns""" + # Get model lists + dit_models = get_dit_models(dit_folder) + lora_choices = get_lora_options(lora_folder) + + # Current values processing + dit_value = current_values[0] + if dit_value not in dit_models: + dit_value = dit_models[0] if dit_models else None + + weights = current_values[1:5] + multipliers = current_values[5:9] + + results = [gr.update(choices=dit_models, value=dit_value)] + + # Add LoRA updates + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in lora_choices: + weight = "None" + results.extend([ + gr.update(choices=lora_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def extract_video_metadata(video_path: str) -> Dict: + """Extract metadata from video file using ffprobe.""" + cmd = [ + 'ffprobe', + '-v', 'quiet', + '-print_format', 'json', + '-show_format', + video_path + ] + + try: + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + metadata = json.loads(result.stdout.decode('utf-8')) + if 'format' in metadata and 'tags' in metadata['format']: + comment = metadata['format']['tags'].get('comment', '{}') + return json.loads(comment) + return {} + except Exception as e: + print(f"Metadata extraction failed: {str(e)}") + return {} + +def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: + """Map metadata parameters to Gradio components for different tabs""" + mapping = { + 'common': { + 'prompt': ('prompt', 'v2v_prompt'), + 'width': ('width', 'v2v_width'), + 'height': ('height', 'v2v_height'), + 'batch_size': ('batch_size', 'v2v_batch_size'), + 'video_length': ('video_length', 'v2v_video_length'), + 'fps': ('fps', 'v2v_fps'), + 'infer_steps': ('infer_steps', 'v2v_infer_steps'), + 'seed': ('seed', 'v2v_seed'), + 'model': ('model', 'v2v_model'), + 'vae': ('vae', 'v2v_vae'), + 'te1': ('te1', 'v2v_te1'), + 'te2': ('te2', 'v2v_te2'), + 'save_path': ('save_path', 'v2v_save_path'), + 'flow_shift': ('flow_shift', 'v2v_flow_shift'), + 'cfg_scale': ('cfg_scale', 'v2v_cfg_scale'), + 'output_type': ('output_type', 'v2v_output_type'), + 'attn_mode': ('attn_mode', 'v2v_attn_mode'), + 'block_swap': ('block_swap', 'v2v_block_swap'), + 'negative_prompt': ('i2v_negative_prompt', 'v2v_negative_prompt'), + 'clip_vision_path': ('i2v_clip_vision_path', None), + 'i2v_stability': ('i2v_stability', None), + 'fp8_fast': ('i2v_fp8_fast', None) + }, + 'lora': { + 'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]') for i in range(4)], + 'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]') for i in range(4)] + } + } + + results = {} + for param, value in metadata.items(): + # Handle common parameters + if param in mapping['common']: + target = mapping['common'][param][0 if target_tab == 't2v' else 1] + results[target] = value + + # Handle LoRA parameters + if param == 'lora_weights': + for i, weight in enumerate(value[:4]): + target = mapping['lora']['lora_weights'][i][1 if target_tab == 'v2v' else 0] + results[target] = weight + + if param == 'lora_multipliers': + for i, mult in enumerate(value[:4]): + target = mapping['lora']['lora_multipliers'][i][1 if target_tab == 'v2v' else 0] + results[target] = float(mult) + + return results + +def add_metadata_to_video(video_path: str, parameters: dict) -> None: + """Add generation parameters to video metadata using ffmpeg.""" + import json + import subprocess + + # Convert parameters to JSON string + params_json = json.dumps(parameters, indent=2) + + # Temporary output path + temp_path = video_path.replace(".mp4", "_temp.mp4") + + # FFmpeg command to add metadata without re-encoding + cmd = [ + 'ffmpeg', + '-i', video_path, + '-metadata', f'comment={params_json}', + '-codec', 'copy', + temp_path + ] + + try: + # Execute FFmpeg command + subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + # Replace original file with the metadata-enhanced version + os.replace(temp_path, video_path) + except subprocess.CalledProcessError as e: + print(f"Failed to add metadata: {e.stderr.decode()}") + if os.path.exists(temp_path): + os.remove(temp_path) + except Exception as e: + print(f"Error: {str(e)}") + +def count_prompt_tokens(prompt: str) -> int: + enc = tiktoken.get_encoding("cl100k_base") + tokens = enc.encode(prompt) + return len(tokens) + + +def get_lora_options(lora_folder: str = "lora") -> List[str]: + if not os.path.exists(lora_folder): + return ["None"] + lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] + lora_files.sort(key=str.lower) + return ["None"] + lora_files + +def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + +def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: + """Transfer selected video and prompt to Video2Video tab""" + if not gallery or evt.index >= len(gallery): + return None, "", selected_index.value + + selected_item = gallery[evt.index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Update the selected index + selected_index.value = evt.index + + return str(video_path), prompt, evt.index + +def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: + """Send the currently selected video to V2V tab""" + if not gallery or selected_index.value is None or selected_index.value >= len(gallery): + return None, "" + + selected_item = gallery[selected_index.value] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return str(video_path), prompt + +def clear_cuda_cache(): + """Clear CUDA cache if available""" + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Optional: synchronize to ensure cache is cleared + torch.cuda.synchronize() + +def wanx_batch_handler( + use_random, + prompt, + negative_prompt, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + batch_size, + input_folder_path, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + *lora_params +): + """Handle both folder-based batch processing and regular processing for WanX""" + global stop_event + + if use_random: + # Random image from folder mode + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Ensure batch_size is treated as an integer + batch_size = int(batch_size) + + # Process each item in the batch separately + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Get random image from folder + random_image, status = get_random_image_from_folder(input_folder_path) + if random_image is None: + yield all_videos, f"Error in batch {i+1}: {status}", "" + continue + + # Resize image + resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) + if resized_image is None: + yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" + continue + + # Use the dimensions returned from the resize function + local_width, local_height = width, height # Default fallback + if isinstance(size_info, tuple): + local_width, local_height = size_info + progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height} (maintaining aspect ratio)" + else: + progress_text = f"Using image: {os.path.basename(random_image)}" + + yield all_videos.copy(), batch_text, progress_text + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + # Extract LoRA weights and multipliers + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + + # Generate video for this image - one at a time + for videos, status, progress in wanx_generate_video( + prompt, + negative_prompt, + resized_image, + local_width, + local_height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + current_seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + *lora_weights, + *lora_multipliers + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clean up temporary file + try: + if os.path.exists(resized_image): + os.remove(resized_image) + except: + pass + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # For non-random mode, if batch_size > 1, we need to process multiple times + # with the same input image but different seeds + if int(batch_size) > 1: + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract LoRA weights and multipliers and input image + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None + + # Process each batch item + for i in range(int(batch_size)): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Generate a single video with the current seed + for videos, status, progress in wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + current_seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + *lora_weights, + *lora_multipliers + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + # Clear CUDA cache between generations + clear_cuda_cache() + time.sleep(0.5) + + yield all_videos, "Batch complete", "" + else: + # Single image, single generation - use existing function + num_lora_weights = 4 + lora_weights = lora_params[:num_lora_weights] + lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] + input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None + + yield from wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + *lora_weights, + *lora_multipliers + ) + +def process_single_video( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + lora1: str = "", + lora2: str = "", + lora3: str = "", + lora4: str = "", + lora1_multiplier: float = 1.0, + lora2_multiplier: float = 1.0, + lora3_multiplier: float = 1.0, + lora4_multiplier: float = 1.0, + video_path: Optional[str] = None, + image_path: Optional[str] = None, + strength: Optional[float] = None, + negative_prompt: Optional[str] = None, + embedded_cfg_scale: Optional[float] = None, + split_uncond: Optional[bool] = None, + guidance_scale: Optional[float] = None, + use_fp8: bool = True +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate a single video with the given parameters""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + if is_skyreels: + # Force certain parameters for SkyReels + if negative_prompt is None: + negative_prompt = "" + if embedded_cfg_scale is None: + embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels + if split_uncond is None: + split_uncond = True + if guidance_scale is None: + guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided + + # Determine the input channels based on model type + if is_skyreels_i2v: + dit_in_channels = 32 # SkyReels I2V uses 32 channels + else: + dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) + else: + dit_in_channels = 16 # Regular Hunyuan models use 16 channels + embedded_cfg_scale = cfg_scale + + if os.path.isabs(model): + model_path = model + else: + model_path = os.path.normpath(os.path.join(dit_folder, model)) + + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + env["BATCH_RUN_ID"] = f"{time.time()}" + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) + if batch_size > 1: # Only modify seed for batch generation + current_seed = (seed + batch_id * 100003) % (2**32) + else: + current_seed = seed + + clear_cuda_cache() + + command = [ + sys.executable, + "hv_generate_video.py", + "--dit", model_path, + "--vae", vae, + "--text_encoder1", te1, + "--text_encoder2", te2, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--embedded_cfg_scale", str(cfg_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--fp8_llm", + "--vae_chunk_size", "32", + "--vae_spatial_tile_sample_min_size", "128" + ] + + if use_fp8: + command.append("--fp8") + + # Add negative prompt and embedded cfg scale for SkyReels + if is_skyreels: + command.extend(["--dit_in_channels", str(dit_in_channels)]) + command.extend(["--guidance_scale", str(guidance_scale)]) + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + if split_uncond: + command.append("--split_uncond") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + if use_split_attn: + command.append("--split_attn") + + # Handle input paths + if video_path: + command.extend(["--video_path", video_path]) + if strength is not None: + command.extend(["--strength", str(strength)]) + elif image_path: + command.extend(["--image_path", image_path]) + # Only add strength parameter for non-SkyReels I2V models + # SkyReels I2V doesn't use strength parameter for image-to-video generation + if strength is not None and not is_skyreels_i2v: + command.extend(["--strength", str(strength)]) + + print(f"{command}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "model": model, + "vae": vae, + "te1": te1, + "te2": te2, + "save_path": save_path, + "flow_shift": flow_shift, + "cfg_scale": cfg_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "lora_weights": [lora1, lora2, lora3, lora4], + "lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], + "input_video": video_path if video_path else None, + "input_image": image_path if image_path else None, + "strength": strength, + "negative_prompt": negative_prompt if is_skyreels else None, + "embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +# The issue is in the process_batch function, in the section that handles different input types +# Here's the corrected version of that section: + +def process_batch( + prompt: str, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + dit_folder: str, + model: str, + vae: str, + te1: str, + te2: str, + save_path: str, + flow_shift: float, + cfg_scale: float, + output_type: str, + attn_mode: str, + block_swap: int, + exclude_single_blocks: bool, + use_split_attn: bool, + lora_folder: str, + *args +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Process a batch of videos using Gradio's queue""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Extract additional arguments + num_lora_weights = 4 + lora_weights = args[:num_lora_weights] + lora_multipliers = args[num_lora_weights:num_lora_weights*2] + extra_args = args[num_lora_weights*2:] + + # Determine if this is a SkyReels model and what type + is_skyreels = "skyreels" in model.lower() + is_skyreels_i2v = is_skyreels and "i2v" in model.lower() + is_skyreels_t2v = is_skyreels and "t2v" in model.lower() + + # Handle input paths and additional parameters + input_path = extra_args[0] if extra_args else None + strength = float(extra_args[1]) if len(extra_args) > 1 else None + + # Get use_fp8 flag (it should be the last parameter) + use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True + + # Get SkyReels specific parameters if applicable + if is_skyreels: + # Always set embedded_cfg_scale to 1.0 for SkyReels models + embedded_cfg_scale = 1.0 + + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" + # Use cfg_scale for guidance_scale parameter + guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale + split_uncond = True if len(extra_args) > 4 and extra_args[4] else False + else: + negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None + guidance_scale = cfg_scale + embedded_cfg_scale = cfg_scale + split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None + + for i in range(batch_size): + if stop_event.is_set(): + break + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Handle different input types + video_path = None + image_path = None + + if input_path: + # Check if it's an image file (common image extensions) + is_image = False + lower_path = input_path.lower() + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') + is_image = any(lower_path.endswith(ext) for ext in image_extensions) + + # Only use image_path for SkyReels I2V models and actual image files + if is_skyreels_i2v and is_image: + image_path = input_path + else: + video_path = input_path + + # Prepare arguments for process_single_video + single_video_args = [ + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder + ] + single_video_args.extend(lora_weights) + single_video_args.extend(lora_multipliers) + single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) + + for videos, status, progress in process_single_video(*single_video_args): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def update_wanx_image_dimensions(image): + """Update dimensions from uploaded image""" + if image is None: + return "", gr.update(value=832), gr.update(value=480) + img = Image.open(image) + w, h = img.size + w = (w // 32) * 32 + h = (h // 32) * 32 + return f"{w}x{h}", w, h + +def calculate_wanx_width(height, original_dims): + """Calculate width based on height maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 32) * 32 + return gr.update(value=new_width) + +def calculate_wanx_height(width, original_dims): + """Calculate height based on width maintaining aspect ratio""" + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 32) * 32 + return gr.update(value=new_height) + +def update_wanx_from_scale(scale, original_dims): + """Update dimensions based on scale percentage""" + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 32) * 32 + new_h = math.floor((orig_h * scale / 100) / 32) * 32 + return gr.update(value=new_w), gr.update(value=new_h) + +def recommend_wanx_flow_shift(width, height): + """Get recommended flow shift value based on dimensions""" + recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 + return gr.update(value=recommended_shift) + +def handle_wanx_gallery_select(evt: gr.SelectData, gallery) -> tuple: + """Track selected index and video path when gallery item is clicked""" + if gallery is None: + return None, None + + if evt.index >= len(gallery): + return None, None + + selected_item = gallery[evt.index] + video_path = None + + # Extract the video path based on the item type + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + return evt.index, video_path + +def wanx_generate_video( + prompt, + negative_prompt, + input_image, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0 +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate video with WanX model (supports both i2v and t2v)""" + global stop_event + + if stop_event.is_set(): + yield [], "", "" + return + + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + else: + current_seed = seed + + # Check if we need input image (required for i2v, not for t2v) + if "i2v" in task and not input_image: + yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" + return + + # Prepare environment + env = os.environ.copy() + env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") + env["PYTHONIOENCODING"] = "utf-8" + + clear_cuda_cache() + + command = [ + sys.executable, + "wan_generate_video.py", + "--task", task, + "--prompt", prompt, + "--video_size", str(height), str(width), + "--video_length", str(video_length), + "--fps", str(fps), + "--infer_steps", str(infer_steps), + "--save_path", save_path, + "--seed", str(current_seed), + "--flow_shift", str(flow_shift), + "--guidance_scale", str(guidance_scale), + "--output_type", output_type, + "--attn_mode", attn_mode, + "--blocks_to_swap", str(block_swap), + "--dit", dit_path, + "--vae", vae_path, + "--t5", t5_path, + "--sample_solver", sample_solver + ] + + # Add image path only for i2v task and if input image is provided + if "i2v" in task and input_image: + command.extend(["--image_path", input_image]) + command.extend(["--clip", clip_path]) # CLIP is only needed for i2v + + if negative_prompt: + command.extend(["--negative_prompt", negative_prompt]) + + if fp8: + command.append("--fp8") + + if fp8_t5: + command.append("--fp8_t5") + + if exclude_single_blocks: + command.append("--exclude_single_blocks") + + # Add LoRA weights and multipliers if provided + valid_loras = [] + for weight, mult in zip([lora1, lora2, lora3, lora4], + [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): + if weight and weight != "None": + valid_loras.append((os.path.join(lora_folder, weight), mult)) + if valid_loras: + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + command.extend(["--lora_weight"] + weights) + command.extend(["--lora_multiplier"] + multipliers) + + print(f"Running: {' '.join(command)}") + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + encoding='utf-8', + errors='replace', + bufsize=1 + ) + + videos = [] + + while True: + if stop_event.is_set(): + p.terminate() + p.wait() + yield [], "", "Generation stopped by user." + return + + line = p.stdout.readline() + if not line: + if p.poll() is not None: + break + continue + + print(line, end='') + if '|' in line and '%' in line and '[' in line and ']' in line: + yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() + + p.stdout.close() + p.wait() + + clear_cuda_cache() + time.sleep(0.5) + + # Collect generated video + save_path_abs = os.path.abspath(save_path) + if os.path.exists(save_path_abs): + all_videos = sorted( + [f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], + key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), + reverse=True + ) + matching_videos = [v for v in all_videos if f"_{current_seed}" in v] + if matching_videos: + video_path = os.path.join(save_path_abs, matching_videos[0]) + + # Collect parameters for metadata + parameters = { + "prompt": prompt, + "width": width, + "height": height, + "video_length": video_length, + "fps": fps, + "infer_steps": infer_steps, + "seed": current_seed, + "task": task, + "flow_shift": flow_shift, + "guidance_scale": guidance_scale, + "output_type": output_type, + "attn_mode": attn_mode, + "block_swap": block_swap, + "input_image": input_image if "i2v" in task else None + } + + add_metadata_to_video(video_path, parameters) + videos.append((str(video_path), f"Seed: {current_seed}")) + + yield videos, f"Completed (seed: {current_seed})", "" + +def send_wanx_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + guidance_scale: float, + negative_prompt: str +) -> Tuple: + """Send the selected WanX video to Video2Video tab""" + if gallery is None or not gallery: + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + # If no selection made but we have videos, use the first one + if selected_index is None and len(gallery) > 0: + selected_index = 0 + + if selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + else: + video_path = selected_item + + # Clean up path for Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Make sure it's a string + video_path = str(video_path) + + return (video_path, prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def wanx_generate_video_batch( + prompt, + negative_prompt, + width, + height, + video_length, + fps, + infer_steps, + flow_shift, + guidance_scale, + seed, + task, + dit_path, + vae_path, + t5_path, + clip_path, + save_path, + output_type, + sample_solver, + exclude_single_blocks, + attn_mode, + block_swap, + fp8, + fp8_t5, + lora_folder, + lora1="None", + lora2="None", + lora3="None", + lora4="None", + lora1_multiplier=1.0, + lora2_multiplier=1.0, + lora3_multiplier=1.0, + lora4_multiplier=1.0, + batch_size=1, + input_image=None # Make input_image optional and place it at the end +) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: + """Generate videos with WanX with support for batches""" + global stop_event + stop_event.clear() + + all_videos = [] + progress_text = "Starting generation..." + yield [], "Preparing...", progress_text + + # Process each item in the batch + for i in range(batch_size): + if stop_event.is_set(): + yield all_videos, "Generation stopped by user", "" + return + + # Calculate seed for this batch item + current_seed = seed + if seed == -1: + current_seed = random.randint(0, 2**32 - 1) + elif batch_size > 1: + current_seed = seed + i + + batch_text = f"Generating video {i + 1} of {batch_size}" + yield all_videos.copy(), batch_text, progress_text + + # Generate a single video using the existing function + for videos, status, progress in wanx_generate_video( + prompt, negative_prompt, input_image, width, height, + video_length, fps, infer_steps, flow_shift, guidance_scale, + current_seed, task, dit_path, vae_path, t5_path, clip_path, + save_path, output_type, sample_solver, exclude_single_blocks, + attn_mode, block_swap, fp8, fp8_t5, + lora_folder, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier + ): + if videos: + all_videos.extend(videos) + yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress + + yield all_videos, "Batch complete", "" + +def update_wanx_t2v_dimensions(size): + """Update width and height based on selected size""" + width, height = map(int, size.split('*')) + return gr.update(value=width), gr.update(value=height) + +def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when gallery item is clicked""" + return evt.index + +def send_wanx_t2v_to_v2v( + gallery, prompt, selected_index, width, height, video_length, + fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt +) -> Tuple: + """Send the selected WanX T2V video to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, guidance_scale, negative_prompt) + +def prepare_for_batch_extension(input_img, base_video, batch_size): + """Prepare inputs for batch video extension""" + if input_img is None: + return None, None, batch_size, "No input image found", "" + + if base_video is None: + return input_img, None, batch_size, "No base video selected for extension", "" + + return input_img, base_video, batch_size, "Preparing batch extension...", f"Will create {batch_size} variations of extended video" + +def concat_batch_videos(base_video_path, generated_videos, save_path, original_video_path=None): + """Concatenate multiple generated videos with the base video""" + if not base_video_path: + return [], "No base video provided" + + if not generated_videos or len(generated_videos) == 0: + return [], "No new videos generated" + + # Create output directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + + # Track all extended videos + extended_videos = [] + + # For each generated video, create an extended version + for i, video_item in enumerate(generated_videos): + try: + # Extract video path from gallery item + if isinstance(video_item, tuple): + new_video_path = video_item[0] + seed_info = video_item[1] if len(video_item) > 1 else "" + elif isinstance(video_item, dict): + new_video_path = video_item.get("name", video_item.get("data", None)) + seed_info = "" + else: + new_video_path = video_item + seed_info = "" + + if not new_video_path or not os.path.exists(new_video_path): + print(f"Skipping missing video: {new_video_path}") + continue + + # Create unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + # Extract seed from seed_info if available + seed_match = re.search(r"Seed: (\d+)", seed_info) + seed_part = f"_seed{seed_match.group(1)}" if seed_match else f"_{i}" + + output_filename = f"extended_{timestamp}{seed_part}_{Path(base_video_path).stem}.mp4" + output_path = os.path.join(save_path, output_filename) + + # Create a temporary file list for ffmpeg + list_file = os.path.join(save_path, f"temp_list_{i}.txt") + with open(list_file, "w") as f: + f.write(f"file '{os.path.abspath(base_video_path)}'\n") + f.write(f"file '{os.path.abspath(new_video_path)}'\n") + + # Run ffmpeg concatenation + command = [ + "ffmpeg", + "-f", "concat", + "-safe", "0", + "-i", list_file, + "-c", "copy", + output_path + ] + + subprocess.run(command, check=True, capture_output=True) + + # Clean up temporary file + if os.path.exists(list_file): + os.remove(list_file) + + # Add to extended videos list if successful + if os.path.exists(output_path): + seed_display = f"Extended {seed_info}" if seed_info else f"Extended video #{i+1}" + extended_videos.append((output_path, seed_display)) + + except Exception as e: + print(f"Error processing video {i}: {str(e)}") + + if not extended_videos: + return [], "Failed to create any extended videos" + + return extended_videos, f"Successfully created {len(extended_videos)} extended videos" + +def handle_extend_generation(base_video_path: str, new_videos: list, save_path: str, current_gallery: list) -> tuple: + """Combine generated video with base video and update gallery""" + if not base_video_path: + return current_gallery, "Extend failed: No base video provided" + + if not new_videos: + return current_gallery, "Extend failed: No new video generated" + + # Ensure save path exists + os.makedirs(save_path, exist_ok=True) + + # Get the first video from new_videos (gallery item) + new_video_path = new_videos[0][0] if isinstance(new_videos[0], tuple) else new_videos[0] + + # Create a unique output filename + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + output_filename = f"extended_{timestamp}_{Path(base_video_path).stem}.mp4" + output_path = str(Path(save_path) / output_filename) + + try: + # Concatenate the videos using ffmpeg + ( + ffmpeg + .input(base_video_path) + .concat( + ffmpeg.input(new_video_path) + ) + .output(output_path) + .run(overwrite_output=True, quiet=True) + ) + + # Create a new gallery entry with the combined video + updated_gallery = [(output_path, f"Extended video: {Path(output_path).stem}")] + + return updated_gallery, f"Successfully extended video to {Path(output_path).name}" + except Exception as e: + print(f"Error extending video: {str(e)}") + return current_gallery, f"Failed to extend video: {str(e)}" + +# UI setup +with gr.Blocks( + theme=themes.Default( + primary_hue=colors.Color( + name="custom", + c50="#E6F0FF", + c100="#CCE0FF", + c200="#99C1FF", + c300="#66A3FF", + c400="#3384FF", + c500="#0060df", # This is your main color + c600="#0052C2", + c700="#003D91", + c800="#002961", + c900="#001430", + c950="#000A18" + ) + ), + css=""" + .gallery-item:first-child { border: 2px solid #4CAF50 !important; } + .gallery-item:first-child:hover { border-color: #45a049 !important; } + .green-btn { + background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; + color: white !important; + border: none !important; + } + .green-btn:hover { + background: linear-gradient(to bottom right, #27ae60, #219651) !important; + } + .refresh-btn { + max-width: 40px !important; + min-width: 40px !important; + height: 40px !important; + border-radius: 50% !important; + padding: 0 !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + } + """, + +) as demo: + # Add state for tracking selected video indices in both tabs + selected_index = gr.State(value=None) # For Text to Video + v2v_selected_index = gr.State(value=None) # For Video to Video + params_state = gr.State() #New addition + i2v_selected_index = gr.State(value=None) + skyreels_selected_index = gr.State(value=None) + wanx_i2v_selected_index = gr.State(value=None) + extended_videos = gr.State(value=[]) + wanx_base_video = gr.State(value=None) + wanx_sharpest_frame_number = gr.State(value=None) + wanx_sharpest_frame_path = gr.State(value=None) + wanx_trimmed_video_path = gr.State(value=None) + demo.load(None, None, None, js=""" + () => { + document.title = 'H1111'; + + function updateTitle(text) { + if (text && text.trim()) { + const progressMatch = text.match(/(\d+)%.*\[.*<(\d+:\d+),/); + if (progressMatch) { + const percentage = progressMatch[1]; + const timeRemaining = progressMatch[2]; + document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; + } + } + } + + setTimeout(() => { + const progressElements = document.querySelectorAll('textarea.scroll-hide'); + progressElements.forEach(element => { + if (element) { + new MutationObserver(() => { + updateTitle(element.value); + }).observe(element, { + attributes: true, + childList: true, + characterData: true + }); + } + }); + }, 1000); + } + """) + + with gr.Tabs() as tabs: + # Text to Video Tab + with gr.Tab(id=1, label="Hunyuan-t2v"): + with gr.Row(): + with gr.Column(scale=4): + prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + + with gr.Column(scale=1): + token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + + t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") + t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") + video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") + fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") + infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") + flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") + cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") + + with gr.Column(): + + with gr.Row(): + video_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + with gr.Row(): + refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + lora_weights = [] + lora_multipliers = [] + for i in range(4): + with gr.Column(): + lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + seed = gr.Number(label="Seed (use -1 for random)", value=-1) + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + + #Image to Video Tab + with gr.Tab(label="Hunyuan-i2v") as i2v_tab: + with gr.Row(): + with gr.Column(scale=4): + i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + i2v_negative_prompt = gr.Textbox(label="Negative Prompt", value="", lines=2, info="Negative prompt") + + with gr.Column(scale=1): + i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + i2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + i2v_input = gr.Image(label="Input Image", type="filepath") + i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + # Scale slider as percentage + scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + # Width and height inputs + with gr.Row(): + width = gr.Number(label="New Width", value=544, step=16) + calc_height_btn = gr.Button("→") + calc_width_btn = gr.Button("←") + height = gr.Number(label="New Height", value=544, step=16) + i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) + with gr.Column(): + i2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + i2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for Image2Video + i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + i2v_lora_weights = [] + i2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + i2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + i2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + i2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + + i2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + i2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava-llama-3-8b-text-encoder-tokenizer") + i2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + i2v_clip_vision_path = gr.Textbox(label="CLIP Vision Path", value="hunyuan/clip-vit-large-patch14", info="Path to CLIP vision model for HunyuanI2V") + i2v_save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + i2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + i2v_split_uncond = gr.Checkbox(label="Split Unconditional", value=True, visible=True) + with gr.Row(): + i2v_stability = gr.Checkbox(label="I2V Stability", value=False, info="Enable stability mode for HunyuanI2V") + i2v_fp8_fast = gr.Checkbox(label="FP8 Fast", value=False, info="Enable fast FP8 arithmetic (RTX 4XXX+)") + i2v_compile = gr.Checkbox(label="Compile Model", value=False, info="Enable torch.compile for potentially faster generation") + i2v_compile_backend = gr.Dropdown(label="Compile Backend", choices=["inductor", "cudagraphs", "onnxrt", "nvfuser"], value="inductor", info="Torch compile backend") + i2v_compile_mode = gr.Dropdown(label="Compile Mode", choices=["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], value="max-autotune-no-cudagraphs", info="Torch compile mode") + i2v_compile_dynamic = gr.Checkbox(label="Dynamic Shapes", value=False, info="Use dynamic shapes in compilation") + i2v_compile_fullgraph = gr.Checkbox(label="Full Graph", value=False, info="Use full graph compilation") + # Video to Video Tab + with gr.Tab(id=2, label="Hunyuan-v2v") as v2v_tab: + with gr.Row(): + with gr.Column(scale=4): + v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) + v2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt (for SkyReels models)", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + v2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + v2v_input = gr.Video(label="Input Video", format="mp4") + v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and Height Inputs + with gr.Row(): + v2v_width = gr.Number(label="New Width", value=544, step=16) + v2v_calc_height_btn = gr.Button("→") + v2v_calc_width_btn = gr.Button("←") + v2v_height = gr.Number(label="New Height", value=544, step=16) + v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) + with gr.Column(): + v2v_output = gr.Gallery( + label="Generated Videos", + columns=[1], + rows=[1], + object_fit="contain", + height="auto" + ) + v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button + v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + v2v_lora_weights = [] + v2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + v2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + v2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + v2v_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("hunyuan"), + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + v2v_save_path = gr.Textbox(label="Save Path", value="outputs") + with gr.Row(): + v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) + +### SKYREELS + + with gr.Tab(label="SkyReels-i2v") as skyreels_tab: + with gr.Row(): + with gr.Column(scale=4): + skyreels_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + skyreels_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", + lines=3 + ) + + with gr.Column(scale=1): + skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") + with gr.Row(): + skyreels_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) + skyreels_input_folder = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images", + visible=False + ) + skyreels_folder_status = gr.Textbox( + label="Folder Status", + placeholder="Status will appear here", + interactive=False, + visible=False + ) + skyreels_validate_folder_btn = gr.Button("Validate Folder", visible=False) + skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") + + # Scale slider as percentage + skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height inputs + with gr.Row(): + skyreels_width = gr.Number(label="New Width", value=544, step=16) + skyreels_calc_height_btn = gr.Button("→") + skyreels_calc_width_btn = gr.Button("←") + skyreels_height = gr.Number(label="New Height", value=544, step=16) + + skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) + skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) + skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) + skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) + skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) + skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) + + with gr.Column(): + skyreels_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for SKYREELS + skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + skyreels_lora_weights = [] + skyreels_lora_multipliers = [] + for i in range(4): + with gr.Column(): + skyreels_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + skyreels_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + skyreels_model = gr.Dropdown( + label="DiT Model", + choices=get_dit_models("skyreels"), + value="skyreels_hunyuan_i2v_bf16.safetensors", + allow_custom_value=True, + interactive=True + ) + skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") + skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") + skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") + skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) + skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) + skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) + skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) + + # WanX Image to Video Tab + with gr.Tab(id=4, label="WanX-i2v") as wanx_i2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + wanx_input = gr.Image(label="Input Image", type="filepath") + with gr.Row(): + wanx_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) + wanx_input_folder = gr.Textbox( + label="Image Folder Path", + placeholder="Path to folder containing images", + visible=False + ) + wanx_folder_status = gr.Textbox( + label="Folder Status", + placeholder="Status will appear here", + interactive=False, + visible=False + ) + wanx_validate_folder_btn = gr.Button("Validate Folder", visible=False) + wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") + wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) + + # Width and height display + with gr.Row(): + wanx_width = gr.Number(label="Width", value=832, interactive=True) + wanx_calc_height_btn = gr.Button("→") + wanx_calc_width_btn = gr.Button("←") + wanx_height = gr.Number(label="Height", value=480, interactive=True) + wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, + info="Recommended: 3.0 for 480p, 5.0 for others") + wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + wanx_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") + wanx_send_last_frame_btn = gr.Button("Send Last Frame to Input") + wanx_extend_btn = gr.Button("Extend Video") + wanx_frames_to_check = gr.Slider(minimum=1, maximum=100, step=1, value=30, + label="Frames to Check from End", + info="Number of frames from the end to check for sharpness") + wanx_send_sharpest_frame_btn = gr.Button("Extract Sharpest Frame") + wanx_trim_and_extend_btn = gr.Button("Trim Video & Prepare for Extension") + wanx_sharpest_frame_status = gr.Textbox(label="Status", interactive=False) + + # Add a new button for directly extending with the trimmed video + wanx_extend_with_trimmed_btn = gr.Button("Extend with Trimmed Video") + + # Add LoRA section for WanX-i2v similar to other tabs + wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_lora_weights = [] + wanx_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_task = gr.Dropdown( + label="Task", + choices=["i2v-14B"], + value="i2v-14B", + info="Currently only i2v-14B is supported" + ) + wanx_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_i2v_480p_14B_bf16.safetensors") + wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") + wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) + wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + #WanX-t2v Tab + + # WanX Text to Video Tab + with gr.Tab(id=5, label="WanX-t2v") as wanx_t2v_tab: + with gr.Row(): + with gr.Column(scale=4): + wanx_t2v_prompt = gr.Textbox( + scale=3, + label="Enter your prompt", + value="A person walking on a beach at sunset", + lines=5 + ) + wanx_t2v_negative_prompt = gr.Textbox( + scale=3, + label="Negative Prompt", + value="", + lines=3, + info="Leave empty to use default negative prompt" + ) + + with gr.Column(scale=1): + wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) + wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) + + with gr.Column(scale=2): + wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") + wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") + + with gr.Row(): + wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") + wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") + wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") + wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") + + wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) + wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) + wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) + wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, + info="Recommended: 3.0 for I2V with 480p, 5.0 for others") + wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) + + with gr.Column(): + wanx_t2v_output = gr.Gallery( + label="Generated Videos (Click to select)", + columns=[2], + rows=[2], + object_fit="contain", + height="auto", + show_label=True, + elem_id="gallery", + allow_preview=True, + preview=True + ) + wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") + + # Add LoRA section for WanX-t2v + wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + wanx_t2v_lora_weights = [] + wanx_t2v_lora_multipliers = [] + for i in range(4): + with gr.Column(): + wanx_t2v_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + wanx_t2v_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + + with gr.Row(): + wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) + wanx_t2v_task = gr.Dropdown( + label="Task", + choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], + value="t2v-14B", + info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" + ) + wanx_t2v_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_t2v_14B_bf16.safetensors") + wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") + wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") + wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") + wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") + + with gr.Row(): + wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") + wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") + wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") + wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, + info="Max 39 for 14B model, 29 for 1.3B model") + wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) + wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) + + #Video Info Tab + with gr.Tab("Video Info") as video_info_tab: + with gr.Row(): + video_input = gr.Video(label="Upload Video", interactive=True) + metadata_output = gr.JSON(label="Generation Parameters") + + with gr.Row(): + send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") + send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") + send_to_wanx_i2v_btn = gr.Button("Send to WanX-i2v", variant="primary") + send_to_wanx_t2v_btn = gr.Button("Send to WanX-t2v", variant="primary") + + with gr.Row(): + status = gr.Textbox(label="Status", interactive=False) + + #Merge Model's tab + with gr.Tab("Convert LoRA") as convert_lora_tab: + def suggest_output_name(file_obj) -> str: + """Generate suggested output name from input file""" + if not file_obj: + return "" + # Get input filename without extension and add MUSUBI + base_name = os.path.splitext(os.path.basename(file_obj.name))[0] + return f"{base_name}_MUSUBI" + + def convert_lora(input_file, output_name: str, target_format: str) -> str: + """Convert LoRA file to specified format""" + try: + if not input_file: + return "Error: No input file selected" + + # Ensure output directory exists + os.makedirs("lora", exist_ok=True) + + # Construct output path + output_path = os.path.join("lora", f"{output_name}.safetensors") + + # Build command + cmd = [ + sys.executable, + "convert_lora.py", + "--input", input_file.name, + "--output", output_path, + "--target", target_format + ] + + print(f"Converting {input_file.name} to {output_path}") + + # Execute conversion + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + if os.path.exists(output_path): + return f"Successfully converted LoRA to {output_path}" + else: + return "Error: Output file not created" + + except subprocess.CalledProcessError as e: + return f"Error during conversion: {e.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + with gr.Row(): + input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) + output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") + format_radio = gr.Radio( + choices=["default", "other"], + value="default", + label="Target Format", + info="Choose 'default' for H1111/MUSUBI format or 'other' for diffusion pipe format" + ) + + with gr.Row(): + convert_btn = gr.Button("Convert LoRA", variant="primary") + status_output = gr.Textbox(label="Status", interactive=False) + + # Automatically update output name when file is selected + input_file.change( + fn=suggest_output_name, + inputs=[input_file], + outputs=[output_name] + ) + + # Handle conversion + convert_btn.click( + fn=convert_lora, + inputs=[input_file, output_name, format_radio], + outputs=status_output + ) + with gr.Tab("Model Merging") as model_merge_tab: + with gr.Row(): + with gr.Column(): + # Model selection + dit_model = gr.Dropdown( + label="Base DiT Model", + choices=["mp_rank_00_model_states.pt"], + value="mp_rank_00_model_states.pt", + allow_custom_value=True, + interactive=True + ) + merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") + with gr.Row(): + with gr.Column(): + # Output model name + output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") + exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) + merge_btn = gr.Button("Merge Models", variant="primary") + merge_status = gr.Textbox(label="Status", interactive=False) + with gr.Row(): + # LoRA selection section (similar to Text2Video) + merge_lora_weights = [] + merge_lora_multipliers = [] + for i in range(4): + with gr.Column(): + merge_lora_weights.append(gr.Dropdown( + label=f"LoRA {i+1}", + choices=get_lora_options(), + value="None", + allow_custom_value=True, + interactive=True + )) + merge_lora_multipliers.append(gr.Slider( + label=f"Multiplier", + minimum=0.0, + maximum=2.0, + step=0.05, + value=1.0 + )) + with gr.Row(): + merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") + dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") + + #Video Extension + wanx_send_last_frame_btn.click( + fn=send_last_frame_handler, + inputs=[wanx_output, wanx_i2v_selected_index], + outputs=[wanx_input, wanx_base_video] + ) + + wanx_extend_btn.click( + fn=prepare_for_batch_extension, + inputs=[wanx_input, wanx_base_video, wanx_batch_size], + outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] + ).then( + fn=wanx_batch_handler, + inputs=[ + gr.Checkbox(value=False), # Not using random folder + wanx_prompt, wanx_negative_prompt, + wanx_width, wanx_height, wanx_video_length, + wanx_fps, wanx_infer_steps, wanx_flow_shift, + wanx_guidance_scale, wanx_seed, wanx_batch_size, + wanx_input_folder, # Not used but needed for function signature + wanx_task, + wanx_dit_path, wanx_vae_path, wanx_t5_path, + wanx_clip_path, wanx_save_path, wanx_output_type, + wanx_sample_solver, wanx_exclude_single_blocks, + wanx_attn_mode, wanx_block_swap, wanx_fp8, + wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, + *wanx_lora_multipliers, wanx_input # Include input image + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] + ).then( + fn=concat_batch_videos, + inputs=[wanx_base_video, wanx_output, wanx_save_path], + outputs=[wanx_output, wanx_progress_text] + ) + + # Extract and send sharpest frame to input + wanx_send_sharpest_frame_btn.click( + fn=send_sharpest_frame_handler, + inputs=[wanx_output, wanx_i2v_selected_index, wanx_frames_to_check], + outputs=[wanx_input, wanx_base_video, wanx_sharpest_frame_number, wanx_sharpest_frame_status] + ) + + # Trim video to sharpest frame and prepare for extension + wanx_trim_and_extend_btn.click( + fn=trim_and_prepare_for_extension, + inputs=[wanx_base_video, wanx_sharpest_frame_number, wanx_save_path], + outputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status] + ).then( + fn=lambda path, status: (path, status if "Failed" in status else "Video trimmed successfully and ready for extension"), + inputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status], + outputs=[wanx_base_video, wanx_sharpest_frame_status] + ) + + # Event handler for extending with the trimmed video + wanx_extend_with_trimmed_btn.click( + fn=prepare_for_batch_extension, + inputs=[wanx_input, wanx_trimmed_video_path, wanx_batch_size], + outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] + ).then( + fn=wanx_batch_handler, + inputs=[ + gr.Checkbox(value=False), # Not using random folder + wanx_prompt, wanx_negative_prompt, + wanx_width, wanx_height, wanx_video_length, + wanx_fps, wanx_infer_steps, wanx_flow_shift, + wanx_guidance_scale, wanx_seed, wanx_batch_size, + wanx_input_folder, # Not used but needed for function signature + wanx_task, + wanx_dit_path, wanx_vae_path, wanx_t5_path, + wanx_clip_path, wanx_save_path, wanx_output_type, + wanx_sample_solver, wanx_exclude_single_blocks, + wanx_attn_mode, wanx_block_swap, wanx_fp8, + wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, + *wanx_lora_multipliers, wanx_input # Include input image + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] + ).then( + fn=concat_batch_videos, + inputs=[wanx_trimmed_video_path, wanx_output, wanx_save_path], + outputs=[wanx_output, wanx_progress_text] + ) + + #Video Info + def handle_send_to_wanx_tab(metadata, target_tab): + """Common handler for sending video parameters to WanX tabs""" + if not metadata: + return "No parameters to send", {} + + # Tab names for clearer messages + tab_names = { + 'wanx_i2v': 'WanX-i2v', + 'wanx_t2v': 'WanX-t2v' + } + + # Just pass through all parameters - we'll use them in the .then() function + return f"Parameters ready for {tab_names.get(target_tab, target_tab)}", metadata + + def change_to_wanx_i2v_tab(): + return gr.Tabs(selected=4) # WanX-i2v tab index + + def change_to_wanx_t2v_tab(): + return gr.Tabs(selected=5) # WanX-t2v tab index + + send_to_wanx_i2v_btn.click( + fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_i2v'), + inputs=[metadata_output], + outputs=[status, params_state] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 40), + params.get("seed", -1), + params.get("flow_shift", 3.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0), + params.get("task", "i2v-14B") + ] if params else [gr.update()]*12, + inputs=params_state, + outputs=[ + wanx_prompt, + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_seed, + wanx_flow_shift, + wanx_guidance_scale, + wanx_attn_mode, + wanx_block_swap, + wanx_task + ] + ).then( + fn=change_to_wanx_i2v_tab, inputs=None, outputs=[tabs] + ) + + # 3. Update the WanX-t2v button handler + send_to_wanx_t2v_btn.click( + fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_t2v'), + inputs=[metadata_output], + outputs=[status, params_state] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 832), + params.get("height", 480), + params.get("video_length", 81), + params.get("fps", 16), + params.get("infer_steps", 50), + params.get("seed", -1), + params.get("flow_shift", 5.0), + params.get("guidance_scale", 5.0), + params.get("attn_mode", "sdpa"), + params.get("block_swap", 0) + ] if params else [gr.update()]*11, + inputs=params_state, + outputs=[ + wanx_t2v_prompt, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_seed, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_attn_mode, + wanx_t2v_block_swap + ] + ).then( + fn=change_to_wanx_t2v_tab, inputs=None, outputs=[tabs] + ) + + #text to video + def change_to_tab_one(): + return gr.Tabs(selected=1) #This will navigate + #video to video + def change_to_tab_two(): + return gr.Tabs(selected=2) #This will navigate + def change_to_skyreels_tab(): + return gr.Tabs(selected=3) + + #SKYREELS TAB!!! + # Add state management for dimensions + def sync_skyreels_dimensions(width, height): + return gr.update(value=width), gr.update(value=height) + + # Add this function to update the LoRA dropdowns in the SKYREELS tab + def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: + new_choices = get_lora_options(lora_folder) + weights = current_values[:4] + multipliers = current_values[4:8] + + results = [] + for i in range(4): + weight = weights[i] if i < len(weights) else "None" + multiplier = multipliers[i] if i < len(multipliers) else 1.0 + if weight not in new_choices: + weight = "None" + results.extend([ + gr.update(choices=new_choices, value=weight), + gr.update(value=multiplier) + ]) + + return results + + # Add this function to update the models dropdown in the SKYREELS tab + def update_skyreels_model_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Add event handler for model dropdown refresh + skyreels_dit_folder.change( + fn=update_skyreels_model_dropdown, + inputs=[skyreels_dit_folder], + outputs=[skyreels_model] + ) + + # Add handlers for the refresh button + skyreels_refresh_btn.click( + fn=update_skyreels_lora_dropdowns, + inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] + ) + # Skyreels dimension handling + def calculate_skyreels_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 + return gr.update(value=new_width) + + def calculate_skyreels_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 + return gr.update(value=new_height) + + def update_skyreels_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_skyreels_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + def send_skyreels_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float, + negative_prompt: str = "" # Add this parameter + ) -> Tuple: + if not gallery or selected_index is None or selected_index >= len(gallery): + return (None, "", width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + selected_item = gallery[selected_index] + + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + if isinstance(video_path, tuple): + video_path = video_path[0] + + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + negative_prompt) # Add negative_prompt to return + + # Add event handlers for the SKYREELS tab + skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) + skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling + skyreels_input.change( + fn=update_skyreels_dimensions, + inputs=[skyreels_input], + outputs=[skyreels_original_dims, skyreels_width, skyreels_height] + ) + + skyreels_scale_slider.change( + fn=update_skyreels_from_scale, + inputs=[skyreels_scale_slider, skyreels_original_dims], + outputs=[skyreels_width, skyreels_height] + ) + + skyreels_calc_width_btn.click( + fn=calculate_skyreels_width, + inputs=[skyreels_height, skyreels_original_dims], + outputs=[skyreels_width] + ) + + skyreels_calc_height_btn.click( + fn=calculate_skyreels_height, + inputs=[skyreels_width, skyreels_original_dims], + outputs=[skyreels_height] + ) + + # Handle checkbox visibility toggling + skyreels_use_random_folder.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), + inputs=[skyreels_use_random_folder], + outputs=[skyreels_input_folder, skyreels_folder_status, skyreels_input] + ) + + # Validate folder button click handler + skyreels_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], + inputs=[skyreels_input_folder], + outputs=[skyreels_folder_status] + ) + + skyreels_use_random_folder.change( + fn=lambda x: gr.update(visible=x), + inputs=[skyreels_use_random_folder], + outputs=[skyreels_validate_folder_btn] + ) + + # Modify the skyreels_generate_btn.click event handler to use process_random_image_batch when folder mode is on + skyreels_generate_btn.click( + fn=batch_handler, + inputs=[ + skyreels_use_random_folder, + # Rest of the arguments + skyreels_prompt, + skyreels_negative_prompt, + skyreels_width, + skyreels_height, + skyreels_video_length, + skyreels_fps, + skyreels_infer_steps, + skyreels_seed, + skyreels_flow_shift, + skyreels_guidance_scale, + skyreels_embedded_cfg_scale, + skyreels_batch_size, + skyreels_input_folder, + skyreels_dit_folder, + skyreels_model, + skyreels_vae, + skyreels_te1, + skyreels_te2, + skyreels_save_path, + skyreels_output_type, + skyreels_attn_mode, + skyreels_block_swap, + skyreels_exclude_single_blocks, + skyreels_use_split_attn, + skyreels_use_fp8, + skyreels_split_uncond, + skyreels_lora_folder, + *skyreels_lora_weights, + *skyreels_lora_multipliers, + skyreels_input # Add the input image path + ], + outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[skyreels_batch_size], + outputs=skyreels_selected_index + ) + + # Gallery selection handling + skyreels_output.select( + fn=handle_skyreels_gallery_select, + outputs=skyreels_selected_index + ) + + # Send to Video2Video handler + skyreels_send_to_v2v_btn.click( + fn=send_skyreels_to_v2v, + inputs=[ + skyreels_output, skyreels_prompt, skyreels_selected_index, + skyreels_width, skyreels_height, skyreels_video_length, + skyreels_fps, skyreels_infer_steps, skyreels_seed, + skyreels_flow_shift, skyreels_guidance_scale + ] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component + outputs=[ + v2v_input, v2v_prompt, v2v_width, v2v_height, + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + + # Refresh button handler + skyreels_refresh_outputs = [skyreels_model] + for i in range(4): + skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) + + skyreels_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, + outputs=skyreels_refresh_outputs + ) + + def calculate_v2v_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_width) + + def calculate_v2v_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_height) + + def update_v2v_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_v2v_dimensions(video): + if video is None: + return "", gr.update(value=544), gr.update(value=544) + cap = cv2.VideoCapture(video) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + # Make dimensions divisible by 16 + w = (w // 16) * 16 + h = (h // 16) * 16 + return f"{w}x{h}", w, h + + # Event Handlers for Video to Video Tab + v2v_input.change( + fn=update_v2v_dimensions, + inputs=[v2v_input], + outputs=[v2v_original_dims, v2v_width, v2v_height] + ) + + v2v_scale_slider.change( + fn=update_v2v_from_scale, + inputs=[v2v_scale_slider, v2v_original_dims], + outputs=[v2v_width, v2v_height] + ) + + v2v_calc_width_btn.click( + fn=calculate_v2v_width, + inputs=[v2v_height, v2v_original_dims], + outputs=[v2v_width] + ) + + v2v_calc_height_btn.click( + fn=calculate_v2v_height, + inputs=[v2v_width, v2v_original_dims], + outputs=[v2v_height] + ) + + ##Image 2 video dimension logic + def calculate_width(height, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_width) + + def calculate_height(width, original_dims): + if not original_dims: + return gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + aspect_ratio = orig_w / orig_h + new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_height) + + def update_from_scale(scale, original_dims): + if not original_dims: + return gr.update(), gr.update() + orig_w, orig_h = map(int, original_dims.split('x')) + new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 + new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 + return gr.update(value=new_w), gr.update(value=new_h) + + def update_dimensions(image): + if image is None: + return "", gr.update(value=544), gr.update(value=544) + img = Image.open(image) + w, h = img.size + # Make dimensions divisible by 16 + w = (w // 16) * 16 # Changed from 8 to 16 + h = (h // 16) * 16 # Changed from 8 to 16 + return f"{w}x{h}", w, h + i2v_input.change( + fn=update_dimensions, + inputs=[i2v_input], + outputs=[original_dims, width, height] + ) + + scale_slider.change( + fn=update_from_scale, + inputs=[scale_slider, original_dims], + outputs=[width, height] + ) + + calc_width_btn.click( + fn=calculate_width, + inputs=[height, original_dims], + outputs=[width] + ) + + calc_height_btn.click( + fn=calculate_height, + inputs=[width, original_dims], + outputs=[height] + ) + + # Function to get available DiT models + def get_dit_models(dit_folder: str) -> List[str]: + if not os.path.exists(dit_folder): + return ["mp_rank_00_model_states.pt"] + models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] + models.sort(key=str.lower) + return models if models else ["mp_rank_00_model_states.pt"] + + # Function to perform model merging + def merge_models( + dit_folder: str, + dit_model: str, + output_model: str, + exclude_single_blocks: bool, + merge_lora_folder: str, + *lora_params # Will contain both weights and multipliers + ) -> str: + try: + # Separate weights and multipliers + num_loras = len(lora_params) // 2 + weights = list(lora_params[:num_loras]) + multipliers = list(lora_params[num_loras:]) + + # Filter out "None" selections + valid_loras = [] + for weight, mult in zip(weights, multipliers): + if weight and weight != "None": + valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) + + if not valid_loras: + return "No LoRA models selected for merging" + + # Create output path in the dit folder + os.makedirs(dit_folder, exist_ok=True) + output_path = os.path.join(dit_folder, output_model) + + # Prepare command + cmd = [ + sys.executable, + "merge_lora.py", + "--dit", os.path.join(dit_folder, dit_model), + "--save_merged_model", output_path + ] + + # Add LoRA weights and multipliers + weights = [weight for weight, _ in valid_loras] + multipliers = [str(mult) for _, mult in valid_loras] + cmd.extend(["--lora_weight"] + weights) + cmd.extend(["--lora_multiplier"] + multipliers) + + if exclude_single_blocks: + cmd.append("--exclude_single_blocks") + + # Execute merge operation + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + if os.path.exists(output_path): + return f"Successfully merged model and saved to {output_path}" + else: + return "Error: Output file not created" + + except subprocess.CalledProcessError as e: + return f"Error during merging: {e.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + # Update DiT model dropdown + def update_dit_dropdown(dit_folder: str) -> Dict: + models = get_dit_models(dit_folder) + return gr.update(choices=models, value=models[0] if models else None) + + # Connect events + merge_btn.click( + fn=merge_models, + inputs=[ + dit_folder, + dit_model, + output_model, + exclude_single_blocks, + merge_lora_folder, + *merge_lora_weights, + *merge_lora_multipliers + ], + outputs=merge_status + ) + + # Refresh buttons for both DiT and LoRA dropdowns + merge_refresh_btn.click( + fn=lambda f: update_dit_dropdown(f), + inputs=[dit_folder], + outputs=[dit_model] + ) + + # LoRA refresh handling + merge_refresh_outputs = [] + for i in range(4): + merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) + + merge_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, + outputs=merge_refresh_outputs + ) + # Event handlers + prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) + v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) + stop_btn.click(fn=lambda: stop_event.set(), queue=False) + v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + #Image_to_Video + + def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters + img = Image.open(image_path) + + # Resize to the specified dimensions + img_resized = img.resize((width, height), Image.LANCZOS) + temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") + img_resized.save(temp_image_path) + + # Rest of function remains the same + frame_rate = 24 + duration = frames / frame_rate + command = [ + "ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", + "-t", str(duration), "-pix_fmt", "yuv420p", + "-vf", f"fps={frame_rate}", output_path + ] + + try: + subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + print(f"Video saved to {output_path}") + return True + except subprocess.CalledProcessError as e: + print(f"An error occurred while creating the video: {e}") + return False + finally: + # Clean up the temporary image file + if os.path.exists(temp_image_path): + os.remove(temp_image_path) + img.close() # Make sure to close the image file explicitly + + def generate_from_image( + image_path, + prompt, width, height, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, strength, batch_size, *lora_params + ): + """Generate video from input image with progressive updates""" + global stop_event + stop_event.clear() + + # Create temporary video path + temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") + + try: + # Convert image to video + if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): + yield [], "Failed to create temporary video", "Error in video creation" + return + + # Ensure video is fully written before proceeding + time.sleep(1) + if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: + yield [], "Failed to create temporary video", "Temporary video file is empty or missing" + return + + # Get video dimensions + try: + probe = ffmpeg.probe(temp_video_path) + video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) + if video_stream is None: + raise ValueError("No video stream found") + width = int(video_stream['width']) + height = int(video_stream['height']) + except Exception as e: + yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" + return + + # Generate the video using the temporary file + try: + generator = process_single_video( + prompt, width, height, batch_size, video_length, fps, infer_steps, + seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_params, video_path=temp_video_path, strength=strength + ) + + # Forward all generator updates + for videos, batch_text, progress_text in generator: + yield videos, batch_text, progress_text + + except Exception as e: + yield [], f"Error in video generation: {str(e)}", "Generation error" + return + + except Exception as e: + yield [], f"Unexpected error: {str(e)}", "Error occurred" + return + + finally: + # Clean up temporary file + try: + if os.path.exists(temp_video_path): + os.remove(temp_video_path) + except Exception: + pass # Ignore cleanup errors + + + # Add event handlers + i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) + i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + def handle_i2v_gallery_select(evt: gr.SelectData) -> int: + """Track selected index when I2V gallery item is clicked""" + return evt.index + + def send_i2v_to_v2v( + gallery: list, + prompt: str, + selected_index: int, + width: int, + height: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: + """Send the selected video and parameters from Image2Video tab to Video2Video tab""" + if not gallery or selected_index is None or selected_index >= len(gallery): + return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ + lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier + + selected_item = gallery[selected_index] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + # Use the original width and height without doubling + return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, + flow_shift, cfg_scale, lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) + + # Generate button handler + i2v_generate_btn.click( + fn=process_hunyuani2v_batch, + inputs=[ + i2v_prompt, width, height, + i2v_batch_size, i2v_video_length, + i2v_fps, i2v_infer_steps, i2v_seed, i2v_dit_folder, i2v_model, i2v_vae, i2v_te1, i2v_te2, + i2v_save_path, i2v_flow_shift, i2v_cfg_scale, i2v_output_type, i2v_attn_mode, + i2v_block_swap, i2v_exclude_single_blocks, i2v_use_split_attn, i2v_lora_folder, + *i2v_lora_weights, *i2v_lora_multipliers, i2v_input, i2v_strength, i2v_negative_prompt, + i2v_cfg_scale, i2v_split_uncond, i2v_use_fp8, i2v_clip_vision_path, i2v_stability, + i2v_fp8_fast, i2v_compile, i2v_compile_backend, i2v_compile_mode, + i2v_compile_dynamic, i2v_compile_fullgraph + ], + outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[i2v_batch_size], + outputs=i2v_selected_index + ) + # Send to Video2Video + i2v_output.select( + fn=handle_i2v_gallery_select, + outputs=i2v_selected_index + ) + + i2v_send_to_v2v_btn.click( + fn=send_i2v_to_v2v, + inputs=[ + i2v_output, i2v_prompt, i2v_selected_index, + width, height, + i2v_video_length, i2v_fps, i2v_infer_steps, + i2v_seed, i2v_flow_shift, i2v_cfg_scale + ] + i2v_lora_weights + i2v_lora_multipliers, + outputs=[ + v2v_input, v2v_prompt, + v2v_width, v2v_height, + v2v_video_length, v2v_fps, v2v_infer_steps, + v2v_seed, v2v_flow_shift, v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + #Video Info + def clean_video_path(video_path) -> str: + """Extract clean video path from Gradio's various return formats""" + print(f"Input video_path: {video_path}, type: {type(video_path)}") + if isinstance(video_path, dict): + path = video_path.get("name", "") + elif isinstance(video_path, (tuple, list)): + path = video_path[0] + elif isinstance(video_path, str): + path = video_path + else: + path = "" + print(f"Cleaned path: {path}") + return path + def handle_video_upload(video_path: str) -> Dict: + """Handle video upload and metadata extraction""" + if not video_path: + return {}, "No video uploaded" + + metadata = extract_video_metadata(video_path) + if not metadata: + return {}, "No metadata found in video" + + return metadata, "Metadata extracted successfully" + + def get_video_info(video_path: str) -> dict: + try: + probe = ffmpeg.probe(video_path) + video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') + + width = int(video_info['width']) + height = int(video_info['height']) + fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 + + # Calculate total frames + duration = float(probe['format']['duration']) + total_frames = int(duration * fps) + + # Ensure video length does not exceed 201 frames + if total_frames > 201: + total_frames = 201 + duration = total_frames / fps # Adjust duration accordingly + + return { + 'width': width, + 'height': height, + 'fps': fps, + 'total_frames': total_frames, + 'duration': duration # Might be useful in some contexts + } + except Exception as e: + print(f"Error extracting video info: {e}") + return {} + + def extract_video_details(video_path: str) -> Tuple[dict, str]: + metadata = extract_video_metadata(video_path) + video_details = get_video_info(video_path) + + # Combine metadata with video details + for key, value in video_details.items(): + if key not in metadata: + metadata[key] = value + + # Ensure video length does not exceed 201 frames + if 'video_length' in metadata: + metadata['video_length'] = min(metadata['video_length'], 201) + else: + metadata['video_length'] = min(video_details.get('total_frames', 0), 201) + + # Return both the updated metadata and a status message + return metadata, "Video details extracted successfully" + + def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: + """Create parameter mapping for target tab""" + if not metadata: + return "No parameters to send", {} + + tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" + try: + mapping = create_parameter_transfer_map(metadata, target_tab) + return f"Parameters ready for {tab_name}", mapping + except Exception as e: + return f"Error: {str(e)}", {} + + video_input.upload( + fn=extract_video_details, + inputs=video_input, + outputs=[metadata_output, status] + ) + + send_to_t2v_btn.click( + fn=lambda m: send_parameters_to_tab(m, "t2v"), + inputs=metadata_output, + outputs=[status, params_state] + ).then( + fn=change_to_tab_one, inputs=None, outputs=[tabs] + ).then( + lambda params: [ + params.get("prompt", ""), + params.get("width", 544), + params.get("height", 544), + params.get("batch_size", 1), + params.get("video_length", 25), + params.get("fps", 24), + params.get("infer_steps", 30), + params.get("seed", -1), + params.get("model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("vae", "hunyuan/pytorch_model.pt"), + params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("te2", "hunyuan/clip_l.safetensors"), + params.get("save_path", "outputs"), + params.get("flow_shift", 11.0), + params.get("cfg_scale", 7.0), + params.get("output_type", "video"), + params.get("attn_mode", "sdpa"), + params.get("block_swap", "0"), + *[params.get(f"lora{i+1}", "") for i in range(4)], + *[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] + ] if params else [gr.update()]*26, + inputs=params_state, + outputs=[prompt, width, height, batch_size, video_length, fps, infer_steps, seed, + model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap] + lora_weights + lora_multipliers + ) + # Text to Video generation + generate_btn.click( + fn=process_batch, + inputs=[ + prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, + seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, + output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, + lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 + ], + outputs=[video_output, batch_progress, progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[batch_size], + outputs=selected_index + ) + + # Update gallery selection handling + def handle_gallery_select(evt: gr.SelectData) -> int: + return evt.index + + # Track selected index when gallery item is clicked + video_output.select( + fn=handle_gallery_select, + outputs=selected_index + ) + + # Track selected index when Video2Video gallery item is clicked + def handle_v2v_gallery_select(evt: gr.SelectData) -> int: + """Handle gallery selection without automatically updating the input""" + return evt.index + + # Update the gallery selection event + v2v_output.select( + fn=handle_v2v_gallery_select, + outputs=v2v_selected_index + ) + + # Send button handler with gallery selection + def handle_send_button( + gallery: list, + prompt: str, + idx: int, + width: int, + height: int, + batch_size: int, + video_length: int, + fps: int, + infer_steps: int, + seed: int, + flow_shift: float, + cfg_scale: float, + lora1: str, + lora2: str, + lora3: str, + lora4: str, + lora1_multiplier: float, + lora2_multiplier: float, + lora3_multiplier: float, + lora4_multiplier: float + ) -> tuple: + if not gallery or idx is None or idx >= len(gallery): + return (None, "", width, height, batch_size, video_length, fps, infer_steps, + seed, flow_shift, cfg_scale, + lora1, lora2, lora3, lora4, + lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, + "") # Add empty string for negative_prompt in the return values + + # Auto-select first item if only one exists and no selection made + if idx is None and len(gallery) == 1: + idx = 0 + + selected_item = gallery[idx] + + # Handle different gallery item formats + if isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, (tuple, list)): + video_path = selected_item[0] + else: + video_path = selected_item + + # Final cleanup for Gradio Video component + if isinstance(video_path, tuple): + video_path = video_path[0] + + return ( + str(video_path), + prompt, + width, + height, + batch_size, + video_length, + fps, + infer_steps, + seed, + flow_shift, + cfg_scale, + lora1, + lora2, + lora3, + lora4, + lora1_multiplier, + lora2_multiplier, + lora3_multiplier, + lora4_multiplier, + "" # Add empty string for negative_prompt + ) + + send_t2v_to_v2v_btn.click( + fn=handle_send_button, + inputs=[ + video_output, prompt, selected_index, + t2v_width, t2v_height, batch_size, video_length, + fps, infer_steps, seed, flow_shift, cfg_scale + ] + lora_weights + lora_multipliers, # Remove the string here + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_batch_size, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale + ] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: + """Handle both parameters and video transfer""" + status_msg, params = send_parameters_to_tab(metadata, "v2v") + return status_msg, params, video_path + + def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: + """Handle both parameters and video transfer from Video Info to V2V tab""" + if not video_path: + return "No video selected", {}, None + + status_msg, params = send_parameters_to_tab(metadata, "v2v") + # Just return the path directly + return status_msg, params, video_path + + # Send button click handler + send_to_v2v_btn.click( + fn=handle_info_to_v2v, + inputs=[metadata_output, video_input], + outputs=[status, params_state, v2v_input] + ).then( + lambda params: [ + params.get("v2v_prompt", ""), + params.get("v2v_width", 544), + params.get("v2v_height", 544), + params.get("v2v_batch_size", 1), + params.get("v2v_video_length", 25), + params.get("v2v_fps", 24), + params.get("v2v_infer_steps", 30), + params.get("v2v_seed", -1), + params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), + params.get("v2v_vae", "hunyuan/pytorch_model.pt"), + params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), + params.get("v2v_te2", "hunyuan/clip_l.safetensors"), + params.get("v2v_save_path", "outputs"), + params.get("v2v_flow_shift", 11.0), + params.get("v2v_cfg_scale", 7.0), + params.get("v2v_output_type", "video"), + params.get("v2v_attn_mode", "sdpa"), + params.get("v2v_block_swap", "0"), + *[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], + *[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] + ] if params else [gr.update()] * 26, + inputs=params_state, + outputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, + v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, + v2v_attn_mode, v2v_block_swap + ] + v2v_lora_weights + v2v_lora_multipliers + ).then( + lambda: print(f"Tabs object: {tabs}"), # Debug print + outputs=None + ).then( + fn=change_to_tab_two, inputs=None, outputs=[tabs] + ) + + # Handler for sending selected video from Video2Video gallery to input + def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: + """Send the currently selected video in V2V gallery to V2V input""" + if not gallery or idx is None or idx >= len(gallery): + return None, "" + + selected_item = gallery[idx] + video_path = None + + # Handle different gallery item formats + if isinstance(selected_item, tuple): + video_path = selected_item[0] # Gallery returns (path, caption) + elif isinstance(selected_item, dict): + video_path = selected_item.get("name", selected_item.get("data", None)) + elif isinstance(selected_item, str): + video_path = selected_item + + if not video_path: + return None, "" + + # Check if the file exists and is accessible + if not os.path.exists(video_path): + print(f"Warning: Video file not found at {video_path}") + return None, "" + + return video_path, prompt + + v2v_send_to_input_btn.click( + fn=handle_v2v_send_button, + inputs=[v2v_output, v2v_prompt, v2v_selected_index], + outputs=[v2v_input, v2v_prompt] + ).then( + lambda: gr.update(visible=True), # Ensure the video input is visible + outputs=v2v_input + ) + + # Video to Video generation + v2v_generate_btn.click( + fn=process_batch, + inputs=[ + v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, + v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, + v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, + v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, + *v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, + v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 + ], + outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[v2v_batch_size], + outputs=v2v_selected_index + ) + refresh_outputs = [model] # Add model dropdown to outputs + for i in range(4): + refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) + + refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, + outputs=refresh_outputs + ) + # Image2Video refresh + i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs + for i in range(4): + i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) + + i2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, + outputs=i2v_refresh_outputs + ) + + # Video2Video refresh + v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs + for i in range(4): + v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) + + v2v_refresh_btn.click( + fn=update_dit_and_lora_dropdowns, + inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, + outputs=v2v_refresh_outputs + ) + + # WanX-i2v tab connections + wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) + wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Image input handling for WanX-i2v + wanx_input.change( + fn=update_wanx_image_dimensions, + inputs=[wanx_input], + outputs=[wanx_original_dims, wanx_width, wanx_height] + ) + + # Scale slider handling for WanX-i2v + wanx_scale_slider.change( + fn=update_wanx_from_scale, + inputs=[wanx_scale_slider, wanx_original_dims], + outputs=[wanx_width, wanx_height] + ) + + # Width/height calculation buttons for WanX-i2v + wanx_calc_width_btn.click( + fn=calculate_wanx_width, + inputs=[wanx_height, wanx_original_dims], + outputs=[wanx_width] + ) + + wanx_calc_height_btn.click( + fn=calculate_wanx_height, + inputs=[wanx_width, wanx_original_dims], + outputs=[wanx_height] + ) + # Add visibility toggle for the folder input components + wanx_use_random_folder.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), + inputs=[wanx_use_random_folder], + outputs=[wanx_input_folder, wanx_folder_status, wanx_validate_folder_btn, wanx_input] + ) + + # Validate folder button handler + wanx_validate_folder_btn.click( + fn=lambda folder: get_random_image_from_folder(folder)[1], + inputs=[wanx_input_folder], + outputs=[wanx_folder_status] + ) + + # Flow shift recommendation buttons + wanx_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_width, wanx_height], + outputs=[wanx_flow_shift] + ) + + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Generate button handler + wanx_generate_btn.click( + fn=wanx_batch_handler, + inputs=[ + wanx_use_random_folder, + wanx_prompt, + wanx_negative_prompt, + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_flow_shift, + wanx_guidance_scale, + wanx_seed, + wanx_batch_size, + wanx_input_folder, + wanx_task, + wanx_dit_path, + wanx_vae_path, + wanx_t5_path, + wanx_clip_path, + wanx_save_path, + wanx_output_type, + wanx_sample_solver, + wanx_exclude_single_blocks, + wanx_attn_mode, + wanx_block_swap, + wanx_fp8, + wanx_fp8_t5, + wanx_lora_folder, + *wanx_lora_weights, + *wanx_lora_multipliers, + wanx_input # Include input image path for non-batch mode + ], + outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_batch_size], + outputs=wanx_i2v_selected_index # Update to use correct state + ) + + # Add refresh button handler for WanX-i2v tab + wanx_refresh_outputs = [] + for i in range(4): + wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) + + wanx_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, + outputs=wanx_refresh_outputs + ) + + # Gallery selection handling + wanx_output.select( + fn=handle_wanx_gallery_select, + inputs=[wanx_output], + outputs=[wanx_i2v_selected_index, wanx_base_video] + ) + + # Send to Video2Video handler + wanx_send_to_v2v_btn.click( + fn=send_wanx_to_v2v, + inputs=[ + wanx_output, # Gallery with videos + wanx_prompt, # Prompt text + wanx_i2v_selected_index, # Use the correct selected index state + wanx_width, + wanx_height, + wanx_video_length, + wanx_fps, + wanx_infer_steps, + wanx_seed, + wanx_flow_shift, + wanx_guidance_scale, + wanx_negative_prompt + ], + outputs=[ + v2v_input, # Video input in V2V tab + v2v_prompt, # Prompt in V2V tab + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, # Function to switch to Video2Video tab + inputs=None, + outputs=[tabs] + ) + + # Add state for T2V tab selected index + wanx_t2v_selected_index = gr.State(value=None) + + # Connect prompt token counter + wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) + + # Stop button handler + wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) + + # Flow shift recommendation button + wanx_t2v_recommend_flow_btn.click( + fn=recommend_wanx_flow_shift, + inputs=[wanx_t2v_width, wanx_t2v_height], + outputs=[wanx_t2v_flow_shift] + ) + + # Task change handler to update CLIP visibility and path + def update_clip_visibility(task): + is_i2v = "i2v" in task + return gr.update(visible=is_i2v) + + wanx_t2v_task.change( + fn=update_clip_visibility, + inputs=[wanx_t2v_task], + outputs=[wanx_t2v_clip_path] + ) + + # Generate button handler for T2V + wanx_t2v_generate_btn.click( + fn=wanx_generate_video_batch, + inputs=[ + wanx_t2v_prompt, + wanx_t2v_negative_prompt, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_seed, + wanx_t2v_task, + wanx_t2v_dit_path, + wanx_t2v_vae_path, + wanx_t2v_t5_path, + wanx_t2v_clip_path, + wanx_t2v_save_path, + wanx_t2v_output_type, + wanx_t2v_sample_solver, + wanx_t2v_exclude_single_blocks, + wanx_t2v_attn_mode, + wanx_t2v_block_swap, + wanx_t2v_fp8, + wanx_t2v_fp8_t5, + wanx_t2v_lora_folder, + *wanx_t2v_lora_weights, + *wanx_t2v_lora_multipliers, + wanx_t2v_batch_size, + # input_image is now optional and not included here + ], + outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], + queue=True + ).then( + fn=lambda batch_size: 0 if batch_size == 1 else None, + inputs=[wanx_t2v_batch_size], + outputs=wanx_t2v_selected_index + ) + + # Add refresh button handler for WanX-t2v tab + wanx_t2v_refresh_outputs = [] + for i in range(4): + wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) + + wanx_t2v_refresh_btn.click( + fn=update_lora_dropdowns, + inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, + outputs=wanx_t2v_refresh_outputs + ) + + # Gallery selection handling + wanx_t2v_output.select( + fn=handle_wanx_t2v_gallery_select, + outputs=wanx_t2v_selected_index + ) + + # Send to Video2Video handler + wanx_t2v_send_to_v2v_btn.click( + fn=send_wanx_t2v_to_v2v, + inputs=[ + wanx_t2v_output, + wanx_t2v_prompt, + wanx_t2v_selected_index, + wanx_t2v_width, + wanx_t2v_height, + wanx_t2v_video_length, + wanx_t2v_fps, + wanx_t2v_infer_steps, + wanx_t2v_seed, + wanx_t2v_flow_shift, + wanx_t2v_guidance_scale, + wanx_t2v_negative_prompt + ], + outputs=[ + v2v_input, + v2v_prompt, + v2v_width, + v2v_height, + v2v_video_length, + v2v_fps, + v2v_infer_steps, + v2v_seed, + v2v_flow_shift, + v2v_cfg_scale, + v2v_negative_prompt + ] + ).then( + fn=change_to_tab_two, + inputs=None, + outputs=[tabs] + ) + +demo.queue().launch(server_name="0.0.0.0", share=False) \ No newline at end of file diff --git a/images/screenshot.png b/images/screenshot.png new file mode 100644 index 0000000000000000000000000000000000000000..ed7676acd1a950ffdea755bcf2b9e2ba26c10ea9 --- /dev/null +++ b/images/screenshot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:125efddda63b53bc9ab48c63e6c4a2ca381bcecbf505fe325e51d71ca0630e9f +size 9408018 diff --git a/merge_diffusers2I2Vsky.py b/merge_diffusers2I2Vsky.py new file mode 100644 index 0000000000000000000000000000000000000000..e4bc9bf1e96c6bdb7dfdb1994d5b811c715ab4ae --- /dev/null +++ b/merge_diffusers2I2Vsky.py @@ -0,0 +1,237 @@ +import os +import json +import torch +from safetensors.torch import load_file, save_file +import logging +import shutil +from typing import Dict, Any, Set +import re + +logger = logging.getLogger("PeftMerger") +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + +def normalize_key(key: str) -> str: + """Normalize key format to match base model""" + key = key.replace("transformer.double_blocks", "transformer_blocks") + key = key.replace("transformer.single_blocks", "single_transformer_blocks") + key = re.sub(r'\.+', '.', key) # Remove double dots + if key.endswith('.'): + key = key[:-1] + return key + +def merge_lora_weights(base_weights: Dict[str, torch.Tensor], + lora_weights: Dict[str, torch.Tensor], + alpha: float = 1.0) -> Dict[str, torch.Tensor]: + """Merge LoRA weights into base model weights""" + merged = base_weights.copy() + + # Print first few keys for debugging + logger.info(f"Base model keys (first 5): {list(base_weights.keys())[:5]}") + logger.info(f"LoRA keys (first 5): {list(lora_weights.keys())[:5]}") + + # Process LoRA keys + for key in lora_weights.keys(): + if '.lora_A.weight' not in key: + continue + + logger.info(f"Processing LoRA key: {key}") + base_key = key.replace('.lora_A.weight', '') + lora_a = lora_weights[key] + lora_b = lora_weights[base_key + '.lora_B.weight'] + + # Normalize after getting both A and B weights + normalized_key = normalize_key(base_key) + logger.info(f"Normalized key: {normalized_key}") + + # Map double blocks + if 'img_attn_qkv' in base_key: + weights = torch.matmul(lora_b, lora_a) + q, k, v = torch.chunk(weights, 3, dim=0) + block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) + block_num = block_match.group(1) + + q_key = f'transformer_blocks.{block_num}.attn.to_q.weight' + k_key = f'transformer_blocks.{block_num}.attn.to_k.weight' + v_key = f'transformer_blocks.{block_num}.attn.to_v.weight' + + if all(k in merged for k in [q_key, k_key, v_key]): + merged[q_key] = merged[q_key] + alpha * q + merged[k_key] = merged[k_key] + alpha * k + merged[v_key] = merged[v_key] + alpha * v + logger.info(f"Updated keys: {q_key}, {k_key}, {v_key}") + else: + logger.warning(f"Missing some keys: {[k for k in [q_key, k_key, v_key] if k not in merged]}") + + elif 'txt_attn_qkv' in base_key: + weights = torch.matmul(lora_b, lora_a) + q, k, v = torch.chunk(weights, 3, dim=0) + block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) + block_num = block_match.group(1) + + q_key = f'transformer_blocks.{block_num}.attn.add_q_proj.weight' + k_key = f'transformer_blocks.{block_num}.attn.add_k_proj.weight' + v_key = f'transformer_blocks.{block_num}.attn.add_v_proj.weight' + + if all(k in merged for k in [q_key, k_key, v_key]): + merged[q_key] = merged[q_key] + alpha * q + merged[k_key] = merged[k_key] + alpha * k + merged[v_key] = merged[v_key] + alpha * v + logger.info(f"Updated keys: {q_key}, {k_key}, {v_key}") + else: + logger.warning(f"Missing some keys: {[k for k in [q_key, k_key, v_key] if k not in merged]}") + + elif 'img_attn_proj' in base_key: + block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) + block_num = block_match.group(1) + model_key = f'transformer_blocks.{block_num}.attn.to_out.0.weight' + + if model_key in merged: + weights = torch.matmul(lora_b, lora_a) + merged[model_key] = merged[model_key] + alpha * weights + logger.info(f"Updated key: {model_key}") + else: + logger.warning(f"Missing key: {model_key}") + + elif 'txt_attn_proj' in base_key: + block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) + block_num = block_match.group(1) + model_key = f'transformer_blocks.{block_num}.attn.to_add_out.weight' + + if model_key in merged: + weights = torch.matmul(lora_b, lora_a) + merged[model_key] = merged[model_key] + alpha * weights + logger.info(f"Updated key: {model_key}") + else: + logger.warning(f"Missing key: {model_key}") + + elif 'img_mlp.fc1' in base_key: + block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) + block_num = block_match.group(1) + model_key = f'transformer_blocks.{block_num}.ff.net.0.proj.weight' + + if model_key in merged: + weights = torch.matmul(lora_b, lora_a) + merged[model_key] = merged[model_key] + alpha * weights + logger.info(f"Updated key: {model_key}") + else: + logger.warning(f"Missing key: {model_key}") + + elif 'img_mlp.fc2' in base_key: + block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) + block_num = block_match.group(1) + model_key = f'transformer_blocks.{block_num}.ff.net.2.weight' + + if model_key in merged: + weights = torch.matmul(lora_b, lora_a) + merged[model_key] = merged[model_key] + alpha * weights + logger.info(f"Updated key: {model_key}") + else: + logger.warning(f"Missing key: {model_key}") + + elif 'txt_mlp.fc1' in base_key: + block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) + block_num = block_match.group(1) + model_key = f'transformer_blocks.{block_num}.ff_context.net.0.proj.weight' + + if model_key in merged: + weights = torch.matmul(lora_b, lora_a) + merged[model_key] = merged[model_key] + alpha * weights + logger.info(f"Updated key: {model_key}") + else: + logger.warning(f"Missing key: {model_key}") + + elif 'txt_mlp.fc2' in base_key: + block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) + block_num = block_match.group(1) + model_key = f'transformer_blocks.{block_num}.ff_context.net.2.weight' + + if model_key in merged: + weights = torch.matmul(lora_b, lora_a) + merged[model_key] = merged[model_key] + alpha * weights + logger.info(f"Updated key: {model_key}") + else: + logger.warning(f"Missing key: {model_key}") + + return merged + +def save_sharded_model(weights: Dict[str, torch.Tensor], + index_data: dict, + output_dir: str, + base_model_path: str): + """Save merged weights in same sharded format as original""" + os.makedirs(output_dir, exist_ok=True) + + # Copy all non-safetensor files from original directory + index_dir = os.path.dirname(os.path.abspath(base_model_path)) + for file in os.listdir(index_dir): + if not file.endswith('.safetensors'): + src = os.path.join(index_dir, file) + dst = os.path.join(output_dir, file) + if os.path.isfile(src): + shutil.copy2(src, dst) + elif os.path.isdir(src): + shutil.copytree(src, dst) + + # Group weights by shard + weight_map = index_data['weight_map'] + shard_weights = {} + + for key, shard in weight_map.items(): + if shard not in shard_weights: + shard_weights[shard] = {} + if key in weights: + shard_weights[shard][key] = weights[key] + + # Save each shard + for shard, shard_dict in shard_weights.items(): + if not shard_dict: # Skip empty shards + continue + shard_path = os.path.join(output_dir, shard) + logger.info(f"Saving shard {shard} with {len(shard_dict)} tensors") + save_file(shard_dict, shard_path) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--base_model", type=str, required=True) + parser.add_argument("--adapter", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + parser.add_argument("--alpha", type=float, default=1.0) + args = parser.parse_args() + + # Load base model index + logger.info("Loading base model index...") + with open(args.base_model, 'r') as f: + index_data = json.load(f) + weight_map = index_data['weight_map'] + + # Load base weights + logger.info("Loading base model weights...") + base_dir = os.path.dirname(args.base_model) + base_weights = {} + for part_file in set(weight_map.values()): + part_path = os.path.join(base_dir, part_file) + logger.info(f"Loading from {part_path}") + weights = load_file(part_path) + base_weights.update(weights) + + # Load LoRA + logger.info("Loading LoRA weights...") + lora_weights = load_file(args.adapter) + + # Merge + logger.info(f"Merging with alpha={args.alpha}") + merged_weights = merge_lora_weights(base_weights, lora_weights, args.alpha) + + # Save in sharded format + logger.info(f"Saving merged model to {args.output}") + save_sharded_model(merged_weights, index_data, args.output, args.base_model) + logger.info("Done!") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/merge_lora.py b/merge_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..c73cadc731ec471637c9006f53829aeeb197c2cc --- /dev/null +++ b/merge_lora.py @@ -0,0 +1,69 @@ +import argparse +import logging +import torch +from safetensors.torch import load_file +from networks import lora +from utils.safetensors_utils import mem_eff_save_file +from hunyuan_model.models import load_transformer + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser(description="HunyuanVideo model merger script") + + parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory") + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier (can specify multiple values)") + parser.add_argument("--save_merged_model", type=str, required=True, help="Path to save the merged model") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for merging") + parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights") + + + return parser.parse_args() + + +def main(): + args = parse_args() + + device = torch.device(args.device) + logger.info(f"Using device: {device}") + + # Load DiT model + logger.info(f"Loading DiT model from {args.dit}") + transformer = load_transformer(args.dit, "torch", False, "cpu", torch.bfloat16) + transformer.eval() + + # Load LoRA weights and merge + if args.lora_weight is not None and len(args.lora_weight) > 0: + for i, lora_weight in enumerate(args.lora_weight): + # Use the corresponding lora_multiplier or default to 1.0 + if args.lora_multiplier is not None and len(args.lora_multiplier) > i: + lora_multiplier = args.lora_multiplier[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + + if args.exclude_single_blocks: + filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k} + weights_sd = filtered_weights + + network = lora.create_network_from_weights_hunyuan_video( + lora_multiplier, weights_sd, unet=transformer, for_inference=True + ) + logger.info("Merging LoRA weights to DiT model") + network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True) + + logger.info("LoRA weights loaded") + + # Save the merged model + logger.info(f"Saving merged model to {args.save_merged_model}") + mem_eff_save_file(transformer.state_dict(), args.save_merged_model) + logger.info("Merged model saved") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/custom_offloading_utils.py b/modules/custom_offloading_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d813575af2ce4fcccf4a305c1002bf618844e591 --- /dev/null +++ b/modules/custom_offloading_utils.py @@ -0,0 +1,266 @@ +from concurrent.futures import ThreadPoolExecutor +import gc +import time +from typing import Optional +import torch +import torch.nn as nn + + +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + + # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules + # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + # print(module_to_cpu.__class__, module_to_cuda.__class__) + # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} + for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): + if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: + module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) + if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + else: + if module_to_cuda.weight.data.device.type != device.type: + # print( + # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" + # ) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + """ + not tested + """ + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # device to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + synchronize_device() + + # cpu to device + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + synchronize_device() + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +class Offloader: + """ + common offloading class + """ + + def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + self.block_type = block_type + self.num_blocks = num_blocks + self.blocks_to_swap = blocks_to_swap + self.device = device + self.debug = debug + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + self.futures = {} + self.cuda_available = device.type == "cuda" + + def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): + if self.cuda_available: + swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) + else: + swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) + + def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + if self.debug: + start_time = time.perf_counter() + print( + f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}" + ) + + self.swap_weight_devices(block_to_cpu, block_to_cuda) + + if self.debug: + print(f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + self.futures[block_idx_to_cuda] = self.thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda + ) + + def _wait_blocks_move(self, block_idx): + if block_idx not in self.futures: + return + + if self.debug: + print(f"[{self.block_type}] Wait for block {block_idx}") + start_time = time.perf_counter() + + future = self.futures.pop(block_idx) + _, bidx_to_cuda = future.result() + + assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" + + if self.debug: + print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + + +class ModelOffloader(Offloader): + """ + supports forward offloading + """ + + def __init__( + self, + block_type: str, + blocks: list[nn.Module], + num_blocks: int, + blocks_to_swap: int, + supports_backward: bool, + device: torch.device, + debug: bool = False, + ): + super().__init__(block_type, num_blocks, blocks_to_swap, device, debug) + + self.supports_backward = supports_backward + self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference + + if self.supports_backward: + # register backward hooks + self.remove_handles = [] + for i, block in enumerate(blocks): + hook = self.create_backward_hook(blocks, i) + if hook is not None: + handle = block.register_full_backward_hook(hook) + self.remove_handles.append(handle) + + def set_forward_only(self, forward_only: bool): + self.forward_only = forward_only + + def __del__(self): + if self.supports_backward: + for handle in self.remove_handles: + handle.remove() + + def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + # -1 for 0-based index + num_blocks_propagated = self.num_blocks - block_index - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + def backward_hook(module, grad_input, grad_output): + if self.debug: + print(f"Backward hook for block {block_index}") + + if swapping: + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + if waiting: + self._wait_blocks_move(block_idx_to_wait) + return None + + return backward_hook + + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + if self.debug: + print(f"[{self.block_type}] Prepare block devices before forward") + + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + b.to(self.device) + weighs_to_device(b, self.device) # make sure weights are on device + + for b in blocks[self.num_blocks - self.blocks_to_swap :]: + b.to(self.device) # move block to device first + weighs_to_device(b, "cpu") # make sure weights are on cpu + + synchronize_device(self.device) + clean_memory_on_device(self.device) + + def wait_for_block(self, block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self._wait_blocks_move(block_idx) + + def submit_move_blocks_forward(self, blocks: list[nn.Module], block_idx: int): + # check if blocks_to_swap is enabled + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + # if supports_backward and backward is enabled, we swap blocks more than blocks_to_swap in backward pass + if not self.forward_only and block_idx >= self.blocks_to_swap: + return + + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx + block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) diff --git a/modules/fp8_optimization_utils.py b/modules/fp8_optimization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e7aff9b28d764f4481b47a3d327debc1aedd3c8a --- /dev/null +++ b/modules/fp8_optimization_utils.py @@ -0,0 +1,456 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import logging + +from tqdm import tqdm + +from utils.safetensors_utils import MemoryEfficientSafeOpen + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +from utils.device_utils import clean_memory_on_device + + +def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1): + """ + Calculate the maximum representable value in FP8 format. + Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). + + Args: + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + sign_bits (int): Number of sign bits (0 or 1) + + Returns: + float: Maximum value representable in FP8 format + """ + assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8" + + # Calculate exponent bias + bias = 2 ** (exp_bits - 1) - 1 + + # Calculate maximum mantissa value + mantissa_max = 1.0 + for i in range(mantissa_bits - 1): + mantissa_max += 2 ** -(i + 1) + + # Calculate maximum value + max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias)) + + return max_value + + +def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None): + """ + Quantize a tensor to FP8 format. + + Args: + tensor (torch.Tensor): Tensor to quantize + scale (float or torch.Tensor): Scale factor + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + sign_bits (int): Number of sign bits + + Returns: + tuple: (quantized_tensor, scale_factor) + """ + # Create scaled tensor + scaled_tensor = tensor / scale + + # Calculate FP8 parameters + bias = 2 ** (exp_bits - 1) - 1 + + if max_value is None: + # Calculate max and min values + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits) + min_value = -max_value if sign_bits > 0 else 0.0 + + # Clamp tensor to range + clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value) + + # Quantization process + abs_values = torch.abs(clamped_tensor) + nonzero_mask = abs_values > 0 + + # Calculate log scales (only for non-zero elements) + log_scales = torch.zeros_like(clamped_tensor) + if nonzero_mask.any(): + log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach() + + # Limit log scales and calculate quantization factor + log_scales = torch.clamp(log_scales, min=1.0) + quant_factor = 2.0 ** (log_scales - mantissa_bits - bias) + + # Quantize and dequantize + quantized = torch.round(clamped_tensor / quant_factor) * quant_factor + + return quantized, scale + + +def optimize_state_dict_with_fp8_on_the_fly( + model_files, + calc_device, + target_layer_keys=None, + exclude_layer_keys=None, + exp_bits=4, + mantissa_bits=3, + move_to_device=False, + weight_hook=None, +): + """ + Optimize Linear layer weights in a model's state dict to FP8 format. + + Args: + model_files (list): List of model files to optimize + calc_device (str): Device to quantize tensors on + target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers) + exclude_layer_keys (list, optional): Layer key patterns to exclude + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + move_to_device (bool): Move optimized tensors to the calculating device + + Returns: + dict: FP8 optimized state dict + """ + if exp_bits == 4 and mantissa_bits == 3: + fp8_dtype = torch.float8_e4m3fn + elif exp_bits == 5 and mantissa_bits == 2: + fp8_dtype = torch.float8_e5m2 + else: + raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}") + + # Calculate FP8 max value + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) + min_value = -max_value # this function supports only signed FP8 + + # Create optimized state dict + def is_target_key(key): + # Check if it's a weight key and matches target patterns + is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight") + is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys) + is_target = is_target and not is_excluded + return is_target + + optimized_count = 0 + # process each model file + state_dict = {} + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file) as f: + keys = f.keys() + for key in tqdm(keys, desc=f"Loading {model_file}", unit="key"): + value = f.get_tensor(key) + if weight_hook is not None: + value = weight_hook(key, value) + + if not is_target_key(key): + state_dict[key] = value + continue + + # Save original device and dtype + original_device = value.device + original_dtype = value.dtype + + # Move to calculation device + if calc_device is not None: + value = value.to(calc_device) + + # Calculate scale factor + scale = torch.max(torch.abs(value.flatten())) / max_value + # print(f"Optimizing {key} with scale: {scale}") + + # Quantize weight to FP8 + quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value) + + # Add to state dict using original key for weight and new key for scale + fp8_key = key # Maintain original key + scale_key = key.replace(".weight", ".scale_weight") + + quantized_weight = quantized_weight.to(fp8_dtype) + + if not move_to_device: + quantized_weight = quantized_weight.to(original_device) + + scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device) + + state_dict[fp8_key] = quantized_weight + state_dict[scale_key] = scale_tensor + + optimized_count += 1 + + if calc_device is not None: # optimized_count % 10 == 0 and + # free memory on calculation device + clean_memory_on_device(calc_device) + + logger.info(f"Number of optimized Linear layers: {optimized_count}") + return state_dict + + +def optimize_state_dict_with_fp8( + state_dict, calc_device, target_layer_keys=None, exclude_layer_keys=None, exp_bits=4, mantissa_bits=3, move_to_device=False +): + """ + Optimize Linear layer weights in a model's state dict to FP8 format. + + Args: + state_dict (dict): State dict to optimize, replaced in-place + calc_device (str): Device to quantize tensors on + target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers) + exclude_layer_keys (list, optional): Layer key patterns to exclude + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + move_to_device (bool): Move optimized tensors to the calculating device + + Returns: + dict: FP8 optimized state dict + """ + if exp_bits == 4 and mantissa_bits == 3: + fp8_dtype = torch.float8_e4m3fn + elif exp_bits == 5 and mantissa_bits == 2: + fp8_dtype = torch.float8_e5m2 + else: + raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}") + + # Calculate FP8 max value + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) + min_value = -max_value # this function supports only signed FP8 + + # Create optimized state dict + optimized_count = 0 + + # Enumerate tarket keys + target_state_dict_keys = [] + for key in state_dict.keys(): + # Check if it's a weight key and matches target patterns + is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight") + is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys) + is_target = is_target and not is_excluded + + if is_target and isinstance(state_dict[key], torch.Tensor): + target_state_dict_keys.append(key) + + # Process each key + for key in tqdm(target_state_dict_keys): + value = state_dict[key] + + # Save original device and dtype + original_device = value.device + original_dtype = value.dtype + + # Move to calculation device + if calc_device is not None: + value = value.to(calc_device) + + # Calculate scale factor + scale = torch.max(torch.abs(value.flatten())) / max_value + # print(f"Optimizing {key} with scale: {scale}") + + # Quantize weight to FP8 + quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value) + + # Add to state dict using original key for weight and new key for scale + fp8_key = key # Maintain original key + scale_key = key.replace(".weight", ".scale_weight") + + quantized_weight = quantized_weight.to(fp8_dtype) + + if not move_to_device: + quantized_weight = quantized_weight.to(original_device) + + scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device) + + state_dict[fp8_key] = quantized_weight + state_dict[scale_key] = scale_tensor + + optimized_count += 1 + + if calc_device is not None: # optimized_count % 10 == 0 and + # free memory on calculation device + clean_memory_on_device(calc_device) + + logger.info(f"Number of optimized Linear layers: {optimized_count}") + return state_dict + + +def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None): + """ + Patched forward method for Linear layers with FP8 weights. + + Args: + self: Linear layer instance + x (torch.Tensor): Input tensor + use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) + max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor. + + Returns: + torch.Tensor: Result of linear transformation + """ + if use_scaled_mm: + input_dtype = x.dtype + original_weight_dtype = self.scale_weight.dtype + weight_dtype = self.weight.dtype + target_dtype = torch.float8_e5m2 + assert weight_dtype == torch.float8_e4m3fn, "Only FP8 E4M3FN format is supported" + assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)" + + if max_value is None: + # no input quantization + scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device) + else: + # calculate scale factor for input tensor + scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32) + + # quantize input tensor to FP8: this seems to consume a lot of memory + x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value) + + original_shape = x.shape + x = x.reshape(-1, x.shape[2]).to(target_dtype) + + weight = self.weight.t() + scale_weight = self.scale_weight.to(torch.float32) + + if self.bias is not None: + # float32 is not supported with bias in scaled_mm + o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight) + else: + o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight) + + return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype) + + else: + # Dequantize the weight + original_dtype = self.scale_weight.dtype + dequantized_weight = self.weight.to(original_dtype) * self.scale_weight + + # Perform linear transformation + if self.bias is not None: + output = F.linear(x, dequantized_weight, self.bias) + else: + output = F.linear(x, dequantized_weight) + + return output + + +def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False): + """ + Apply monkey patching to a model using FP8 optimized state dict. + + Args: + model (nn.Module): Model instance to patch + optimized_state_dict (dict): FP8 optimized state dict + use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) + + Returns: + nn.Module: The patched model (same instance, modified in-place) + """ + # # Calculate FP8 float8_e5m2 max value + # max_value = calculate_fp8_maxval(5, 2) + max_value = None # do not quantize input tensor + + # Find all scale keys to identify FP8-optimized layers + scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")] + + # Enumerate patched layers + patched_module_paths = set() + for scale_key in scale_keys: + # Extract module path from scale key (remove .scale_weight) + module_path = scale_key.rsplit(".scale_weight", 1)[0] + patched_module_paths.add(module_path) + + patched_count = 0 + + # Apply monkey patch to each layer with FP8 weights + for name, module in model.named_modules(): + # Check if this module has a corresponding scale_weight + has_scale = name in patched_module_paths + + # Apply patch if it's a Linear layer with FP8 scale + if isinstance(module, nn.Linear) and has_scale: + # register the scale_weight as a buffer to load the state_dict + module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype)) + + # Create a new forward method with the patched version. + def new_forward(self, x): + return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value) + + # Bind method to module + module.forward = new_forward.__get__(module, type(module)) + + patched_count += 1 + + logger.info(f"Number of monkey-patched Linear layers: {patched_count}") + return model + + +# Example usage +def example_usage(): + # Small test model + class TestModel(nn.Module): + def __init__(self): + super().__init__() + fc1 = nn.Linear(768, 3072) + act1 = nn.GELU() + fc2 = nn.Linear(3072, 768) + act2 = nn.GELU() + fc3 = nn.Linear(768, 768) + + # Set layer names for testing + self.single_blocks = nn.ModuleList([fc1, act1, fc2, act2, fc3]) + + self.fc4 = nn.Linear(768, 128) + + def forward(self, x): + for layer in self.single_blocks: + x = layer(x) + x = self.fc4(x) + return x + + # Instantiate model + test_model = TestModel() + test_model.to(torch.float16) # convert to FP16 for testing + + # Test input tensor + test_input = torch.randn(1, 768, dtype=torch.float16) + + # Calculate output before optimization + with torch.no_grad(): + original_output = test_model(test_input) + print("original output", original_output[0, :5]) + + # Get state dict + state_dict = test_model.state_dict() + + # Apply FP8 optimization to state dict + cuda_device = torch.device("cuda") + optimized_state_dict = optimize_state_dict_with_fp8(state_dict, cuda_device, ["single_blocks"], ["2"]) + + # Apply monkey patching to the model + optimized_model = TestModel() # re-instantiate model + optimized_model.to(torch.float16) # convert to FP16 for testing + apply_fp8_monkey_patch(optimized_model, optimized_state_dict) + + # Load optimized state dict + optimized_model.load_state_dict(optimized_state_dict, strict=True, assign=True) # assign=True to load buffer + + # Calculate output after optimization + with torch.no_grad(): + optimized_output = optimized_model(test_input) + print("optimized output", optimized_output[0, :5]) + + # Compare accuracy + error = torch.mean(torch.abs(original_output - optimized_output)) + print(f"Mean absolute error: {error.item()}") + + # Check memory usage + original_params = sum(p.nelement() * p.element_size() for p in test_model.parameters()) / (1024 * 1024) + print(f"Model parameter memory: {original_params:.2f} MB") + optimized_params = sum(p.nelement() * p.element_size() for p in optimized_model.parameters()) / (1024 * 1024) + print(f"Optimized model parameter memory: {optimized_params:.2f} MB") + + return test_model + + +if __name__ == "__main__": + example_usage() diff --git a/modules/model.py b/modules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9b0a34f5b43bdbbfa140f321944801fc799a583d --- /dev/null +++ b/modules/model.py @@ -0,0 +1,928 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from accelerate import init_empty_weights + +import logging + +from utils.safetensors_utils import MemoryEfficientSafeOpen, load_safetensors + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +from utils.device_utils import clean_memory_on_device + +from .attention import flash_attention +from utils.device_utils import clean_memory_on_device +from modules.custom_offloading_utils import ModelOffloader +from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8 + +__all__ = ["WanModel"] + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +# @amp.autocast(enabled=False) +# no autocast is needed for rope_apply, because it is already in float64 +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +# @amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs): + device_type = x.device.type + with torch.amp.autocast(device_type=device_type, enabled=False): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) + freqs_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +def calculate_freqs_i(fhw, c, freqs): + f, h, w = fhw + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + freqs_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(f * h * w, 1, -1) + return freqs_i + + +# inplace version of rope_apply +def rope_apply_inplace_cached(x, grid_sizes, freqs_list): + # with torch.amp.autocast(device_type=device_type, enabled=False): + rope_dtype = torch.float64 # float32 does not reduce memory usage significantly + + n, c = x.size(2), x.size(3) // 2 + + # loop over samples + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(rope_dtype).reshape(seq_len, n, -1, 2)) + freqs_i = freqs_list[i] + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + # x_i = torch.cat([x_i, x[i, seq_len:]]) + + # inplace update + x[i, :seq_len] = x_i.to(x.dtype) + + return x + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + # return self._norm(x.float()).type_as(x) * self.weight + # support fp8 + return self._norm(x.float()).type_as(x) * self.weight.to(x.dtype) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + # def forward(self, x): + # r""" + # Args: + # x(Tensor): Shape [B, L, C] + # """ + # # inplace version, also supports fp8 -> does not have significant performance improvement + # original_dtype = x.dtype + # x = x.float() + # y = x.pow(2).mean(dim=-1, keepdim=True) + # y.add_(self.eps) + # y.rsqrt_() + # x *= y + # x = x.to(original_dtype) + # x *= self.weight.to(original_dtype) + # return x + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, attn_mode="torch", split_attn=False): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + self.attn_mode = attn_mode + self.split_attn = split_attn + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # # query, key, value function + # def qkv_fn(x): + # q = self.norm_q(self.q(x)).view(b, s, n, d) + # k = self.norm_k(self.k(x)).view(b, s, n, d) + # v = self.v(x).view(b, s, n, d) + # return q, k, v + # q, k, v = qkv_fn(x) + # del x + # query, key, value function + + q = self.q(x) + k = self.k(x) + v = self.v(x) + del x + q = self.norm_q(q) + k = self.norm_k(k) + q = q.view(b, s, n, d) + k = k.view(b, s, n, d) + v = v.view(b, s, n, d) + + rope_apply_inplace_cached(q, grid_sizes, freqs) + rope_apply_inplace_cached(k, grid_sizes, freqs) + qkv = [q, k, v] + del q, k, v + x = flash_attention( + qkv, k_lens=seq_lens, window_size=self.window_size, attn_mode=self.attn_mode, split_attn=self.split_attn + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanT2VCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + # q = self.norm_q(self.q(x)).view(b, -1, n, d) + # k = self.norm_k(self.k(context)).view(b, -1, n, d) + # v = self.v(context).view(b, -1, n, d) + q = self.q(x) + del x + k = self.k(context) + v = self.v(context) + del context + q = self.norm_q(q) + k = self.norm_k(k) + q = q.view(b, -1, n, d) + k = k.view(b, -1, n, d) + v = v.view(b, -1, n, d) + + # compute attention + qkv = [q, k, v] + del q, k, v + x = flash_attention(qkv, k_lens=context_lens, attn_mode=self.attn_mode, split_attn=self.split_attn) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, attn_mode="torch", split_attn=False): + super().__init__(dim, num_heads, window_size, qk_norm, eps, attn_mode, split_attn) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + # self.alpha = nn.Parameter(torch.zeros((1, ))) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x) + del x + q = self.norm_q(q) + q = q.view(b, -1, n, d) + k = self.k(context) + k = self.norm_k(k).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + del context + + # compute attention + qkv = [q, k, v] + del k, v + x = flash_attention(qkv, k_lens=context_lens, attn_mode=self.attn_mode, split_attn=self.split_attn) + + # compute query, key, value + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + v_img = self.v_img(context_img).view(b, -1, n, d) + del context_img + + # compute attention + qkv = [q, k_img, v_img] + del q, k_img, v_img + img_x = flash_attention(qkv, k_lens=None, attn_mode=self.attn_mode, split_attn=self.split_attn) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + if self.training: + x = x + img_x # avoid inplace + else: + x += img_x + del img_x + + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = { + "t2v_cross_attn": WanT2VCrossAttention, + "i2v_cross_attn": WanI2VCrossAttention, +} + + +class WanAttentionBlock(nn.Module): + + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + attn_mode="torch", + split_attn=False, + ): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps, attn_mode, split_attn) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps, attn_mode, split_attn) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + # with amp.autocast(dtype=torch.float32): + # e = (self.modulation + e).chunk(6, dim=1) + # support fp8 + e = self.modulation.to(torch.float32) + e + e = e.chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn(self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs) + # with amp.autocast(dtype=torch.float32): + # x = x + y * e[2] + x = x + y.to(torch.float32) * e[2] + del y + + # cross-attention & ffn function + # def cross_attn_ffn(x, context, context_lens, e): + # x += self.cross_attn(self.norm3(x), context, context_lens) + # y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + # # with amp.autocast(dtype=torch.float32): + # # x = x + y * e[5] + # x += y.to(torch.float32) * e[5] + # return x + # x = cross_attn_ffn(x, context, context_lens, e) + + # x += self.cross_attn(self.norm3(x), context, context_lens) # backward error + x = x + self.cross_attn(self.norm3(x), context, context_lens) + del context + y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + x = x + y.to(torch.float32) * e[5] + del y + return x + + def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, x, e, seq_lens, grid_sizes, freqs, context, context_lens, use_reentrant=False) + return self._forward(x, e, seq_lens, grid_sizes, freqs, context, context_lens) + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + assert e.dtype == torch.float32 + # with amp.autocast(dtype=torch.float32): + # e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + # x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + # support fp8 + e = (self.modulation.to(torch.float32) + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + return x + + +class MLPProj(torch.nn.Module): + + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), + torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), + torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim), + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanModel(nn.Module): # ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"] + _no_split_modules = ["WanAttentionBlock"] + + # @register_to_config + def __init__( + self, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + attn_mode=None, + split_attn=False, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + self.attn_mode = attn_mode if attn_mode is not None else "torch" + self.split_attn = split_attn + + # embeddings + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" + self.blocks = nn.ModuleList( + [ + WanAttentionBlock( + cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, attn_mode, split_attn + ) + for _ in range(num_layers) + ] + ) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], dim=1 + ) + self.freqs_fhw = {} + + if model_type == "i2v": + self.img_emb = MLPProj(1280, dim) + + # initialize weights + self.init_weights() + + self.gradient_checkpointing = False + + # offloading + self.blocks_to_swap = None + self.offloader = None + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + + def fp8_optimization(self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool) -> int: + """ + Optimize the model state_dict with fp8. + + Args: + state_dict (dict[str, torch.Tensor]): + The state_dict of the model. + device (torch.device): + The device to calculate the weight. + move_to_device (bool): + Whether to move the weight to the device after optimization. + """ + TARGET_KEYS = ["blocks"] + EXCLUDE_KEYS = [ + "norm", + "patch_embedding", + "text_embedding", + "time_embedding", + "time_projection", + "head", + "modulation", + "img_emb", + ] + + # inplace optimization + state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device) + + # apply monkey patching + apply_fp8_monkey_patch(self, state_dict) + + return state_dict + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + for block in self.blocks: + block.enable_gradient_checkpointing() + + print(f"WanModel: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + for block in self.blocks: + block.disable_gradient_checkpointing() + + print(f"WanModel: Gradient checkpointing disabled.") + + def enable_block_swap(self, blocks_to_swap: int, device: torch.device, supports_backward: bool): + self.blocks_to_swap = blocks_to_swap + self.num_blocks = len(self.blocks) + + assert ( + self.blocks_to_swap <= self.num_blocks - 1 + ), f"Cannot swap more than {self.num_blocks - 1} blocks. Requested {self.blocks_to_swap} blocks to swap." + + self.offloader = ModelOffloader( + "wan_attn_block", self.blocks, self.num_blocks, self.blocks_to_swap, supports_backward, device # , debug=True + ) + print( + f"WanModel: Block swap enabled. Swapping {self.blocks_to_swap} blocks out of {self.num_blocks} blocks. Supports backward: {supports_backward}" + ) + + def switch_block_swap_for_inference(self): + if self.blocks_to_swap: + self.offloader.set_forward_only(True) + self.prepare_block_swap_before_forward() + print(f"WanModel: Block swap set to forward only.") + + def switch_block_swap_for_training(self): + if self.blocks_to_swap: + self.offloader.set_forward_only(False) + self.prepare_block_swap_before_forward() + print(f"WanModel: Block swap set to forward and backward.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_blocks = self.blocks + self.blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.blocks = save_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader.prepare_block_devices_before_forward(self.blocks) + + def forward(self, x, t, context, seq_len, clip_fea=None, y=None): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == "i2v": + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + y = None + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + + freqs_list = [] + for fhw in grid_sizes: + fhw = tuple(fhw.tolist()) + if fhw not in self.freqs_fhw: + c = self.dim // self.num_heads // 2 + self.freqs_fhw[fhw] = calculate_freqs_i(fhw, c, self.freqs) + freqs_list.append(self.freqs_fhw[fhw]) + + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len, f"Sequence length exceeds maximum allowed length {seq_len}. Got {seq_lens.max()}" + x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x]) + + # time embeddings + # with amp.autocast(dtype=torch.float32): + with torch.amp.autocast(device_type=device.type, dtype=torch.float32): + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + if type(context) is list: + context = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) + context = self.text_embedding(context) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + clip_fea = None + context_clip = None + + # arguments + kwargs = dict(e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=freqs_list, context=context, context_lens=context_lens) + + if self.blocks_to_swap: + clean_memory_on_device(device) + + # print(f"x: {x.shape}, e: {e0.shape}, context: {context.shape}, seq_lens: {seq_lens}") + for block_idx, block in enumerate(self.blocks): + if self.blocks_to_swap: + self.offloader.wait_for_block(block_idx) + + x = block(x, **kwargs) + + if self.blocks_to_swap: + self.offloader.submit_move_blocks_forward(self.blocks, block_idx) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) + + +def detect_wan_sd_dtype(path: str) -> torch.dtype: + # get dtype from model weights + with MemoryEfficientSafeOpen(path) as f: + keys = set(f.keys()) + key1 = "model.diffusion_model.blocks.0.cross_attn.k.weight" # 1.3B + key2 = "blocks.0.cross_attn.k.weight" # 14B + if key1 in keys: + dit_dtype = f.get_tensor(key1).dtype + elif key2 in keys: + dit_dtype = f.get_tensor(key2).dtype + else: + raise ValueError(f"Could not find the dtype in the model weights: {path}") + logger.info(f"Detected DiT dtype: {dit_dtype}") + return dit_dtype + + +def load_wan_model( + config: any, + i2v: bool, + device: Union[str, torch.device], + dit_path: str, + attn_mode: str, + split_attn: bool, + loading_device: Union[str, torch.device], + dit_weight_dtype: Optional[torch.dtype], + fp8_scaled: bool = False, +) -> WanModel: + # dit_weight_dtype is None for fp8_scaled + assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None) + + device = torch.device(device) + loading_device = torch.device(loading_device) + + with init_empty_weights(): + logger.info(f"Creating WanModel") + model = WanModel( + model_type="i2v" if i2v else "t2v", + dim=config.dim, + eps=config.eps, + ffn_dim=config.ffn_dim, + freq_dim=config.freq_dim, + in_dim=36 if i2v else 16, # 36 for I2V, 16 for T2V + num_heads=config.num_heads, + num_layers=config.num_layers, + out_dim=16, + text_len=512, + attn_mode=attn_mode, + split_attn=split_attn, + ) + if dit_weight_dtype is not None: + model.to(dit_weight_dtype) + + # if fp8_scaled, load model weights to CPU to reduce VRAM usage. Otherwise, load to the specified device (CPU for block swap or CUDA for others) + wan_loading_device = torch.device("cpu") if fp8_scaled else loading_device + logger.info(f"Loading DiT model from {dit_path}, device={wan_loading_device}, dtype={dit_weight_dtype}") + + # load model weights with the specified dtype or as is + sd = load_safetensors(dit_path, wan_loading_device, disable_mmap=True, dtype=dit_weight_dtype) + + # remove "model.diffusion_model." prefix: 1.3B model has this prefix + for key in list(sd.keys()): + if key.startswith("model.diffusion_model."): + sd[key[22:]] = sd.pop(key) + + if fp8_scaled: + # fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap) + logger.info(f"Optimizing model weights to fp8. This may take a while.") + sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu") + + if loading_device.type != "cpu": + # make sure all the model weights are on the loading_device + logger.info(f"Moving weights to {loading_device}") + for key in sd.keys(): + sd[key] = sd[key].to(loading_device) + + info = model.load_state_dict(sd, strict=True, assign=True) + logger.info(f"Loaded DiT model from {dit_path}, info={info}") + + return model diff --git a/modules/scheduling_flow_match_discrete.py b/modules/scheduling_flow_match_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..c507ec4eb050463188e250c20aec8d1fde2c4a5d --- /dev/null +++ b/modules/scheduling_flow_match_discrete.py @@ -0,0 +1,257 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + reverse (`bool`, defaults to `True`): + Whether to reverse the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + reverse: bool = True, + solver: str = "euler", + n_tokens: Optional[int] = None, + ): + sigmas = torch.linspace(1, 0, num_train_timesteps + 1) + + if not reverse: + sigmas = sigmas.flip(0) + + self.sigmas = sigmas + # the value fed to model + self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) + + self._step_index = None + self._begin_index = None + + self.supported_solver = ["euler"] + if solver not in self.supported_solver: + raise ValueError( + f"Solver {solver} not supported. Supported solvers: {self.supported_solver}" + ) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + n_tokens: int = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + """ + self.num_inference_steps = num_inference_steps + + sigmas = torch.linspace(1, 0, num_inference_steps + 1) + sigmas = self.sd3_time_shift(sigmas) + + if not self.config.reverse: + sigmas = 1 - sigmas + + self.sigmas = sigmas + self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to( + dtype=torch.float32, device=device + ) + + # Reset step index + self._step_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def scale_model_input( + self, sample: torch.Tensor, timestep: Optional[int] = None + ) -> torch.Tensor: + return sample + + def sd3_time_shift(self, t: torch.Tensor): + return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] + + if self.config.solver == "euler": + prev_sample = sample + model_output.to(torch.float32) * dt + else: + raise ValueError( + f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}" + ) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/modules/unet_causal_3d_blocks.py b/modules/unet_causal_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..27d544170ece6a370cdacfe9e31367b884c2e516 --- /dev/null +++ b/modules/unet_causal_3d_blocks.py @@ -0,0 +1,818 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== + +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from einops import rearrange + +from diffusers.utils import logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import SpatialNorm +from diffusers.models.attention_processor import Attention +from diffusers.models.normalization import AdaGroupNorm +from diffusers.models.normalization import RMSNorm + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): + seq_len = n_frame * n_hw + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // n_hw + mask[i, : (i_frame + 1) * n_hw] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + +class CausalConv3d(nn.Module): + """ + Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations. + This maintains temporal causality in video generation tasks. + """ + + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + pad_mode="replicate", + chunk_size=0, + **kwargs, + ): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T + self.time_causal_padding = padding + self.chunk_size = chunk_size + + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def original_forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + def forward(self, x): + if self.chunk_size == 0: + return self.original_forward(x) + + # if not large, call original forward + if x.shape[4] < self.chunk_size * 1.5: + return self.original_forward(x) + + # # debug: verify the original forward is the same as chunked forward + # orig_forwarded_value = None + # if x.shape[4] < self.chunk_size * 4: + # orig_forwarded_value = self.original_forward(x) + + # get the kernel size + kernel_size = self.conv.kernel_size[0] # assume cubic kernel + assert kernel_size == self.conv.kernel_size[1] == self.conv.kernel_size[2], "Only cubic kernels are supported" + padding_size = kernel_size // 2 # 1 for kernel_size=3, 0 for kernel_size=1 + + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + + B, C, D, H, W = orig_shape = x.shape + chunk_size = self.chunk_size + chunk_size -= chunk_size % self.conv.stride[2] # make sure the chunk size is divisible by stride + # print(f"chunked forward: {x.shape}, chunk_size: {chunk_size}") + + # calculate the indices for chunking with overlap and padding by kernel size and stride + indices = [] + i = 0 + while i < W - padding_size: + start_idx = i - padding_size + end_idx = min(i + chunk_size + padding_size, W) + if i == 0: + start_idx = 0 + end_idx += padding_size # to make sure the first chunk is divisible by stride + if W - end_idx < chunk_size // 2: # small chunk at the end + end_idx = W + indices.append((start_idx, end_idx)) + i = end_idx - padding_size + # print(f"chunked forward: {x.shape}, chunked indices: {indices}") + + chunks = [] + for start_idx, end_idx in indices: + chunk = x[:, :, :, :, start_idx:end_idx] + chunk_output = self.conv(chunk) + # print(chunk.shape, chunk_output.shape) + chunks.append(chunk_output) + + # concatenate the chunks + x = torch.cat(chunks, dim=4) + + assert ( + x.shape[2] == ((D - padding_size * 2) + self.conv.stride[0] - 1) // self.conv.stride[0] + ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}" + assert ( + x.shape[3] == ((H - padding_size * 2) + self.conv.stride[1] - 1) // self.conv.stride[1] + ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}" + assert ( + x.shape[4] == ((W - padding_size * 2) + self.conv.stride[2] - 1) // self.conv.stride[2] + ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}" + + # # debug: verify the original forward is the same as chunked forward + # if orig_forwarded_value is not None: + # assert torch.allclose( + # orig_forwarded_value, x, rtol=1e-4, atol=1e-2 + # ), f"Chunked forward is different from original forward. {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}, {self.conv.kernel_size}" + + return x + + +class UpsampleCausal3D(nn.Module): + """ + A 3D upsampling layer with an optional convolution. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, + upsample_factor=(2, 2, 2), + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.interpolate = interpolate + self.upsample_factor = upsample_factor + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + if kernel_size is None: + kernel_size = 3 + conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + output_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + raise NotImplementedError + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if self.interpolate: + B, C, T, H, W = hidden_states.shape + first_h, other_h = hidden_states.split((1, T - 1), dim=2) + if output_size is None: + if T > 1: + other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest") + + first_h = first_h.squeeze(2) + first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest") + first_h = first_h.unsqueeze(2) + else: + raise NotImplementedError + + if T > 1: + hidden_states = torch.cat((first_h, other_h), dim=2) + else: + hidden_states = first_h + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class DownsampleCausal3D(nn.Module): + """ + A 3D downsampling layer with an optional convolution. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + kernel_size=3, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + stride=2, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = stride + self.name = name + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + if use_conv: + conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlockCausal3D(nn.Module): + r""" + A Resnet block. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + # default, scale_shift, ada_group, spatial + time_embedding_norm: str = "default", + kernel: Optional[torch.FloatTensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_3d_out_channels: Optional[int] = None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act + + linear_cls = nn.Linear + + if groups_out is None: + groups_out = groups + + if self.time_embedding_norm == "ada_group": + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm1 = SpatialNorm(in_channels, temb_channels) + else: + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = linear_cls(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) + elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": + self.time_emb_proj = None + else: + raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ") + else: + self.time_emb_proj = None + + if self.time_embedding_norm == "ada_group": + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm2 = SpatialNorm(out_channels, temb_channels) + else: + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_3d_out_channels = conv_3d_out_channels or out_channels + self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + self.upsample = UpsampleCausal3D(in_channels, use_conv=False) + elif self.down: + self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op") + + self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = CausalConv3d( + in_channels, + conv_3d_out_channels, + kernel_size=1, + stride=1, + bias=conv_shortcut_bias, + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + temb: torch.FloatTensor, + scale: float = 1.0, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": + hidden_states = self.norm1(hidden_states, temb) + else: + hidden_states = self.norm1(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor, scale=scale) + hidden_states = self.upsample(hidden_states, scale=scale) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor, scale=scale) + hidden_states = self.downsample(hidden_states, scale=scale) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb, scale)[:, :, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": + hidden_states = self.norm2(hidden_states, temb) + else: + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +def get_down_block3d( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + downsample_stride: int, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownEncoderBlockCausal3D": + return DownEncoderBlockCausal3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + downsample_stride=downsample_stride, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block3d( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + upsample_scale_factor: Tuple, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpDecoderBlockCausal3D": + return UpDecoderBlockCausal3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + upsample_scale_factor=upsample_scale_factor, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlockCausal3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks. + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + resnets = [ + ResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + B, C, T, H, W = hidden_states.shape + hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c") + attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B) + hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask) + hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class DownEncoderBlockCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_stride: int = 2, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockCausal3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + DownsampleCausal3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + stride=downsample_stride, + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None, scale=scale) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale) + + return hidden_states + + +class UpDecoderBlockCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + upsample_scale_factor=(2, 2, 2), + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlockCausal3D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + UpsampleCausal3D( + out_channels, + use_conv=True, + out_channels=out_channels, + upsample_factor=upsample_scale_factor, + ) + ] + ) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states diff --git a/networks/__init__.py b/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/networks/lora.py b/networks/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..971828958abc6d7a47e4bf3103f75abaf7299700 --- /dev/null +++ b/networks/lora.py @@ -0,0 +1,913 @@ +# LoRA network module: currently conv2d is not fully supported +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import ast +import math +import os +import re +from typing import Dict, List, Optional, Type, Union +from transformers import CLIPTextModel +import numpy as np +import torch +import torch.nn as nn + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +HUNYUAN_TARGET_REPLACE_MODULES = ["MMDoubleStreamBlock", "MMSingleStreamBlock"] + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + split_dims: Optional[List[int]] = None, + ): + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of multi-head attention. + """ + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + else: + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # for save/load + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # for reference + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # merge weight to org_module + # def merge_to(self, sd, dtype, device, non_blocking=False): + # if torch.cuda.is_available(): + # stream = torch.cuda.Stream(device=device) + # with torch.cuda.stream(stream): + # print(f"merge_to {self.lora_name}") + # self._merge_to(sd, dtype, device, non_blocking) + # torch.cuda.synchronize(device=device) + # print(f"merge_to {self.lora_name} done") + # torch.cuda.empty_cache() + # else: + # self._merge_to(sd, dtype, device, non_blocking) + + def merge_to(self, sd, dtype, device, non_blocking=False): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(device, dtype=torch.float, non_blocking=non_blocking) # for calculation + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(device, dtype=torch.float, non_blocking=non_blocking) + up_weight = sd["lora_up.weight"].to(device, dtype=torch.float, non_blocking=non_blocking) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(org_device, dtype=dtype) # back to CPU without non_blocking + self.org_module.load_state_dict(org_sd) + else: + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(device, torch.float, non_blocking=non_blocking) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(device, torch.float, non_blocking=non_blocking) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(org_device, dtype) # back to CPU without non_blocking + self.org_module.load_state_dict(org_sd) + + # return weight for merge + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_arch_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: nn.Module, + text_encoders: List[nn.Module], + unet: nn.Module, + neuron_dropout: Optional[float] = None, + **kwargs, +): + # add default exclude patterns + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is None: + exclude_patterns = [] + else: + exclude_patterns = ast.literal_eval(exclude_patterns) + + # exclude if 'img_mod', 'txt_mod' or 'modulation' in the name + exclude_patterns.append(r".*(img_mod|txt_mod|modulation).*") + + kwargs["exclude_patterns"] = exclude_patterns + + return create_network( + HUNYUAN_TARGET_REPLACE_MODULES, + "lora_unet", + multiplier, + network_dim, + network_alpha, + vae, + text_encoders, + unet, + neuron_dropout=neuron_dropout, + **kwargs, + ) + + +def create_network( + target_replace_modules: List[str], + prefix: str, + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: nn.Module, + text_encoders: List[nn.Module], + unet: nn.Module, + neuron_dropout: Optional[float] = None, + **kwargs, +): + """ architecture independent network creation """ + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # TODO generic rank/dim setting with regular expression + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # regular expression for module selection: exclude and include + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is not None and isinstance(exclude_patterns, str): + exclude_patterns = ast.literal_eval(exclude_patterns) + include_patterns = kwargs.get("include_patterns", None) + if include_patterns is not None and isinstance(include_patterns, str): + include_patterns = ast.literal_eval(include_patterns) + + # too many arguments ( ^ω^)・・・ + network = LoRANetwork( + target_replace_modules, + prefix, + text_encoders, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + exclude_patterns=exclude_patterns, + include_patterns=include_patterns, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + # loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + # loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + # loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + # loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None: # or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio) # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +class LoRANetwork(torch.nn.Module): + # only supports U-Net (DiT), Text Encoders are not supported + + def __init__( + self, + target_replace_modules: List[str], + prefix: str, + text_encoders: Union[List[CLIPTextModel], CLIPTextModel], + unet: nn.Module, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + exclude_patterns: Optional[List[str]] = None, + include_patterns: Optional[List[str]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.target_replace_modules = target_replace_modules + self.prefix = prefix + + self.loraplus_lr_ratio = None + # self.loraplus_unet_lr_ratio = None + # self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + # if train_t5xxl: + # logger.info(f"train T5XXL as well") + + # compile regular expression if specified + exclude_re_patterns = [] + if exclude_patterns is not None: + for pattern in exclude_patterns: + try: + re_pattern = re.compile(pattern) + except re.error as e: + logger.error(f"Invalid exclude pattern '{pattern}': {e}") + continue + exclude_re_patterns.append(re_pattern) + + include_re_patterns = [] + if include_patterns is not None: + for pattern in include_patterns: + try: + re_pattern = re.compile(pattern) + except re.error as e: + logger.error(f"Invalid include pattern '{pattern}': {e}") + continue + include_re_patterns.append(re_pattern) + + # create module instances + def create_modules( + is_unet: bool, + pfx: str, + root_module: torch.nn.Module, + target_replace_mods: Optional[List[str]] = None, + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[LoRAModule]: + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_mods is None or module.__class__.__name__ in target_replace_mods: + if target_replace_mods is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + original_name = (name + "." if name else "") + child_name + lora_name = f"{pfx}.{original_name}".replace(".", "_") + + # exclude/include filter + excluded = False + for pattern in exclude_re_patterns: + if pattern.match(original_name): + excluded = True + break + included = False + for pattern in include_re_patterns: + if pattern.match(original_name): + included = True + break + if excluded and not included: + if verbose: + logger.info(f"exclude: {original_name}") + continue + + # filter by name (not used in the current implementation) + if filter is not None and not filter in lora_name: + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + + if target_replace_mods is None: + break # all modules are searched + return loras, skipped + + # # create LoRA for text encoder + # # it is redundant to create LoRA modules even if they are not used + + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + # skipped_te = [] + # for i, text_encoder in enumerate(text_encoders): + # index = i + # if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + # break + # logger.info(f"create LoRA for Text Encoder {index+1}:") + # text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + # logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + # self.text_encoder_loras.extend(text_encoder_loras) + # skipped_te += skipped + + # create LoRA for U-Net + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, prefix, unet, target_replace_modules) + + logger.info(f"create LoRA for U-Net/DiT: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def prepare_network(self, args): + """ + called after the network is created + """ + pass + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to( + self, + text_encoders: Optional[nn.Module], + unet: Optional[nn.Module], + apply_text_encoder: bool = True, + apply_unet: bool = True, + ): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None, non_blocking=False): + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=2) as executor: # 2 workers is enough + futures = [] + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + if len(sd_for_lora) == 0: + logger.info(f"no weight for {lora.lora_name}") + continue + + # lora.merge_to(sd_for_lora, dtype, device) + futures.append(executor.submit(lora.merge_to, sd_for_lora, dtype, device, non_blocking)) + + for future in futures: + future.result() + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio): # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_lr_ratio}") + # logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params(self, unet_lr: float = 1e-4, **kwargs): + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.unet_loras: + params, descriptions = assemble_params(self.unet_loras, unet_lr, self.loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, unet): + self.requires_grad_(True) + + def on_epoch_start(self, unet): + self.train() + + def on_step_start(self): + pass + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from utils import model_utils + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = model_utils.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) + + +def create_arch_network_from_weights( + multiplier: float, + weights_sd: Dict[str, torch.Tensor], + text_encoders: Optional[List[nn.Module]] = None, + unet: Optional[nn.Module] = None, + for_inference: bool = False, + **kwargs, +) -> LoRANetwork: + return create_network_from_weights( + HUNYUAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs + ) + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights( + target_replace_modules: List[str], + multiplier: float, + weights_sd: Dict[str, torch.Tensor], + text_encoders: Optional[List[nn.Module]] = None, + unet: Optional[nn.Module] = None, + for_inference: bool = False, + **kwargs, +) -> LoRANetwork: + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.shape[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + target_replace_modules, + "lora_unet", + text_encoders, + unet, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + ) + return network diff --git a/networks/lora_framepack.py b/networks/lora_framepack.py new file mode 100644 index 0000000000000000000000000000000000000000..4b627d4d5188257f5ceca9e467e7c0964e4dd5e8 --- /dev/null +++ b/networks/lora_framepack.py @@ -0,0 +1,65 @@ +# LoRA module for FramePack + +import ast +from typing import Dict, List, Optional +import torch +import torch.nn as nn + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +import networks.lora as lora + + +FRAMEPACK_TARGET_REPLACE_MODULES = ["HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock"] + + +def create_arch_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: nn.Module, + text_encoders: List[nn.Module], + unet: nn.Module, + neuron_dropout: Optional[float] = None, + **kwargs, +): + # add default exclude patterns + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is None: + exclude_patterns = [] + else: + exclude_patterns = ast.literal_eval(exclude_patterns) + + # exclude if 'norm' in the name of the module + exclude_patterns.append(r".*(norm).*") + + kwargs["exclude_patterns"] = exclude_patterns + + return lora.create_network( + FRAMEPACK_TARGET_REPLACE_MODULES, + "lora_unet", + multiplier, + network_dim, + network_alpha, + vae, + text_encoders, + unet, + neuron_dropout=neuron_dropout, + **kwargs, + ) + + +def create_arch_network_from_weights( + multiplier: float, + weights_sd: Dict[str, torch.Tensor], + text_encoders: Optional[List[nn.Module]] = None, + unet: Optional[nn.Module] = None, + for_inference: bool = False, + **kwargs, +) -> lora.LoRANetwork: + return lora.create_network_from_weights( + FRAMEPACK_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs + ) diff --git a/networks/lora_wan.py b/networks/lora_wan.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b171a741d317a551f17d1f45046e7eed6b161e --- /dev/null +++ b/networks/lora_wan.py @@ -0,0 +1,65 @@ +# LoRA module for Wan2.1 + +import ast +from typing import Dict, List, Optional +import torch +import torch.nn as nn + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +import networks.lora as lora + + +WAN_TARGET_REPLACE_MODULES = ["WanAttentionBlock"] + + +def create_arch_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: nn.Module, + text_encoders: List[nn.Module], + unet: nn.Module, + neuron_dropout: Optional[float] = None, + **kwargs, +): + # add default exclude patterns + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is None: + exclude_patterns = [] + else: + exclude_patterns = ast.literal_eval(exclude_patterns) + + # exclude if 'img_mod', 'txt_mod' or 'modulation' in the name + exclude_patterns.append(r".*(patch_embedding|text_embedding|time_embedding|time_projection|norm|head).*") + + kwargs["exclude_patterns"] = exclude_patterns + + return lora.create_network( + WAN_TARGET_REPLACE_MODULES, + "lora_unet", + multiplier, + network_dim, + network_alpha, + vae, + text_encoders, + unet, + neuron_dropout=neuron_dropout, + **kwargs, + ) + + +def create_arch_network_from_weights( + multiplier: float, + weights_sd: Dict[str, torch.Tensor], + text_encoders: Optional[List[nn.Module]] = None, + unet: Optional[nn.Module] = None, + for_inference: bool = False, + **kwargs, +) -> lora.LoRANetwork: + return lora.create_network_from_weights( + WAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..66c812e35d41eaf2438419b38abfc3dc9824b8fe --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[project] +name = "musubi-tuner" +version = "0.1.0" +description = "Musubi Tuner by kohya_ss" +readme = "README.md" +requires-python = ">=3.10, <3.11" +dependencies = [ + "accelerate>=1.0.0", + "ascii-magic==2.3.0", + "av==14.0.1", + "bitsandbytes>=0.45.0", + "diffusers>=0.32.1", + "einops>=0.7.0", + "huggingface-hub>=0.26.5", + "matplotlib>=3.10.0", + "opencv-python>=4.10.0.84", + "pillow>=10.2.0", + "safetensors>=0.4.5", + "sageattention>=1.0.6", + "tensorboard>=2.18.0", + "toml>=0.10.2", + "torch>=2.5.1", + "torchvision>=0.20.1", + "tqdm>=4.66.5", + "transformers>=4.46.3", + "voluptuous>=0.15.2", +] + +[tool.uv.sources] +torch = [ + { index = "pytorch-cu124" }, +] +torchvision = [ + { index = "pytorch-cu124" }, +] + +[[tool.uv.index]] +name = "pytorch-cu124" +url = "https://download.pytorch.org/whl/cu124" +explicit = true diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..76396f5329686165e9985792b89faeba020e1098 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,100 @@ +absl-py==2.1.0 +accelerate==1.6.0 +aiofiles==23.2.1 +annotated-types==0.7.0 +anyio==4.8.0 +ascii-magic==2.3.0 +av==12.1.0 +bitsandbytes==0.45.0 +certifi==2025.1.31 +charset-normalizer==3.4.1 +click==8.1.8 +colorama==0.4.6 +contourpy==1.3.1 +cycler==0.12.1 +diffusers==0.33.1 +easydict==1.13 +einops==0.7.0 +exceptiongroup==1.2.2 +fastapi==0.115.8 +ffmpeg==1.4 +ffmpeg-python==0.2.0 +ffmpy==0.5.0 +filelock==3.13.1 +fonttools==4.55.8 +fsspec==2024.6.1 +ftfy==6.3.1 +future==1.0.0 +gradio==5.14.0 +gradio_client==1.7.0 +grpcio==1.70.0 +h11==0.14.0 +httpcore==1.0.7 +httpx==0.28.1 +huggingface-hub==0.27.0 +idna==3.10 +importlib_metadata==8.6.1 +Jinja2==3.1.3 +kiwisolver==1.4.8 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.10.0 +mdurl==0.1.2 +mpmath==1.3.0 +networkx==3.3 +numpy==1.26.2 +opencv-python==4.10.0.84 +orjson==3.10.15 +packaging==24.2 +pandas==2.2.3 +pillow==11.1.0 +protobuf==5.29.3 +psutil==6.1.1 +pydantic==2.10.6 +pydantic_core==2.27.2 +pydub==0.25.1 +Pygments==2.19.1 +pyparsing==3.2.1 +python-dateutil==2.9.0.post0 +python-multipart==0.0.20 +pytz==2025.1 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.31.0 +rich==13.9.4 +rich-argparse>=1.5.0 +ruff==0.9.4 +safehttpx==0.1.6 +safetensors==0.4.5 +scipy==1.12.0 +semantic-version==2.10.0 +sentencepiece==0.2.0 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +starlette==0.45.3 +sympy==1.13.1 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tiktoken==0.9.0 +tokenizers==0.20.3 +toml==0.10.2 +tomlkit==0.13.2 +torch==2.5.1+cu124 +torchvision==0.20.1+cu124 +torchsde==0.2.6 +tqdm==4.67.1 +transformers==4.46.2 +typer==0.15.1 +typing_extensions==4.12.2 +tzdata==2025.1 +urllib3==2.3.0 +uvicorn==0.34.0 +voluptuous==0.15.2 +wcwidth==0.2.13 +websockets==14.2 +Werkzeug==3.1.3 +zipp==3.21.0 +imageio-ffmpeg +decord \ No newline at end of file diff --git a/requirementsFP.txt b/requirementsFP.txt new file mode 100644 index 0000000000000000000000000000000000000000..4d884e025c14aaf49e1809d60e56afb8ca49c793 --- /dev/null +++ b/requirementsFP.txt @@ -0,0 +1,14 @@ +accelerate==1.6.0 +diffusers==0.33.1 +transformers==4.46.2 +sentencepiece==0.2.0 +pillow==11.1.0 +av==12.1.0 +numpy==1.26.2 +scipy==1.12.0 +requests==2.31.0 +torchsde==0.2.6 + +einops +opencv-contrib-python +safetensors diff --git a/requirementsPinokio.txt b/requirementsPinokio.txt new file mode 100644 index 0000000000000000000000000000000000000000..a5ca3c241f31ca9f2230a1f90b8bb07b4bdba6e9 --- /dev/null +++ b/requirementsPinokio.txt @@ -0,0 +1,96 @@ +absl-py==2.1.0 +accelerate==1.6.0 +aiofiles==23.2.1 +annotated-types==0.7.0 +anyio==4.8.0 +ascii-magic==2.3.0 +av==12.1.0 +bitsandbytes==0.45.0 +certifi==2025.1.31 +charset-normalizer==3.4.1 +click==8.1.8 +colorama==0.4.6 +contourpy==1.3.1 +cycler==0.12.1 +diffusers==0.33.1 +easydict==1.13 +einops==0.7.0 +exceptiongroup==1.2.2 +fastapi==0.115.8 +ffmpeg==1.4 +ffmpeg-python==0.2.0 +ffmpy==0.5.0 +filelock==3.13.1 +fonttools==4.55.8 +fsspec==2024.6.1 +ftfy==6.3.1 +future==1.0.0 +gradio==5.14.0 +gradio_client==1.7.0 +grpcio==1.70.0 +h11==0.14.0 +httpcore==1.0.7 +httpx==0.28.1 +huggingface-hub==0.27.0 +idna==3.10 +importlib_metadata==8.6.1 +Jinja2==3.1.3 +kiwisolver==1.4.8 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.10.0 +mdurl==0.1.2 +mpmath==1.3.0 +networkx==3.3 +numpy==1.26.2 +opencv-python==4.10.0.84 +orjson==3.10.15 +packaging==24.2 +pandas==2.2.3 +pillow==11.1.0 +protobuf==5.29.3 +psutil==6.1.1 +pydantic==2.10.6 +pydantic_core==2.27.2 +pydub==0.25.1 +Pygments==2.19.1 +pyparsing==3.2.1 +python-dateutil==2.9.0.post0 +python-multipart==0.0.20 +pytz==2025.1 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.31.0 +rich==13.9.4 +ruff==0.9.4 +safehttpx==0.1.6 +safetensors==0.4.5 +semantic-version==2.10.0 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +starlette==0.45.3 +sympy==1.13.1 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tiktoken==0.9.0 +tokenizers==0.20.3 +toml==0.10.2 +tomlkit==0.13.2 +tqdm==4.67.1 +transformers==4.46.2 +typer==0.15.1 +typing_extensions==4.12.2 +tzdata==2025.1 +urllib3==2.3.0 +uvicorn==0.34.0 +voluptuous==0.15.2 +wcwidth==0.2.13 +websockets==14.2 +Werkzeug==3.1.3 +zipp==3.21.0 +sentencepiece==0.2.0 +scipy==1.12.0 +torchsde==0.2.6 +opencv-contrib-python \ No newline at end of file diff --git a/requirementsTorch27.txt b/requirementsTorch27.txt new file mode 100644 index 0000000000000000000000000000000000000000..a92e24ac7a1de7d3be548de7de80e098c2c0f70f --- /dev/null +++ b/requirementsTorch27.txt @@ -0,0 +1,98 @@ +torch==2.7.0+cu128 +torchvision==0.22.0+cu128 +torchaudio==2.7.0 +absl-py==2.1.0 +accelerate==1.2.1 +aiofiles==23.2.1 +annotated-types==0.7.0 +anyio==4.8.0 +ascii-magic==2.3.0 +av==14.0.1 +bitsandbytes==0.45.0 +certifi==2025.1.31 +charset-normalizer==3.4.1 +click==8.1.8 +colorama==0.4.6 +contourpy==1.3.1 +cycler==0.12.1 +diffusers==0.32.1 +easydict==1.13 +einops==0.7.0 +exceptiongroup==1.2.2 +fastapi==0.115.8 +ffmpeg==1.4 +ffmpeg-python==0.2.0 +ffmpy==0.5.0 +filelock==3.13.1 +fonttools==4.55.8 +fsspec==2024.6.1 +ftfy==6.3.1 +future==1.0.0 +gradio==5.14.0 +gradio_client==1.7.0 +grpcio==1.70.0 +h11==0.14.0 +httpcore==1.0.7 +httpx==0.28.1 +huggingface-hub==0.26.5 +idna==3.10 +importlib_metadata==8.6.1 +Jinja2==3.1.3 +kiwisolver==1.4.8 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.10.0 +mdurl==0.1.2 +mpmath==1.3.0 +networkx==3.3 +numpy==2.1.2 +opencv-python==4.10.0.84 +orjson==3.10.15 +packaging==24.2 +pandas==2.2.3 +pillow==10.2.0 +protobuf==5.29.3 +psutil==6.1.1 +pydantic==2.10.6 +pydantic_core==2.27.2 +pydub==0.25.1 +Pygments==2.19.1 +pyparsing==3.2.1 +python-dateutil==2.9.0.post0 +python-multipart==0.0.20 +pytz==2025.1 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +rich==13.9.4 +rich-argparse>=1.5.0 +ruff==0.9.4 +safehttpx==0.1.6 +safetensors==0.4.5 +semantic-version==2.10.0 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +starlette==0.45.3 +sympy +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tiktoken==0.9.0 +tokenizers==0.20.3 +toml==0.10.2 +tomlkit==0.13.2 +tqdm==4.67.1 +transformers==4.46.3 +typer==0.15.1 +typing_extensions==4.12.2 +tzdata==2025.1 +urllib3==2.3.0 +uvicorn==0.34.0 +voluptuous==0.15.2 +wcwidth==0.2.13 +websockets==14.2 +Werkzeug==3.1.3 +zipp==3.21.0 +imageio-ffmpeg +decord \ No newline at end of file diff --git a/run_video_generate.py b/run_video_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..20212743a7b1014ef2e61d7a2ca4aba922ef00fa --- /dev/null +++ b/run_video_generate.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 + +import sys +import subprocess +import random + +def main(): + """ + Usage: + python run_hv_generate_video.py "Your prompt text here" + """ + if len(sys.argv) < 2: + print("Error: No prompt provided.") + print("Usage: python run_hv_generate_video.py \"\"") + #sys.exit(1) + + # Capture the prompt from command-line arguments + #prompt = sys.argv[1] + SkyReelsModel = "Skywork/SkyReels-V1-Hunyuan-I2V" + # Generate a random seed + random_seed = random.randint(0, 2**32 - 1) + # Construct the command + cmd = [ + # quant: Enable FP8 weight-only quantization + # offload: Enable offload model + # high_cpu_memory: Enable pinned memory to reduce the overhead of model offloading. + # parameters_level: Further reduce GPU VRAM usage. + "python3", "video_generate.py", + "--model_id", SkyReelsModel, + "--guidance_scale", "6.0", + "--height", "720", + "--width", "720", + "--num_frames", "97", + "--prompt", "FPS-24, In a serene scene along a detailed oceanfront, a feral female alicorn Twilight Sparkle from My Little Pony stands alone. The waves crash against the shore, splashing her face with salty water, as she gazes out at the vast, indifferent sea. ", + "--embedded_guidance_scale", "1.0", + "--quant", + "--offload", + "--high_cpu_memory", + "--parameters_level", + "--image", "img/ocean.webp", + "--seed", str(random_seed), + "--task_type", "i2v" + ] + # Print the exact command (for debugging/logging) + print("Executing command with random seed:", random_seed) + print(" ".join(cmd)) + + # Run the command + subprocess.run(cmd, check=True) + +if __name__ == "__main__": + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/device_utils.py b/utils/device_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b14803e499d7b92acebf8d8bddc3426d178695c4 --- /dev/null +++ b/utils/device_utils.py @@ -0,0 +1,19 @@ +import torch + + +def clean_memory_on_device(device): + if device.type == "cuda": + torch.cuda.empty_cache() + elif device.type == "cpu": + pass + elif device.type == "mps": # not tested + torch.mps.empty_cache() + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() diff --git a/utils/huggingface_utils.py b/utils/huggingface_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc7bd7dbb2ef70e0b6244b9db686aae00f46408 --- /dev/null +++ b/utils/huggingface_utils.py @@ -0,0 +1,89 @@ +import threading +from typing import Union, BinaryIO +from huggingface_hub import HfApi +from pathlib import Path +import argparse +import os +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def fire_in_thread(f, *args, **kwargs): + threading.Thread(target=f, args=args, kwargs=kwargs).start() + + +def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): + api = HfApi( + token=token, + ) + try: + api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + return True + except: + return False + + +def upload( + args: argparse.Namespace, + src: Union[str, Path, bytes, BinaryIO], + dest_suffix: str = "", + force_sync_upload: bool = False, +): + repo_id = args.huggingface_repo_id + repo_type = args.huggingface_repo_type + token = args.huggingface_token + path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None + private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" + api = HfApi(token=token) + if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): + try: + api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + except Exception as e: # RepositoryNotFoundError or something else + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + logger.error("===========================================") + + is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) + + def uploader(): + try: + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=src, + path_in_repo=path_in_repo, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=src, + path_in_repo=path_in_repo, + ) + except Exception as e: # RuntimeError or something else + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") + + if args.async_upload and not force_sync_upload: + fire_in_thread(uploader) + else: + uploader() + + +def list_dir( + repo_id: str, + subfolder: str, + repo_type: str, + revision: str = "main", + token: str = None, +): + api = HfApi( + token=token, + ) + repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)] + return file_list diff --git a/utils/model_utils.py b/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5beed8ec4e09f433ba2e84556a6c8f342a2903f5 --- /dev/null +++ b/utils/model_utils.py @@ -0,0 +1,151 @@ +import hashlib +from io import BytesIO +from typing import Optional + +import safetensors.torch +import torch + + +def model_hash(filename): + """Old model hash used by stable-diffusion-webui""" + try: + with open(filename, "rb") as file: + m = hashlib.sha256() + + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" + + +def calculate_sha256(filename): + """New model hash used by stable-diffusion-webui""" + try: + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + + +def dtype_to_str(dtype: torch.dtype) -> str: + # get name of the dtype + dtype_name = str(dtype).split(".")[-1] + return dtype_name + + +def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + """ + Convert a string to a torch.dtype + + Args: + s: string representation of the dtype + default_dtype: default dtype to return if s is None + + Returns: + torch.dtype: the corresponding torch.dtype + + Raises: + ValueError: if the dtype is not supported + + Examples: + >>> str_to_dtype("float32") + torch.float32 + >>> str_to_dtype("fp32") + torch.float32 + >>> str_to_dtype("float16") + torch.float16 + >>> str_to_dtype("fp16") + torch.float16 + >>> str_to_dtype("bfloat16") + torch.bfloat16 + >>> str_to_dtype("bf16") + torch.bfloat16 + >>> str_to_dtype("fp8") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fn") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fnuz") + torch.float8_e4m3fnuz + >>> str_to_dtype("fp8_e5m2") + torch.float8_e5m2 + >>> str_to_dtype("fp8_e5m2fnuz") + torch.float8_e5m2fnuz + """ + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32", "float"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") diff --git a/utils/safetensors_utils.py b/utils/safetensors_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d33b3c25ddee119212de80332b609f9dbd6b251d --- /dev/null +++ b/utils/safetensors_utils.py @@ -0,0 +1,221 @@ +import os +import re +import torch +import json +import struct +from typing import Dict, Any, Union, Optional + +from safetensors.torch import load_file + + +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + # print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack(" Dict[str, str]: + return self.header.get("__metadata__", {}) + + def get_tensor(self, key): + if key not in self.header: + raise KeyError(f"Tensor '{key}' not found in the file") + + metadata = self.header[key] + offset_start, offset_end = metadata["data_offsets"] + + if offset_start == offset_end: + tensor_bytes = None + else: + # adjust offset by header size + self.file.seek(self.header_size + 8 + offset_start) + tensor_bytes = self.file.read(offset_end - offset_start) + + return self._deserialize_tensor(tensor_bytes, metadata) + + def _read_header(self): + header_size = struct.unpack(" dict[str, torch.Tensor]: + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + # logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + state_dict = load_file(path, device=device) + except: + state_dict = load_file(path) # prevent device invalid Error + if dtype is not None: + for key in state_dict.keys(): + state_dict[key] = state_dict[key].to(dtype=dtype) + return state_dict + + +def load_split_weights( + file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False +) -> Dict[str, torch.Tensor]: + """ + Load split weights from a file. If the file name ends with 00001-of-00004 etc, it will load all files with the same prefix. + dtype is as is, no conversion is done. + """ + device = torch.device(device) + + # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix + basename = os.path.basename(file_path) + match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename) + if match: + prefix = basename[: match.start(2)] + count = int(match.group(3)) + state_dict = {} + for i in range(count): + filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors" + filepath = os.path.join(os.path.dirname(file_path), filename) + if os.path.exists(filepath): + state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap)) + else: + raise FileNotFoundError(f"File {filepath} not found") + else: + state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap) + return state_dict diff --git a/utils/sai_model_spec.py b/utils/sai_model_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..264340cf532166922849db9cf520a23c133cca99 --- /dev/null +++ b/utils/sai_model_spec.py @@ -0,0 +1,286 @@ +# based on https://github.com/Stability-AI/ModelSpec +import datetime +import hashlib +from io import BytesIO +import os +from typing import List, Optional, Tuple, Union +import safetensors +import logging + +from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, ARCHITECTURE_WAN, ARCHITECTURE_FRAMEPACK + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +r""" +# Metadata Example +metadata = { + # === Must === + "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID + "modelspec.implementation": "sgm", + "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc + # === Should === + "modelspec.author": "Example Corp", # Your name or company name + "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know + "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created + # === Can === + "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc. + "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model +} +""" + +BASE_METADATA = { + # === Must === + "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + "modelspec.architecture": None, + "modelspec.implementation": None, + "modelspec.title": None, + "modelspec.resolution": None, + # === Should === + "modelspec.description": None, + "modelspec.author": None, + "modelspec.date": None, + # === Can === + "modelspec.license": None, + "modelspec.tags": None, + "modelspec.merged_from": None, + "modelspec.prediction_type": None, + "modelspec.timestep_range": None, + "modelspec.encoder_layer": None, +} + +# 別に使うやつだけ定義 +MODELSPEC_TITLE = "modelspec.title" + +ARCH_HUNYUAN_VIDEO = "hunyuan-video" + +# Official Wan2.1 weights does not have sai_model_spec, so we use this as an architecture name +ARCH_WAN = "wan2.1" + +ARCH_FRAMEPACK = "framepack" + +ADAPTER_LORA = "lora" + +IMPL_HUNYUAN_VIDEO = "https://github.com/Tencent/HunyuanVideo" +IMPL_WAN = "https://github.com/Wan-Video/Wan2.1" +IMPL_FRAMEPACK = "https://github.com/lllyasviel/FramePack" + +PRED_TYPE_EPSILON = "epsilon" +# PRED_TYPE_V = "v" + + +def load_bytes_in_safetensors(tensors): + bytes = safetensors.torch.save(tensors) + b = BytesIO(bytes) + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + + return b.read() + + +def precalculate_safetensors_hashes(state_dict): + # calculate each tensor one by one to reduce memory usage + hash_sha256 = hashlib.sha256() + for tensor in state_dict.values(): + single_tensor_sd = {"tensor": tensor} + bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) + hash_sha256.update(bytes_for_tensor) + + return f"0x{hash_sha256.hexdigest()}" + + +def update_hash_sha256(metadata: dict, state_dict: dict): + raise NotImplementedError + + +def build_metadata( + state_dict: Optional[dict], + architecture: str, + timestamp: float, + title: Optional[str] = None, + reso: Optional[Union[int, Tuple[int, int]]] = None, + author: Optional[str] = None, + description: Optional[str] = None, + license: Optional[str] = None, + tags: Optional[str] = None, + merged_from: Optional[str] = None, + timesteps: Optional[Tuple[int, int]] = None, + is_lora: bool = True, +): + metadata = {} + metadata.update(BASE_METADATA) + + # TODO implement if we can calculate hash without loading all tensors + # if state_dict is not None: + # hash = precalculate_safetensors_hashes(state_dict) + # metadata["modelspec.hash_sha256"] = hash + + # arch = ARCH_HUNYUAN_VIDEO + if architecture == ARCHITECTURE_HUNYUAN_VIDEO: + arch = ARCH_HUNYUAN_VIDEO + impl = IMPL_HUNYUAN_VIDEO + elif architecture == ARCHITECTURE_WAN: + arch = ARCH_WAN + impl = IMPL_WAN + elif architecture == ARCHITECTURE_FRAMEPACK: + arch = ARCH_FRAMEPACK + impl = IMPL_FRAMEPACK + else: + raise ValueError(f"Unknown architecture: {architecture}") + + if is_lora: + arch += f"/{ADAPTER_LORA}" + metadata["modelspec.architecture"] = arch + + metadata["modelspec.implementation"] = impl + + if title is None: + title = "LoRA" if is_lora else "Hunyuan-Video" + title += f"@{timestamp}" + metadata[MODELSPEC_TITLE] = title + + if author is not None: + metadata["modelspec.author"] = author + else: + del metadata["modelspec.author"] + + if description is not None: + metadata["modelspec.description"] = description + else: + del metadata["modelspec.description"] + + if merged_from is not None: + metadata["modelspec.merged_from"] = merged_from + else: + del metadata["modelspec.merged_from"] + + if license is not None: + metadata["modelspec.license"] = license + else: + del metadata["modelspec.license"] + + if tags is not None: + metadata["modelspec.tags"] = tags + else: + del metadata["modelspec.tags"] + + # remove microsecond from time + int_ts = int(timestamp) + + # time to iso-8601 compliant date + date = datetime.datetime.fromtimestamp(int_ts).isoformat() + metadata["modelspec.date"] = date + + if reso is not None: + # comma separated to tuple + if isinstance(reso, str): + reso = tuple(map(int, reso.split(","))) + if len(reso) == 1: + reso = (reso[0], reso[0]) + else: + # resolution is defined in dataset, so use default + reso = (1280, 720) + if isinstance(reso, int): + reso = (reso, reso) + + metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" + + # metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON + del metadata["modelspec.prediction_type"] + + if timesteps is not None: + if isinstance(timesteps, str) or isinstance(timesteps, int): + timesteps = (timesteps, timesteps) + if len(timesteps) == 1: + timesteps = (timesteps[0], timesteps[0]) + metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" + else: + del metadata["modelspec.timestep_range"] + + # if clip_skip is not None: + # metadata["modelspec.encoder_layer"] = f"{clip_skip}" + # else: + del metadata["modelspec.encoder_layer"] + + # # assert all values are filled + # assert all([v is not None for v in metadata.values()]), metadata + if not all([v is not None for v in metadata.values()]): + logger.error(f"Internal error: some metadata values are None: {metadata}") + + return metadata + + +# region utils + + +def get_title(metadata: dict) -> Optional[str]: + return metadata.get(MODELSPEC_TITLE, None) + + +def load_metadata_from_safetensors(model: str) -> dict: + if not model.endswith(".safetensors"): + return {} + + with safetensors.safe_open(model, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + return metadata + + +def build_merged_from(models: List[str]) -> str: + def get_title(model: str): + metadata = load_metadata_from_safetensors(model) + title = metadata.get(MODELSPEC_TITLE, None) + if title is None: + title = os.path.splitext(os.path.basename(model))[0] # use filename + return title + + titles = [get_title(model) for model in models] + return ", ".join(titles) + + +# endregion + + +r""" +if __name__ == "__main__": + import argparse + import torch + from safetensors.torch import load_file + from library import train_util + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", type=str, required=True) + args = parser.parse_args() + + print(f"Loading {args.ckpt}") + state_dict = load_file(args.ckpt) + + print(f"Calculating metadata") + metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0) + print(metadata) + del state_dict + + # by reference implementation + with open(args.ckpt, mode="rb") as file_data: + file_hash = hashlib.sha256() + head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix + header = json.loads(file_data.read(head_len[0])) # header itself, json string + content = ( + file_data.read() + ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl. + file_hash.update(content) + # ===== Update the hash for modelspec ===== + by_ref = f"0x{file_hash.hexdigest()}" + print(by_ref) + print("is same?", by_ref == metadata["modelspec.hash_sha256"]) + +""" diff --git a/utils/train_utils.py b/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6af9ae69c5f4748406ae96577f373ae6df5da1 --- /dev/null +++ b/utils/train_utils.py @@ -0,0 +1,177 @@ +import argparse +import logging +import os +import shutil + +import accelerate +import torch + +from utils import huggingface_utils + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +# checkpointファイル名 +EPOCH_STATE_NAME = "{}-{:06d}-state" +EPOCH_FILE_NAME = "{}-{:06d}" +EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" +LAST_STATE_NAME = "{}-state" +STEP_STATE_NAME = "{}-step{:08d}-state" +STEP_FILE_NAME = "{}-step{:08d}" +STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}" + + +def get_sanitized_config_or_none(args: argparse.Namespace): + # if `--log_config` is enabled, return args for logging. if not, return None. + # when `--log_config is enabled, filter out sensitive values from args + # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe + + if not args.log_config: + return None + + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "dit", + "vae", + "text_encoder1", + "text_encoder2", + "base_weights", + "network_weights", + "output_dir", + "logging_dir", + ] + filtered_args = {} + for k, v in vars(args).items(): + # filter out sensitive values and convert to string if necessary + if k not in sensitive_args + sensitive_path_args: + # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): + filtered_args[k] = v + # accelerate does not support lists + elif isinstance(v, list): + filtered_args[k] = f"{v}" + # accelerate does not support objects + elif isinstance(v, object): + filtered_args[k] = f"{v}" + + return filtered_args + + +class LossRecorder: + def __init__(self): + self.loss_list: list[float] = [] + self.loss_total: float = 0.0 + + def add(self, *, epoch: int, step: int, loss: float) -> None: + if epoch == 0: + self.loss_list.append(loss) + else: + while len(self.loss_list) <= step: + self.loss_list.append(0.0) + self.loss_total -= self.loss_list[step] + self.loss_list[step] = loss + self.loss_total += loss + + @property + def moving_average(self) -> float: + return self.loss_total / len(self.loss_list) + + +def get_epoch_ckpt_name(model_name, epoch_no: int): + return EPOCH_FILE_NAME.format(model_name, epoch_no) + ".safetensors" + + +def get_step_ckpt_name(model_name, step_no: int): + return STEP_FILE_NAME.format(model_name, step_no) + ".safetensors" + + +def get_last_ckpt_name(model_name): + return model_name + ".safetensors" + + +def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int): + if args.save_last_n_epochs is None: + return None + + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + if remove_epoch_no < 0: + return None + return remove_epoch_no + + +def get_remove_step_no(args: argparse.Namespace, step_no: int): + if args.save_last_n_steps is None: + return None + + # calculate the step number to remove from the last_n_steps and save_every_n_steps + # e.g. if save_every_n_steps=10, save_last_n_steps=30, at step 50, keep 30 steps and remove step 10 + remove_step_no = step_no - args.save_last_n_steps - 1 + remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) + if remove_step_no < 0: + return None + return remove_step_no + + +def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator: accelerate.Accelerator, epoch_no: int): + model_name = args.output_name + + logger.info("") + logger.info(f"saving state at epoch {epoch_no}") + os.makedirs(args.output_dir, exist_ok=True) + + state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + logger.info("uploading state to huggingface.") + huggingface_utils.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) + + last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + logger.info(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) + + +def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator: accelerate.Accelerator, step_no: int): + model_name = args.output_name + + logger.info("") + logger.info(f"saving state at step {step_no}") + os.makedirs(args.output_dir, exist_ok=True) + + state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + logger.info("uploading state to huggingface.") + huggingface_utils.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) + + last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps + if last_n_steps is not None: + # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する + remove_step_no = step_no - last_n_steps - 1 + remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) + + if remove_step_no > 0: + state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) + if os.path.exists(state_dir_old): + logger.info(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) + + +def save_state_on_train_end(args: argparse.Namespace, accelerator: accelerate.Accelerator): + model_name = args.output_name + + logger.info("") + logger.info("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + + state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) + accelerator.save_state(state_dir) + + if args.save_state_to_huggingface: + logger.info("uploading last state to huggingface.") + huggingface_utils.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) + diff --git a/wan/__init__.py b/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7ed4df10cae220744639f079b4e11d985f9d05 --- /dev/null +++ b/wan/__init__.py @@ -0,0 +1 @@ +# from . import configs, distributed, modules diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c1c37c417b8c32c956bab78686c7a6d7172e69 --- /dev/null +++ b/wan/configs/__init__.py @@ -0,0 +1,143 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import copy +import os +import torch +from easydict import EasyDict + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +from .shared_config import wan_shared_cfg +from .wan_i2v_14B import i2v_14B +from .wan_t2v_1_3B import t2v_1_3B +from .wan_t2v_14B import t2v_14B + +# the config of t2i_14B is the same as t2v_14B +t2i_14B = copy.deepcopy(t2v_14B) +t2i_14B.__name__ = "Config: Wan T2I 14B" + +# ================== START: Add New 1.3B I2V Model Config ================== +i2v_1_3B_new = EasyDict(__name__="Config: Wan I2V 1.3B New") +i2v_1_3B_new.update(wan_shared_cfg) # Start with shared defaults + +# --- Core Model Parameters from your config.json --- +i2v_1_3B_new.dim = 1536 +i2v_1_3B_new.ffn_dim = 8960 +i2v_1_3B_new.num_heads = 12 +i2v_1_3B_new.num_layers = 30 +i2v_1_3B_new.in_dim = 36 # From config.json (latent + mask) +i2v_1_3B_new.out_dim = 16 # From config.json +i2v_1_3B_new.freq_dim = 256 # From config.json +i2v_1_3B_new.text_len = 512 # From config.json +i2v_1_3B_new.eps = 1e-06 # From config.json + +# --- I2V Specific Settings --- +i2v_1_3B_new.i2v = True # Mark as I2V +i2v_1_3B_new.is_fun_control = False # This is NOT a FunControl model + +# --- Assumed Component Checkpoints & Settings (ADJUST IF NEEDED) --- +# Assume it uses the same components as other models unless specified +# DiT: User MUST provide this path via --dit +# VAE: Assume standard VAE, user can override with --vae +i2v_1_3B_new.vae_checkpoint = "Wan2.1_VAE.pth" # Or specific VAE if different +i2v_1_3B_new.vae_stride = (4, 8, 8) # Standard stride + +# T5: Assume standard T5, user can override with --t5 +i2v_1_3B_new.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" # Or smaller T5 if available +i2v_1_3B_new.t5_tokenizer = "google/umt5-xxl" +i2v_1_3B_new.t5_dtype = torch.bfloat16 # Default T5 dtype + +# CLIP: Needed for I2V, assume standard CLIP, user can override with --clip +i2v_1_3B_new.clip_model = "clip_xlm_roberta_vit_h_14" +i2v_1_3B_new.clip_dtype = torch.float16 # Default CLIP dtype +i2v_1_3B_new.clip_checkpoint = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" +i2v_1_3B_new.clip_tokenizer = "xlm-roberta-large" + +# Transformer structure (Assume standard based on WanModel) +i2v_1_3B_new.patch_size = (1, 2, 2) # Standard patch size +i2v_1_3B_new.window_size = (-1, -1) # Global attention +i2v_1_3B_new.qk_norm = True # Standard norm +i2v_1_3B_new.cross_attn_norm = True # Standard norm (often True for I2V) + +# Default sample prompts (can be kept or adjusted) +i2v_1_3B_new.sample_prompts = ["cinematic video of a sports car"] +i2v_1_3B_new.sample_neg_prompt = "text, watermark, copyright, blurry, low quality, noisy" +i2v_1_3B_new.num_train_timesteps = 1000 # Standard diffusion timesteps + +# ================== END: Add New 1.3B I2V Model Config ================== + +# support Fun models: deepcopy and change some configs. FC denotes Fun Control +t2v_1_3B_FC = copy.deepcopy(t2v_1_3B) +t2v_1_3B_FC.__name__ = "Config: Wan-Fun-Control T2V 1.3B" +t2v_1_3B_FC.in_dim = 48 +i2v_14B.is_fun_control = False +t2v_14B_FC = copy.deepcopy(t2v_14B) +t2v_14B_FC.__name__ = "Config: Wan-Fun-Control T2V 14B" +t2v_14B_FC.i2v = True # this is strange, but Fun-Control model needs this because it has img cross-attention +t2v_14B_FC.in_dim = 48 # same as i2v_14B, use zeros for image latents +t2v_14B_FC.is_fun_control = True +i2v_14B_FC = copy.deepcopy(i2v_14B) +i2v_14B_FC.__name__ = "Config: Wan-Fun-Control I2V 14B" +i2v_14B_FC.in_dim = 48 +i2v_14B_FC.is_fun_control = True + +i2v_14B_FC_1_1 = copy.deepcopy(i2v_14B_FC) # Copy the existing FunControl I2V 14B config +i2v_14B_FC_1_1.__name__ = "Config: Wan-Fun-Control I2V 14B v1.1" +# Explicitly add the flag for clarity, though loading logic will derive it +# i2v_14B_FC_1_1.add_ref_conv = True # This flag isn't directly used by the Python config struct, but good for documentation +# The key is that the loaded weights for this model WILL contain 'ref_conv.weight' +# All other parameters are inherited from i2v_14B_FC (in_dim=48, is_fun_control=True, etc.) + +WAN_CONFIGS = { + "t2v-14B": t2v_14B, + "t2v-1.3B": t2v_1_3B, + "i2v-14B": i2v_14B, + "t2i-14B": t2i_14B, + "i2v-1.3B-new": i2v_1_3B_new, + # Fun Control models + "t2v-1.3B-FC": t2v_1_3B_FC, + "t2v-14B-FC": t2v_14B_FC, + "i2v-14B-FC": i2v_14B_FC, + "i2v-14B-FC-1.1": i2v_14B_FC_1_1, +} + +SIZE_CONFIGS = { + "720*1280": (720, 1280), + "1280*720": (1280, 720), + "480*832": (480, 832), + "832*480": (832, 480), + "1024*1024": (1024, 1024), + "512*512": (512, 512), # <--- Example: Added 512x512 if used + "672*352": (672, 352), # <--- Added from your command line example + "352*672": (352, 672), # <--- Added from your command line example (vertical) +} +# --- ^^^ MODIFY THIS DICTIONARY ^^^ --- + + +# --- vvv MODIFY THIS DICTIONARY vvv --- +MAX_AREA_CONFIGS = { + "720*1280": 720 * 1280, + "1280*720": 1280 * 720, + "480*832": 480 * 832, + "832*480": 832 * 480, + "1024*1024": 1024 * 1024, + "512*512": 512 * 512, # <--- Added 512x512 if used + "672*352": 672 * 352, # <--- Added from your command line example + "352*672": 352 * 672, # <--- Added from your command line example (vertical) +} +# --- ^^^ MODIFY THIS DICTIONARY ^^^ --- + + +# --- vvv MODIFY THIS DICTIONARY vvv --- +SUPPORTED_SIZES = { + "t2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), + "t2v-1.3B": ("480*832", "832*480"), + "i2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), + "t2i-14B": tuple(SIZE_CONFIGS.keys()), + # Fun Control models + "t2v-1.3B-FC": ("480*832", "832*480"), + "t2v-14B-FC": ("720*1280", "1280*720", "480*832", "832*480"), + "i2v-14B-FC": ("720*1280", "1280*720", "480*832", "832*480"), + "i2v-14B-FC-1.1": ("720*1280", "1280*720", "480*832", "832*480"), + # Add supported sizes for the new model + "i2v-1.3B-new": ("480*832", "832*480", "512*512", "672*352", "352*672"), + +} diff --git a/wan/configs/shared_config.py b/wan/configs/shared_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ff603d52244336acc864835c2cd30c1c6110e39b --- /dev/null +++ b/wan/configs/shared_config.py @@ -0,0 +1,20 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +from easydict import EasyDict + +#------------------------ Wan shared config ------------------------# +wan_shared_cfg = EasyDict() + +# t5 +wan_shared_cfg.t5_model = 'umt5_xxl' +wan_shared_cfg.t5_dtype = torch.bfloat16 +wan_shared_cfg.text_len = 512 + +# transformer +wan_shared_cfg.param_dtype = torch.bfloat16 +wan_shared_cfg.out_dim = 16 + +# inference +wan_shared_cfg.num_train_timesteps = 1000 +wan_shared_cfg.sample_fps = 16 +wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py new file mode 100644 index 0000000000000000000000000000000000000000..434f59c3d1dd75c9cdc816c5f976afed0ef08631 --- /dev/null +++ b/wan/configs/wan_i2v_14B.py @@ -0,0 +1,38 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +# ------------------------ Wan I2V 14B ------------------------# + +i2v_14B = EasyDict(__name__="Config: Wan I2V 14B") +i2v_14B.update(wan_shared_cfg) +i2v_14B.i2v = True +i2v_14B.is_fun_control = False + +i2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +i2v_14B.t5_tokenizer = "google/umt5-xxl" + +# clip +i2v_14B.clip_model = "clip_xlm_roberta_vit_h_14" +i2v_14B.clip_dtype = torch.float16 +i2v_14B.clip_checkpoint = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" +i2v_14B.clip_tokenizer = "xlm-roberta-large" + +# vae +i2v_14B.vae_checkpoint = "Wan2.1_VAE.pth" +i2v_14B.vae_stride = (4, 8, 8) + +# transformer +i2v_14B.patch_size = (1, 2, 2) +i2v_14B.dim = 5120 +i2v_14B.ffn_dim = 13824 +i2v_14B.freq_dim = 256 +i2v_14B.in_dim = 36 +i2v_14B.num_heads = 40 +i2v_14B.num_layers = 40 +i2v_14B.window_size = (-1, -1) +i2v_14B.qk_norm = True +i2v_14B.cross_attn_norm = True +i2v_14B.eps = 1e-6 diff --git a/wan/configs/wan_t2v_14B.py b/wan/configs/wan_t2v_14B.py new file mode 100644 index 0000000000000000000000000000000000000000..76433f058b159d74ce41a539e69a1bcd8bb9901e --- /dev/null +++ b/wan/configs/wan_t2v_14B.py @@ -0,0 +1,32 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +# ------------------------ Wan T2V 14B ------------------------# + +t2v_14B = EasyDict(__name__="Config: Wan T2V 14B") +t2v_14B.update(wan_shared_cfg) +t2v_14B.i2v = False +t2v_14B.is_fun_control = False + +# t5 +t2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +t2v_14B.t5_tokenizer = "google/umt5-xxl" + +# vae +t2v_14B.vae_checkpoint = "Wan2.1_VAE.pth" +t2v_14B.vae_stride = (4, 8, 8) + +# transformer +t2v_14B.patch_size = (1, 2, 2) +t2v_14B.dim = 5120 +t2v_14B.ffn_dim = 13824 +t2v_14B.freq_dim = 256 +t2v_14B.in_dim = 16 +t2v_14B.num_heads = 40 +t2v_14B.num_layers = 40 +t2v_14B.window_size = (-1, -1) +t2v_14B.qk_norm = True +t2v_14B.cross_attn_norm = True +t2v_14B.eps = 1e-6 diff --git a/wan/configs/wan_t2v_1_3B.py b/wan/configs/wan_t2v_1_3B.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb9e10ef41cf249004e1e46d22591471e284882 --- /dev/null +++ b/wan/configs/wan_t2v_1_3B.py @@ -0,0 +1,32 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +# ------------------------ Wan T2V 1.3B ------------------------# + +t2v_1_3B = EasyDict(__name__="Config: Wan T2V 1.3B") +t2v_1_3B.update(wan_shared_cfg) +t2v_1_3B.i2v = False +t2v_1_3B.is_fun_control = False + +# t5 +t2v_1_3B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +t2v_1_3B.t5_tokenizer = "google/umt5-xxl" + +# vae +t2v_1_3B.vae_checkpoint = "Wan2.1_VAE.pth" +t2v_1_3B.vae_stride = (4, 8, 8) + +# transformer +t2v_1_3B.patch_size = (1, 2, 2) +t2v_1_3B.dim = 1536 +t2v_1_3B.ffn_dim = 8960 +t2v_1_3B.freq_dim = 256 +t2v_1_3B.in_dim = 16 +t2v_1_3B.num_heads = 12 +t2v_1_3B.num_layers = 30 +t2v_1_3B.window_size = (-1, -1) +t2v_1_3B.qk_norm = True +t2v_1_3B.cross_attn_norm = True +t2v_1_3B.eps = 1e-6 diff --git a/wan/image2video.py b/wan/image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..500e4158efe51fed133190c76b108bc1e2214e1b --- /dev/null +++ b/wan/image2video.py @@ -0,0 +1,419 @@ +# Modified from official implementation + +# Original source: +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import logging +import os +import random +import sys +from typing import Optional, Union + +import cv2 +import numpy as np +import torch +import torchvision.transforms.functional as TF +from tqdm import tqdm +from accelerate import Accelerator, init_empty_weights +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +from utils.safetensors_utils import load_safetensors + +# from .distributed.fsdp import shard_model +from .modules.clip import CLIPModel +from .modules.model import WanModel, load_wan_model +from .modules.t5 import T5EncoderModel +from .modules.vae import WanVAE +from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +from utils.device_utils import clean_memory_on_device, synchronize_device + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class WanI2V: + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_usp=False, + t5_cpu=False, + init_on_cpu=True, + device=None, + dit_dtype=None, + dit_weight_dtype=None, + dit_path=None, + dit_attn_mode=None, + t5_path=None, + clip_path=None, + t5_fp8=False, + ): + r""" + Initializes the image-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0) **IGNORED**: + Id of target GPU device + rank (`int`, *optional*, defaults to 0) **IGNORED**: + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False) **IGNORED**: + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False) **IGNORED**: + Enable FSDP sharding for DiT model + use_usp (`bool`, *optional*, defaults to False) **IGNORED**: + Enable distribution strategy of USP. + t5_cpu (`bool`, *optional*, defaults to False) **IGNORED**: + Whether to place T5 model on CPU. Only works without t5_fsdp. + init_on_cpu (`bool`, *optional*, defaults to True) **IGNORED**: + Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + + device (`torch.device`, *optional*, defaults to None): + Device to place the model on. If None, use the default device (cuda) + dtype (`torch.dtype`, *optional*, defaults to None): + Data type for DiT model parameters. If None, use the default parameter data type from config + dit_path (`str`, *optional*, defaults to None): + Path to DiT model checkpoint. checkpoint_dir is used if None. + dit_attn_mode (`str`, *optional*, defaults to None): + Attention mode for DiT model. If None, use "torch" attention mode. + t5_path (`str`, *optional*, defaults to None): + Path to T5 model checkpoint. checkpoint_dir is used if None. + clip_path (`str`, *optional*, defaults to None): + Path to CLIP model checkpoint. checkpoint_dir is used if None. + t5_fp8 (`bool`, *optional*, defaults to False): + Enable FP8 quantization for T5 model + """ + self.device = device if device is not None else torch.device("cuda") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.t5_fp8 = t5_fp8 + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + # shard_fn = partial(shard_model, device_id=device_id) + checkpoint_path = None if checkpoint_dir is None else os.path.join(checkpoint_dir, config.t5_checkpoint) + tokenizer_path = None if checkpoint_dir is None else os.path.join(checkpoint_dir, config.t5_tokenizer) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=device, + checkpoint_path=checkpoint_path, + tokenizer_path=tokenizer_path, + weight_path=t5_path, + fp8=t5_fp8, + # shard_fn=shard_fn if t5_fsdp else None, + ) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + + self.checkpoint_dir = checkpoint_dir + self.dit_path = dit_path + self.dit_dtype = dit_dtype # if dit_dtype is not None else config.param_dtype + self.dit_weight_dtype = dit_weight_dtype + self.dit_attn_mode = dit_attn_mode + self.clip_path = clip_path + + self.sample_neg_prompt = config.sample_neg_prompt + + def generate( + self, + accelerator: Accelerator, + merge_lora: Optional[callable], + fp8_scaled: bool, + input_prompt, + img, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver="unipc", + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + blocks_to_swap=0, + vae: WanVAE = None, + ): + r""" + Generates video frames from input image and text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation. + img (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + max_area (`int`, *optional*, defaults to 720*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + blocks_to_swap (`int`, *optional*, defaults to 0): + Number of blocks to swap (offload) to CPU. If 0, no blocks are offloaded. + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + max_area = size[0] * size[1] + + # save original image as numpy array + img_cv2 = np.array(img) # PIL to numpy + img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB) + + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) # -1 to 1 + + F = frame_num # number of frames + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round(np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) + lat_w = round(np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + lat_f = (F - 1) // self.vae_stride[0] + 1 # size of latent frames + max_seq_len = lat_f * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2]) + + # set seed + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + # Generate noise for the required number of frames only + noise = torch.randn(16, lat_f, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # preprocess + self.text_encoder.model.to(self.device) + with torch.no_grad(): + if self.t5_fp8: + with accelerator.autocast(): + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + else: + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + + del self.text_encoder + clean_memory_on_device(self.device) + + # load CLIP model + checkpoint_path = None if self.checkpoint_dir is None else os.path.join(self.checkpoint_dir, self.config.clip_checkpoint) + tokenizer_path = None if self.checkpoint_dir is None else os.path.join(self.checkpoint_dir, self.config.clip_tokenizer) + clip = CLIPModel( + dtype=self.config.clip_dtype, + device=self.device, + checkpoint_path=checkpoint_path, + tokenizer_path=tokenizer_path, + weight_path=self.clip_path, + ) + + clip.model.to(self.device) + logger.info(f"Encoding image to CLIP context") + # use torch.amp.autocast istead of accelerator.autocast, becuase CLIP dtype is not bfloat16 + with torch.amp.autocast(device_type=self.device.type, dtype=torch.float16), torch.no_grad(): + clip_context = clip.visual([img[:, None, :, :]]) + logger.info(f"Encoding complete") + + del clip + clean_memory_on_device(self.device) + + # y should be encoded with 81 frames, and trim to lat_f frames? encoding F frames causes invalid results? + logger.info(f"Encoding image to latent space") + vae.to_device(self.device) + + # resize image for the first frame. INTER_AREA is the best for downsampling + interpolation = cv2.INTER_AREA if h < img_cv2.shape[0] else cv2.INTER_CUBIC + img_resized = cv2.resize(img_cv2, (w, h), interpolation=interpolation) + img_resized = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB) + img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(self.device) # -1 to 1, CHW + img_resized = img_resized.unsqueeze(1) # CFHW + + # Create mask for the required number of frames + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + with accelerator.autocast(), torch.no_grad(): + # Zero padding for the required number of frames only + padding_frames = F - 1 # The first frame is the input image + img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, h, w, device=self.device)], dim=1) + y = vae.encode([img_resized])[0] + + y = y[:, :lat_f] # may be not needed + y = torch.concat([msk, y]) + logger.info(f"Encoding complete") + + vae.to_device("cpu") + clean_memory_on_device(self.device) + + # load DiT model + loading_device = "cpu" + if blocks_to_swap == 0 and merge_lora is None and not fp8_scaled: + loading_device = self.device + + loading_weight_dtype = self.dit_weight_dtype + if fp8_scaled or merge_lora is not None: + loading_weight_dtype = self.dit_dtype # load as-is + + # set fp8_scaled to False, because we optimize the model after merging LoRA + # TODO state dict based LoRA merge + self.model: WanModel = load_wan_model( + self.config, + True, + self.device, + self.dit_path, + self.dit_attn_mode, + False, + loading_device, + loading_weight_dtype, + False, + ) + + if merge_lora is not None: + # merge LoRA to the model, cast and move to the device + merge_lora(self.model) + + if fp8_scaled: + state_dict = self.model.state_dict() + move_to_device = blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU + state_dict = self.model.fp8_optimization(state_dict, self.device, move_to_device) + info = self.model.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded FP8 optimized weights: {info}") + if blocks_to_swap == 0: + self.model.to(self.device) # make sure all parameters are on the right device + else: + target_dtype = None + target_device = None + if self.dit_weight_dtype is not None: # in case of args.fp8 (not fp8_scaled) + logger.info(f"Convert model to {self.dit_weight_dtype}") + target_dtype = self.dit_weight_dtype + if blocks_to_swap == 0: + logger.info(f"Move model to device: {self.device}") + target_device = self.device + self.model.to(target_device, target_dtype) + + if blocks_to_swap > 0: + logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {self.device}") + self.model.enable_block_swap(blocks_to_swap, self.device, supports_backward=False) + self.model.move_to_device_except_swap_blocks(self.device) + self.model.prepare_block_swap_before_forward() + else: + # make sure the model is on the right device + self.model.to(self.device) + + self.model.eval().requires_grad_(False) + clean_memory_on_device(self.device) + + # evaluation mode + with torch.no_grad(): + + if sample_solver == "unipc": + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == "dpm++": + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps(sample_scheduler, device=self.device, sigmas=sampling_sigmas) + elif sample_solver == "vanilla": + sample_scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=self.num_train_timesteps, shift=shift) + sample_scheduler.set_timesteps(sampling_steps, device=self.device) + timesteps = sample_scheduler.timesteps + + org_step = sample_scheduler.step + + def step_wrapper( + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ): + return org_step(model_output, timestep, sample, return_dict=return_dict) + + sample_scheduler.step = step_wrapper + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latent = noise # on device + del noise + + arg_c = { + "context": [context[0]], + "clip_fea": clip_context, + "seq_len": max_seq_len, + "y": [y], + } + + arg_null = { + "context": context_null, + "clip_fea": clip_context, + "seq_len": max_seq_len, + "y": [y], + } + + # self.model.to(self.device) + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent.to(self.device)] + latent = latent.to("cpu") + timestep = [t] + + timestep = torch.stack(timestep).to(self.device) + + with accelerator.autocast(): + noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0].to("cpu") + noise_pred_uncond = self.model(latent_model_input, t=timestep, **arg_null)[0].to("cpu") + + latent_model_input = None + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=seed_g + )[0] + latent = temp_x0.squeeze(0) + + # x0 = [latent.to(self.device)] + del latent_model_input, timestep + + del sample_scheduler + del self.model + synchronize_device(self.device) + clean_memory_on_device(self.device) + return latent diff --git a/wan/modules/__init__.py b/wan/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f8935bbb45ab4e3f349d203b673102f7cfc07553 --- /dev/null +++ b/wan/modules/__init__.py @@ -0,0 +1,16 @@ +from .attention import flash_attention +from .model import WanModel +from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model +from .tokenizers import HuggingfaceTokenizer +from .vae import WanVAE + +__all__ = [ + 'WanVAE', + 'WanModel', + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', + 'HuggingfaceTokenizer', + 'flash_attention', +] diff --git a/wan/modules/attention.py b/wan/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..7653f7c7c1ceee172f6fd32686fa038dff3472dc --- /dev/null +++ b/wan/modules/attention.py @@ -0,0 +1,312 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from typing import Optional +import torch + +try: + import flash_attn_interface + + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + import sageattention + + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + +try: + import xformers.ops as xops + + XFORMERS_AVAILABLE = True +except ImportError: + XFORMERS_AVAILABLE = False + + +import warnings + +__all__ = [ + "flash_attention", + "attention", +] + + +def flash_attention( + qkv, + q_lens=None, + k_lens=None, + dropout_p=0.0, + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, + attn_mode: Optional[str] = "torch", + split_attn: bool = False, +): + """ + q: [B, Lq, Nq, C1]. + k: [B, Lk, Nk, C1]. + v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. + q_lens: [B]. + k_lens: [B]. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + causal: bool. Whether to apply causal attention mask. + window_size: (left right). If not (-1, -1), apply sliding window local attention. + deterministic: bool. If True, slightly slower and uses more memory. + dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. + """ + q, k, v = qkv + qkv.clear() + + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + # assert q.device.type == "cuda" and q.size(-1) <= 256 + + # params + b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # We cannot test Flash attention 3 in musubi tuner, so keep the original code. + # Customized code (except for flash attention 3) is not supported q_lens and k_lens. + if attn_mode != "flash3" and attn_mode != "sageattn": + assert q_lens is None, "q_lens is not supported except for flash attention 3." + assert k_lens is None or ( + min(k_lens) == max(k_lens) and k_lens[0] == lk + ), "k_lens is not supported except for flash attention 3." + + # SDPA + if attn_mode == "torch" or attn_mode == "sdpa": + assert not deterministic, "deterministic is not supported in scaled_dot_product_attention." + if q_scale is not None: + q = q * q_scale + q = half(q.transpose(1, 2)) + k = half(k.transpose(1, 2)) + v = half(v.transpose(1, 2)) + + if not split_attn: + q = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=causal, dropout_p=dropout_p, scale=softmax_scale + ) + x = q + else: + x = torch.empty_like(q) + for i in range(q.size(0)): + x[i : i + 1] = torch.nn.functional.scaled_dot_product_attention( + q[i : i + 1], k[i : i + 1], v[i : i + 1], is_causal=causal, dropout_p=dropout_p, scale=softmax_scale + ) + + del q, k, v + x = x.transpose(1, 2).contiguous() + return x.type(out_dtype) + + # flash attention 2 + if attn_mode == "flash" or attn_mode == "flash2": + if q_scale is not None: + q = q * q_scale + q = half(q) + k = half(k) + v = half(v) + + if not split_attn: + q = flash_attn.flash_attn_func(q, k, v, dropout_p, softmax_scale, causal, window_size, deterministic=deterministic) + x = q + else: + x = torch.empty_like(q) + for i in range(q.size(0)): + x[i : i + 1] = flash_attn.flash_attn_func( + q[i : i + 1], + k[i : i + 1], + v[i : i + 1], + dropout_p, + softmax_scale, + causal, + window_size, + deterministic=deterministic, + ) + del q, k, v + return x.type(out_dtype) + + # xformers + if attn_mode == "xformers": + assert not deterministic, "deterministic is not supported in xformers." + assert not causal, "causal is not supported in xformers." + if q_scale is not None: + q = q * q_scale + q = half(q) + k = half(k) + v = half(v) + + if not split_attn: + q = xops.memory_efficient_attention(q, k, v, p=dropout_p, scale=softmax_scale) + x = q + else: + x = torch.empty_like(q) + for i in range(q.size(0)): + x[i : i + 1] = xops.memory_efficient_attention( + q[i : i + 1], k[i : i + 1], v[i : i + 1], p=dropout_p, scale=softmax_scale + ) + + del q, k, v + return x.type(out_dtype) + + # sage attention with fixed length seems to cause NaN in I2V inference. + # # sage attention + # if attn_mode == "sageattn": + # print("Using sage attention") + # assert not deterministic, "deterministic is not supported in sage attention." + # if q_scale is not None: + # q = q * q_scale + # q, k, v = half(q), half(k), half(v) + # x = sageattention.sageattn(q, k, v, "NHD", is_causal=causal, sm_scale=softmax_scale) + # del q, k, v + # return x.type(out_dtype) + + assert not split_attn, "split_attn is not supported in flash attention 3 or sage attention." + + # preprocess query: in Wan 2.1, q_lens is always None. + if q_lens is None: + q = half(q.flatten(0, 1)) + q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True) + else: + q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + # preprocess key, value + if k_lens is None: + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True) + else: + # Note: in Wan 2.1, all k_lens are same if we have same image size in the batch. + if min(k_lens) == max(k_lens) and k.shape[1] == k_lens[0]: + # B, L, N, C -> BN, L, C + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + else: + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + q = q.to(v.dtype) + k = k.to(v.dtype) + + if q_scale is not None: + q = q * q_scale + + # if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: + # warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.") + + # apply attention + # if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: + if attn_mode == "flash3": + # Not tested yet in musubi tuner. + # Note: dropout_p, window_size are not supported in FA3 now. + x = flash_attn_interface.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + seqused_q=None, + seqused_k=None, + max_seqlen_q=lq, + max_seqlen_k=lk, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + )[0].unflatten(0, (b, lq)) + # elif (version is None or version == 2) and FLASH_ATTN_2_AVAILABLE: + # # assert FLASH_ATTN_2_AVAILABLE + # x = flash_attn.flash_attn_varlen_func( + # q=q, + # k=k, + # v=v, + # cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + # cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + # max_seqlen_q=lq, + # max_seqlen_k=lk, + # dropout_p=dropout_p, + # softmax_scale=softmax_scale, + # causal=causal, + # window_size=window_size, + # deterministic=deterministic, + # ).unflatten(0, (b, lq)) + # elif version is None and SAGE_ATTN_AVAILABLE: + elif attn_mode == "sageattn": + # print("Using sage attention") + assert not causal, "SAGE attention does not support causal attention." + x = sageattention.sageattn_varlen( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=lq, + max_seqlen_k=lk, + sm_scale=softmax_scale, + ).unflatten(0, (b, lq)) + else: + raise ValueError(f"Unknown attention mode: {attn_mode}") + + # output + return x.type(out_dtype) + + +def attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0.0, + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + fa_version=None, +): + if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: + return flash_attention( + q=q, + k=k, + v=v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + q_scale=q_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + dtype=dtype, + version=fa_version, + ) + else: + if q_lens is not None or k_lens is not None: + warnings.warn( + "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance." + ) + attn_mask = None + + q = q.transpose(1, 2).to(dtype) + k = k.transpose(1, 2).to(dtype) + v = v.transpose(1, 2).to(dtype) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + + out = out.transpose(1, 2).contiguous() + return out diff --git a/wan/modules/clip.py b/wan/modules/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..2fbd867678e1d75d402583c91ea97bba74194c52 --- /dev/null +++ b/wan/modules/clip.py @@ -0,0 +1,546 @@ +# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +from accelerate import init_empty_weights + +from .attention import flash_attention +from .tokenizers import HuggingfaceTokenizer +from .xlm_roberta import XLMRoberta + +from utils.safetensors_utils import load_safetensors + +__all__ = [ + "XLMRobertaCLIP", + "clip_xlm_roberta_vit_h_14", + "CLIPModel", +] + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat( + [ + pos[:, :n], + F.interpolate( + pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode="bicubic", + align_corners=False, + ) + .flatten(2) + .transpose(1, 2), + ], + dim=1, + ) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + + # compute attention + p = self.attn_dropout if self.training else 0.0 + # x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) + # print(q.shape, k.shape, v.shape) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=p, is_causal=self.causal) + # print(x.shape) + x = x.transpose(1, 2).contiguous() + x = x.reshape(b, s, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__( + self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation="quick_gelu", + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5, + ): + assert activation in ["quick_gelu", "gelu", "swi_glu"] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) + self.norm2 = LayerNorm(dim, eps=norm_eps) + if activation == "swi_glu": + self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) + else: + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == "quick_gelu" else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), + nn.Dropout(proj_dropout), + ) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + + def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim) + self.to_kv = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + self.norm = LayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == "quick_gelu" else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), + nn.Dropout(proj_dropout), + ) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) + k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + + # compute attention + # this line is never used because pool_type="token" in Wan2.1 + x = flash_attention(q, k, v, version=2) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + + def __init__( + self, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type="token", + pre_norm=True, + post_norm=False, + activation="quick_gelu", + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5, + ): + if image_size % patch_size != 0: + print("[WARNING] image_size is not divisible by patch_size", flush=True) + assert pool_type in ("token", "token_fc", "attn_pool") + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size) ** 2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm) + if pool_type in ("token", "token_fc"): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim) + ) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential( + *[ + AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) + for _ in range(num_layers) + ] + ) + self.post_norm = LayerNorm(dim, eps=norm_eps) + + # head + if pool_type == "token": + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + elif pool_type == "token_fc": + self.head = nn.Linear(dim, out_dim) + elif pool_type == "attn_pool": + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ("token", "token_fc"): + x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class XLMRobertaWithHead(XLMRoberta): + + def __init__(self, **kwargs): + self.out_dim = kwargs.pop("out_dim") + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential(nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + + def __init__( + self, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool="token", + vision_pre_norm=True, + vision_post_norm=False, + activation="gelu", + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5, + ): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps, + ) + self.textual = XLMRobertaWithHead( + vocab_size=vocab_size, + max_seq_len=max_text_len, + type_size=type_size, + pad_id=pad_id, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + post_norm=text_post_norm, + dropout=text_dropout, + ) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def param_groups(self): + groups = [ + {"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")], "weight_decay": 0.0}, + {"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]}, + ] + return groups + + +def _clip( + pretrained=False, + pretrained_name=None, + model_cls=XLMRobertaCLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding="eos", + dtype=torch.float32, + device="cpu", + **kwargs, +): + # # init a model on device + # with torch.device(device): + model = model_cls(**kwargs) + + # # set device + # model = model.to(dtype=dtype, device=device) + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if "siglip" in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose( + [ + T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=mean, std=std), + ] + ) + output += (transforms,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool="token", + activation="gelu", + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + ) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class CLIPModel: + + def __init__(self, dtype, device, checkpoint_path=None, tokenizer_path=None, weight_path=None): + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + self.weight_path = weight_path + + # init model + with init_empty_weights(): + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device + ) + self.model = self.model.eval().requires_grad_(False) + + logging.info(f"loading {weight_path}") + if os.path.splitext(weight_path)[-1] == ".safetensors": + sd = load_safetensors(weight_path, device=device, disable_mmap=True, dtype=dtype) + else: + sd = torch.load(weight_path, map_location=device, weights_only=True) + info = self.model.load_state_dict(sd, strict=True, assign=True) + self.model = self.model.to(dtype=dtype, device=device) + logging.info(f"weights loaded from {weight_path}: {info}") + + # init tokenizer + if tokenizer_path is None: + tokenizer_path = "Wan-AI/Wan2.1-I2V-14B-720P" + subfolder = "xlm-roberta-large" + else: + subfolder = None + + self.tokenizer = HuggingfaceTokenizer( + name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace", subfolder=subfolder + ) + + def visual(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + # with torch.cuda.amp.autocast(dtype=self.dtype): + out = self.model.visual(videos, use_31_block=True) + return out diff --git a/wan/modules/model.py b/wan/modules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0462c75da80091499ba65d34106e121c869d810d --- /dev/null +++ b/wan/modules/model.py @@ -0,0 +1,1024 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from accelerate import init_empty_weights + +import logging + +from utils.safetensors_utils import MemoryEfficientSafeOpen, load_safetensors + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +from utils.device_utils import clean_memory_on_device + +from .attention import flash_attention +from utils.device_utils import clean_memory_on_device +from modules.custom_offloading_utils import ModelOffloader +from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8 + +__all__ = ["WanModel"] + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +# @amp.autocast(enabled=False) +# no autocast is needed for rope_apply, because it is already in float64 +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +# @amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs): + device_type = x.device.type + with torch.amp.autocast(device_type=device_type, enabled=False): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) + freqs_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +def calculate_freqs_i(fhw, c, freqs): + f, h, w = fhw + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + freqs_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(f * h * w, 1, -1) + return freqs_i + + +# inplace version of rope_apply +def rope_apply_inplace_cached(x, grid_sizes, freqs_list): + # with torch.amp.autocast(device_type=device_type, enabled=False): + rope_dtype = torch.float64 # float32 does not reduce memory usage significantly + + n, c = x.size(2), x.size(3) // 2 + + # loop over samples + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(rope_dtype).reshape(seq_len, n, -1, 2)) + freqs_i = freqs_list[i] + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + # x_i = torch.cat([x_i, x[i, seq_len:]]) + + # inplace update + x[i, :seq_len] = x_i.to(x.dtype) + + return x + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + # return self._norm(x.float()).type_as(x) * self.weight + # support fp8 + return self._norm(x.float()).type_as(x) * self.weight.to(x.dtype) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + # def forward(self, x): + # r""" + # Args: + # x(Tensor): Shape [B, L, C] + # """ + # # inplace version, also supports fp8 -> does not have significant performance improvement + # original_dtype = x.dtype + # x = x.float() + # y = x.pow(2).mean(dim=-1, keepdim=True) + # y.add_(self.eps) + # y.rsqrt_() + # x *= y + # x = x.to(original_dtype) + # x *= self.weight.to(original_dtype) + # return x + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, attn_mode="torch", split_attn=False): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + self.attn_mode = attn_mode + self.split_attn = split_attn + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # # query, key, value function + # def qkv_fn(x): + # q = self.norm_q(self.q(x)).view(b, s, n, d) + # k = self.norm_k(self.k(x)).view(b, s, n, d) + # v = self.v(x).view(b, s, n, d) + # return q, k, v + # q, k, v = qkv_fn(x) + # del x + # query, key, value function + + q = self.q(x) + k = self.k(x) + v = self.v(x) + del x + q = self.norm_q(q) + k = self.norm_k(k) + q = q.view(b, s, n, d) + k = k.view(b, s, n, d) + v = v.view(b, s, n, d) + + rope_apply_inplace_cached(q, grid_sizes, freqs) + rope_apply_inplace_cached(k, grid_sizes, freqs) + qkv = [q, k, v] + del q, k, v + x = flash_attention( + qkv, k_lens=seq_lens, window_size=self.window_size, attn_mode=self.attn_mode, split_attn=self.split_attn + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanT2VCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + # q = self.norm_q(self.q(x)).view(b, -1, n, d) + # k = self.norm_k(self.k(context)).view(b, -1, n, d) + # v = self.v(context).view(b, -1, n, d) + q = self.q(x) + del x + k = self.k(context) + v = self.v(context) + del context + q = self.norm_q(q) + k = self.norm_k(k) + q = q.view(b, -1, n, d) + k = k.view(b, -1, n, d) + v = v.view(b, -1, n, d) + + # compute attention + qkv = [q, k, v] + del q, k, v + x = flash_attention(qkv, k_lens=context_lens, attn_mode=self.attn_mode, split_attn=self.split_attn) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, attn_mode="torch", split_attn=False): + super().__init__(dim, num_heads, window_size, qk_norm, eps, attn_mode, split_attn) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + # self.alpha = nn.Parameter(torch.zeros((1, ))) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x) + del x + q = self.norm_q(q) + q = q.view(b, -1, n, d) + k = self.k(context) + k = self.norm_k(k).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + del context + + # compute attention + qkv = [q, k, v] + del k, v + x = flash_attention(qkv, k_lens=context_lens, attn_mode=self.attn_mode, split_attn=self.split_attn) + + # compute query, key, value + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + v_img = self.v_img(context_img).view(b, -1, n, d) + del context_img + + # compute attention + qkv = [q, k_img, v_img] + del q, k_img, v_img + img_x = flash_attention(qkv, k_lens=None, attn_mode=self.attn_mode, split_attn=self.split_attn) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + if self.training: + x = x + img_x # avoid inplace + else: + x += img_x + del img_x + + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = { + "t2v_cross_attn": WanT2VCrossAttention, + "i2v_cross_attn": WanI2VCrossAttention, +} + + +class WanAttentionBlock(nn.Module): + + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + attn_mode="torch", + split_attn=False, + ): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps, attn_mode, split_attn) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps, attn_mode, split_attn) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + # with amp.autocast(dtype=torch.float32): + # e = (self.modulation + e).chunk(6, dim=1) + # support fp8 + e = self.modulation.to(torch.float32) + e + e = e.chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn(self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs) + # with amp.autocast(dtype=torch.float32): + # x = x + y * e[2] + x = x + y.to(torch.float32) * e[2] + del y + + # cross-attention & ffn function + # def cross_attn_ffn(x, context, context_lens, e): + # x += self.cross_attn(self.norm3(x), context, context_lens) + # y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + # # with amp.autocast(dtype=torch.float32): + # # x = x + y * e[5] + # x += y.to(torch.float32) * e[5] + # return x + # x = cross_attn_ffn(x, context, context_lens, e) + + # x += self.cross_attn(self.norm3(x), context, context_lens) # backward error + x = x + self.cross_attn(self.norm3(x), context, context_lens) + del context + y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + x = x + y.to(torch.float32) * e[5] + del y + return x + + def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, x, e, seq_lens, grid_sizes, freqs, context, context_lens, use_reentrant=False) + return self._forward(x, e, seq_lens, grid_sizes, freqs, context, context_lens) + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + assert e.dtype == torch.float32 + # with amp.autocast(dtype=torch.float32): + # e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + # x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + # support fp8 + e = (self.modulation.to(torch.float32) + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + return x + + +class MLPProj(torch.nn.Module): + + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), + torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), + torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim), + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanModel(nn.Module): # ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"] + _no_split_modules = ["WanAttentionBlock"] + + # @register_to_config + def __init__( + self, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + attn_mode=None, + split_attn=False, + add_ref_conv=False, + in_dim_ref_conv=16, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + self.attn_mode = attn_mode if attn_mode is not None else "torch" + self.split_attn = split_attn + + # embeddings + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" + self.blocks = nn.ModuleList( + [ + WanAttentionBlock( + cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, attn_mode, split_attn + ) + for _ in range(num_layers) + ] + ) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], dim=1 + ) + self.freqs_fhw = {} + + if model_type == "i2v": + self.img_emb = MLPProj(1280, dim) + + self.add_ref_conv = add_ref_conv # Store the flag + if add_ref_conv: + # Use spatial dimensions from patch_size for Conv2d + self.ref_conv = nn.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + logger.info(f"Initialized ref_conv layer with in_channels={in_dim_ref_conv}, out_channels={dim}") + else: + self.ref_conv = None + + # initialize weights + self.init_weights() + + self.gradient_checkpointing = False + + # offloading + self.blocks_to_swap = None + self.offloader = None + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + + def fp8_optimization( + self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool, use_scaled_mm: bool = False + ) -> int: + """ + Optimize the model state_dict with fp8. + + Args: + state_dict (dict[str, torch.Tensor]): + The state_dict of the model. + device (torch.device): + The device to calculate the weight. + move_to_device (bool): + Whether to move the weight to the device after optimization. + """ + TARGET_KEYS = ["blocks"] + EXCLUDE_KEYS = [ + "norm", + "patch_embedding", + "text_embedding", + "time_embedding", + "time_projection", + "head", + "modulation", + "img_emb", + ] + + # inplace optimization + state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device) + + # apply monkey patching + apply_fp8_monkey_patch(self, state_dict, use_scaled_mm=use_scaled_mm) + + return state_dict + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + for block in self.blocks: + block.enable_gradient_checkpointing() + + print(f"WanModel: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + for block in self.blocks: + block.disable_gradient_checkpointing() + + print(f"WanModel: Gradient checkpointing disabled.") + + def enable_block_swap(self, blocks_to_swap: int, device: torch.device, supports_backward: bool): + self.blocks_to_swap = blocks_to_swap + self.num_blocks = len(self.blocks) + + assert ( + self.blocks_to_swap <= self.num_blocks - 1 + ), f"Cannot swap more than {self.num_blocks - 1} blocks. Requested {self.blocks_to_swap} blocks to swap." + + self.offloader = ModelOffloader( + "wan_attn_block", self.blocks, self.num_blocks, self.blocks_to_swap, supports_backward, device # , debug=True + ) + print( + f"WanModel: Block swap enabled. Swapping {self.blocks_to_swap} blocks out of {self.num_blocks} blocks. Supports backward: {supports_backward}" + ) + + def switch_block_swap_for_inference(self): + if self.blocks_to_swap: + self.offloader.set_forward_only(True) + self.prepare_block_swap_before_forward() + print(f"WanModel: Block swap set to forward only.") + + def switch_block_swap_for_training(self): + if self.blocks_to_swap: + self.offloader.set_forward_only(False) + self.prepare_block_swap_before_forward() + print(f"WanModel: Block swap set to forward and backward.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_blocks = self.blocks + self.blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.blocks = save_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader.prepare_block_devices_before_forward(self.blocks) + + def forward(self, x, t, context, seq_len, clip_fea=None, y=None, skip_block_indices=None, fun_ref=None): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + # remove assertions to work with Fun-Control T2V + # if self.model_type == "i2v": + # assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if isinstance(x, list) and len(x) > 0: + _, F_orig, H_orig, W_orig = x[0].shape + else: + # Fallback or error handling if x is not as expected + raise ValueError("Input x is not in the expected list format.") + + if y is not None: + print('WanModel concat debug:') + for i, (u, v) in enumerate(zip(x, y)): + print(f"x[{i}]: {u.shape}, y[{i}]: {v.shape}, y[{i}].dim(): {v.dim()}") + x = [ + torch.cat([u, v], dim=0) + for u, v in zip(x, y) + ] + y = None + + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + + # <<< START: Process fun_ref if applicable >>> + F = F_orig # Use original frame count for RoPE calculation unless fun_ref modifies it + if self.ref_conv is not None and fun_ref is not None: + # fun_ref is expected to be the raw reference image latent [B, C_ref, H_ref, W_ref] + # Ensure it's on the correct device + fun_ref = fun_ref.to(device) + logger.debug(f"Processing fun_ref with shape: {fun_ref.shape}") + + # Apply the 2D convolution + # Note: fun_ref needs batch dim for Conv2d, add if missing + if fun_ref.dim() == 3: fun_ref = fun_ref.unsqueeze(0) + processed_ref = self.ref_conv(fun_ref) # Output: [B, C, H_out, W_out] + logger.debug(f"Processed ref_conv output shape: {processed_ref.shape}") + + # Reshape to token sequence: [B, L_ref, C] + processed_ref = processed_ref.flatten(2).transpose(1, 2) + logger.debug(f"Reshaped processed_ref shape: {processed_ref.shape}") + + # Adjust grid_sizes, seq_len, and F to account for the prepended tokens + # Assuming the reference adds effectively one "frame" worth of tokens spatially + # Note: This might need adjustment depending on how seq_len is used later. + # We increment the frame dimension 'F' in grid_sizes. + grid_sizes = torch.stack([torch.tensor([gs[0] + 1, gs[1], gs[2]], dtype=torch.long) for gs in grid_sizes]).to(grid_sizes.device) + seq_len += processed_ref.size(1) # Add number of reference tokens + F = F_orig + 1 # Indicate one extra effective frame for RoPE/freq calculation + logger.debug(f"Adjusted grid_sizes: {grid_sizes}, seq_len: {seq_len}, F for RoPE: {F}") + + # Prepend the reference tokens to each element in the list x + x = [torch.cat([processed_ref, u.flatten(2).transpose(1, 2)], dim=1) for u in x] # x was already flattened+transposed below, do it here + # x is now list of [B, L_new, C] + else: + # Original flattening if no fun_ref + x = [u.flatten(2).transpose(1, 2) for u in x] + # <<< END: Process fun_ref if applicable >>> + + freqs_list = [] + for fhw in grid_sizes: # Use the potentially updated grid_sizes + fhw_tuple = tuple(fhw.tolist()) + if fhw_tuple not in self.freqs_fhw: + c_rope = self.dim // self.num_heads // 2 + # Use the potentially updated frame count F from fhw[0] + self.freqs_fhw[fhw_tuple] = calculate_freqs_i(fhw, c_rope, self.freqs) + freqs_list.append(self.freqs_fhw[fhw_tuple]) + + # ... (seq_len calculation and padding using potentially updated seq_len) ... + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + if seq_lens.max() > seq_len: + # This might happen if seq_len wasn't updated correctly or padding logic needs review + logger.warning(f"Calculated seq_lens.max()={seq_lens.max()} > adjusted seq_len={seq_len}. Adjusting seq_len.") + seq_len = seq_lens.max().item() # Use the actual max length required + + x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x]) + + # time embeddings + # with amp.autocast(dtype=torch.float32): + with torch.amp.autocast(device_type=device.type, dtype=torch.float32): + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + if type(context) is list: + context = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) + context = self.text_embedding(context) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + clip_fea = None + context_clip = None + + # arguments + kwargs = dict(e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=freqs_list, context=context, context_lens=context_lens) + + if self.blocks_to_swap: + clean_memory_on_device(device) + + # print(f"x: {x.shape}, e: {e0.shape}, context: {context.shape}, seq_lens: {seq_lens}") + for block_idx, block in enumerate(self.blocks): + is_block_skipped = skip_block_indices is not None and block_idx in skip_block_indices + + if self.blocks_to_swap and not is_block_skipped: + self.offloader.wait_for_block(block_idx) + + if not is_block_skipped: + x = block(x, **kwargs) + + if self.blocks_to_swap: + self.offloader.submit_move_blocks_forward(self.blocks, block_idx) + + if self.ref_conv is not None and fun_ref is not None: + num_ref_tokens = processed_ref.size(1) + logger.debug(f"Removing {num_ref_tokens} prepended reference tokens before head.") + x = x[:, num_ref_tokens:, :] + # Restore original grid_sizes F dimension for unpatchify + grid_sizes = torch.stack([torch.tensor([gs[0] - 1, gs[1], gs[2]], dtype=torch.long) for gs in grid_sizes]).to(grid_sizes.device) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) + + +def detect_wan_sd_dtype(path: str) -> torch.dtype: + # get dtype from model weights + with MemoryEfficientSafeOpen(path) as f: + keys = set(f.keys()) + key1 = "model.diffusion_model.blocks.0.cross_attn.k.weight" # 1.3B + key2 = "blocks.0.cross_attn.k.weight" # 14B + if key1 in keys: + dit_dtype = f.get_tensor(key1).dtype + elif key2 in keys: + dit_dtype = f.get_tensor(key2).dtype + else: + raise ValueError(f"Could not find the dtype in the model weights: {path}") + logger.info(f"Detected DiT dtype: {dit_dtype}") + return dit_dtype + + +def load_wan_model( + config: any, + device: Union[str, torch.device], + dit_path: str, + attn_mode: str, + split_attn: bool, + loading_device: Union[str, torch.device], + dit_weight_dtype: Optional[torch.dtype], + fp8_scaled: bool = False, +) -> WanModel: + # dit_weight_dtype is None for fp8_scaled + assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None) + + device = torch.device(device) + loading_device = torch.device(loading_device) + + wan_loading_device = torch.device("cpu") if fp8_scaled else loading_device + logger.info(f"Loading DiT model state dict from {dit_path}, device={wan_loading_device}, dtype={dit_weight_dtype}") + sd = load_safetensors(dit_path, wan_loading_device, disable_mmap=True, dtype=dit_weight_dtype) + + # remove "model.diffusion_model." prefix: 1.3B model has this prefix + sd_keys = list(sd.keys()) # Keep original keys for potential prefix removal + for key in sd_keys: + if key.startswith("model.diffusion_model."): + sd[key[22:]] = sd.pop(key) + + # Check for ref_conv layer weights + has_ref_conv = "ref_conv.weight" in sd + in_dim_ref_conv = sd["ref_conv.weight"].shape[1] if has_ref_conv else 16 # Default if not found + if has_ref_conv: + logger.info(f"Detected ref_conv layer in model weights. Input channels: {in_dim_ref_conv}") + + with init_empty_weights(): + logger.info(f"Creating WanModel") + model = WanModel( + model_type="i2v" if config.i2v else "t2v", + dim=config.dim, + eps=config.eps, + ffn_dim=config.ffn_dim, + freq_dim=config.freq_dim, + in_dim=config.in_dim, + num_heads=config.num_heads, + num_layers=config.num_layers, + out_dim=config.out_dim, + text_len=config.text_len, + attn_mode=attn_mode, + split_attn=split_attn, + add_ref_conv=has_ref_conv, # <<< Pass detected flag + in_dim_ref_conv=in_dim_ref_conv, + ) + if dit_weight_dtype is not None and not fp8_scaled: # Don't pre-cast if optimizing to FP8 later + model.to(dit_weight_dtype) + + # ... (fp8 optimization - sd is already loaded) ... + if fp8_scaled: + # fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap) + logger.info(f"Optimizing model weights to fp8. This may take a while.") + sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu") + + if loading_device.type != "cpu": + # make sure all the model weights are on the loading_device + logger.info(f"Moving weights to {loading_device}") + for key in sd.keys(): + sd[key] = sd[key].to(loading_device) + + # Load the potentially modified state dict + # Use strict=False initially if ref_conv might be missing in older models but present in the class + # After confirming your models, you might set strict=True if all target models have the layer or None. + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded DiT model from {dit_path}, info={info}") + if not info.missing_keys and not info.unexpected_keys: + logger.info("State dict loaded successfully (strict check passed).") + else: + logger.warning(f"State dict load info: Missing={info.missing_keys}, Unexpected={info.unexpected_keys}") + # If add_ref_conv is True but ref_conv keys are missing, it's an issue. + if has_ref_conv and any("ref_conv" in k for k in info.missing_keys): + raise ValueError("Model configuration indicates ref_conv=True, but weights are missing!") + # If add_ref_conv is False but ref_conv keys are unexpected, it's also an issue with model/config mismatch. + if not has_ref_conv and any("ref_conv" in k for k in info.unexpected_keys): + raise ValueError("Model configuration indicates ref_conv=False, but weights are present!") + + + return model \ No newline at end of file diff --git a/wan/modules/t5.py b/wan/modules/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc89c8342ae9c799fc4674e51fa5661131e38b4 --- /dev/null +++ b/wan/modules/t5.py @@ -0,0 +1,514 @@ +# Modified from transformers.models.t5.modeling_t5 +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +# import logging +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer +from accelerate import init_empty_weights +from safetensors.torch import load_file + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +__all__ = [ + "T5Model", + "T5Encoder", + "T5Decoder", + "T5EncoderModel", +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum("bnij,bjnc->binc", attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) + + def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = ( + max_exact + + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long() + ) + rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super(T5Encoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList( + [T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)] + ) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def prepare_fp8(self, target_dtype=torch.bfloat16): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in self.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super(T5Decoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList( + [T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)] + ) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + + def __init__( + self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1, + ): + super(T5Model, self).__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder( + self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout + ) + self.decoder = T5Decoder( + self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout + ) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5( + name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + **kwargs, +): + # dtype=torch.float32, + # device="cpu", + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("encoder_layers") + _ = kwargs.pop("decoder_layers") + elif decoder_only: + model_cls = T5Decoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("decoder_layers") + _ = kwargs.pop("encoder_layers") + else: + model_cls = T5Model + + # # init model + # with torch.device(device): + model = model_cls(**kwargs) + + # # set device + # model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + from .tokenizers import HuggingfaceTokenizer + + tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1, + ) + cfg.update(**kwargs) + return _t5("umt5-xxl", **cfg) + + +class T5EncoderModel: + + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + weight_path=None, + fp8=False, + ): + self.text_len = text_len + self.dtype = dtype if not fp8 else torch.float8_e4m3fn + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + with init_empty_weights(): + model = umt5_xxl(encoder_only=True, return_tokenizer=False) + + model = model.eval().requires_grad_(False) + if checkpoint_path is not None: + logger.info(f"loading {checkpoint_path}") + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + else: + logger.info(f"loading weights from {weight_path}") + if os.path.splitext(weight_path)[1] == ".safetensors": + sd = load_file(weight_path) + else: + sd = torch.load(weight_path, map_location="cpu", weights_only=True) + # remove prefix "encoder." from the state dict + sd = {k.replace("encoder.", ""): v for k, v in sd.items()} + model.load_state_dict(sd, strict=True, assign=True) + + logger.info(f"moving model to {device} and casting to {self.dtype}") + model = model.to(device, dtype=self.dtype) + + if fp8: + logger.info("preparing model for fp8") + model.prepare_fp8(dtype) + + self.model = model + # if shard_fn is not None: + # self.model = shard_fn(self.model, sync_module_states=False) + # else: + # self.model.to(self.device) + # init tokenizer + if tokenizer_path is None: + tokenizer_path = "Wan-AI/Wan2.1-T2V-14B" + subfolder = "google/umt5-xxl" + else: + subfolder = None + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace", subfolder=subfolder) + + def __call__(self, texts, device): + ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/wan/modules/tokenizers.py b/wan/modules/tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..121e591c48f82f82daa51a6ce38ae9a27beea8d2 --- /dev/null +++ b/wan/modules/tokenizers.py @@ -0,0 +1,82 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ['HuggingfaceTokenizer'] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text diff --git a/wan/modules/vae.py b/wan/modules/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..4580e716ce492fa4fbe713cda1c8d9759f3381ef --- /dev/null +++ b/wan/modules/vae.py @@ -0,0 +1,752 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import os +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from safetensors.torch import load_file + +__all__ = [ + "WanVAE", +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + self.cache_device = None + + def set_cache_device(self, device): + self.cache_device = device + + def forward(self, x, feat_cache=None, feat_idx=[0]): + cache_device = self.cache_device if self.cache_device is not None else x.device + + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone().to(cache_device) + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone().to(cache_device) + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :].to(x.device), x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + self.cache_device = None + + def set_cache_device(self, device): + self.cache_device = device + + def forward(self, x, feat_cache=None, feat_idx=[0]): + cache_device = self.cache_device if self.cache_device is not None else x.device + + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = layer(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout) + ) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) + + self.cache_device = None + + def set_cache_device(self, device): + self.cache_device = device + + # set cache device for all layers + for layer in self.downsamples + self.middle + self.head: + if isinstance(layer, Resample) or isinstance(layer, ResidualBlock): + layer.set_cache_device(device) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + cache_device = self.cache_device if self.cache_device is not None else x.device + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = layer(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout) + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) + + self.cache_device = None + + def set_cache_device(self, device): + self.cache_device = device + + # set cache device for all layers + for layer in self.middle + self.upsamples + self.head: + if isinstance(layer, Resample) or isinstance(layer, ResidualBlock): + layer.set_cache_device(device) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + cache_device = self.cache_device if self.cache_device is not None else x.device + + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = layer(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) + + self.cache_device = None + + def set_cache_device(self, device): + # set cache device + self.cache_device = device + self.encoder.set_cache_device(device) + self.decoder.set_cache_device(device) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + # ## 对encode输入的x,按时间拆分为1、4、4、4.... + + # if self.cache_device is None: + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx + ) + out = torch.cat([out, out_], 2) + # else: + # # VRAM optimization + # device = x.device + # clean_memory_on_device(device) + # outs = [] + # for i in range(iter_): + # self._enc_conv_idx = [0] + # if i == 0: + # out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + # else: + # out = self.encoder( + # x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx + # ) + # outs.append(out.to(self.cache_device)) + # out = torch.cat(outs, 2).to(device) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + + # if self.cache_device is None: + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + # else: + # # VRAM optimization + # device = z.device + # x = x.to("cpu") + # clean_memory_on_device(device) + # outs = [] + # for i in range(iter_): + # self._conv_idx = [0] + # out = self.decoder(x[:, :, i : i + 1, :, :].to(device), feat_cache=self._feat_map, feat_idx=self._conv_idx).to( + # self.cache_device + # ) + # outs.append(out) + # out = torch.cat(outs, 2) # on cache_device + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f"loading {pretrained_path}") + if os.path.splitext(pretrained_path)[-1] == ".safetensors": + sd = load_file(pretrained_path) + model.load_state_dict(sd, strict=False, assign=True) + else: + model.load_state_dict(torch.load(pretrained_path, map_location=device, weights_only=True), assign=True) + + return model + + +class WanVAE: + + def __init__(self, z_dim=16, vae_path="cache/vae_step_411000.pth", dtype=torch.float, device="cuda", cache_device=None): + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = ( + _video_vae( + pretrained_path=vae_path, + z_dim=z_dim, + ) + .eval() + .requires_grad_(False) + .to(device) + ) + if cache_device is not None: + self.model.set_cache_device(torch.device(cache_device)) + + def to_device(self, device): + self.device = device + self.model.to(device) + self.mean = self.mean.to(device) + self.std = self.std.to(device) + self.scale = [t.to(device) for t in self.scale] + + def to_dtype(self, dtype): + self.dtype = dtype + self.model.to(dtype=dtype) + self.mean = self.mean.to(dtype) + self.std = self.std.to(dtype) + self.scale = [t.to(dtype) for t in self.scale] + + def eval(self): + self.model.eval() + + def train(self, mode: bool = True): + self.model.train(mode) + + def requires_grad_(self, requires_grad: bool = True): + self.model.requires_grad_(requires_grad) + + def to(self, device_or_dtype: Union[torch.device, torch.dtype, str], dtype: Optional[torch.dtype] = None): + """ + Add nn.Module.to() support for device and dtype. + """ + if isinstance(device_or_dtype, str) or isinstance(device_or_dtype, torch.device): + self.to_device(device_or_dtype) + else: + self.to_dtype(device_or_dtype) + + if dtype is not None: + self.to_dtype(dtype) + + def encode(self, videos): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + # with amp.autocast(dtype=self.dtype): + return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] + + def decode(self, zs): + # with amp.autocast(dtype=self.dtype): + return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs] diff --git a/wan/modules/xlm_roberta.py b/wan/modules/xlm_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd38c1016fdaec90b77a6222d75d01c38c1291c --- /dev/null +++ b/wan/modules/xlm_roberta.py @@ -0,0 +1,170 @@ +# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['XLMRoberta', 'xlm_roberta_large'] + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), + nn.Dropout(dropout)) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__(self, + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList([ + AttentionBlock(dim, num_heads, post_norm, dropout, eps) + for _ in range(num_layers) + ]) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = self.token_embedding(ids) + \ + self.type_embedding(torch.zeros_like(ids)) + \ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where( + mask.view(b, 1, 1, s).gt(0), 0.0, + torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained=False, + return_tokenizer=False, + device='cpu', + **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5) + cfg.update(**kwargs) + + # init a model on device + with torch.device(device): + model = XLMRoberta(**cfg) + return model diff --git a/wan/text2video.py b/wan/text2video.py new file mode 100644 index 0000000000000000000000000000000000000000..9daf89cbf93dbf131b9bdff07b8a6004bd791e8a --- /dev/null +++ b/wan/text2video.py @@ -0,0 +1,335 @@ +# Modified from official implementation + +# Original source: +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import logging +import math +import os +import random +import sys +from typing import Optional, Union + +import torch +from tqdm import tqdm +from accelerate import Accelerator, init_empty_weights +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +from utils.safetensors_utils import load_safetensors + +# from .distributed.fsdp import shard_model +from .modules.model import WanModel, load_wan_model +from .modules.t5 import T5EncoderModel +from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +from utils.device_utils import clean_memory_on_device, synchronize_device + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class WanT2V: + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_usp=False, + t5_cpu=False, + device=None, + dit_dtype=None, + dit_weight_dtype=None, + dit_path=None, + dit_attn_mode=None, + t5_path=None, + t5_fp8=False, + ): + r""" + Initializes the Wan text-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0) **IGNORED**: + Id of target GPU device + rank (`int`, *optional*, defaults to 0) **IGNORED**: + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False) **IGNORED**: + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False) **IGNORED**: + Enable FSDP sharding for DiT model + use_usp (`bool`, *optional*, defaults to False) **IGNORED**: + Enable distribution strategy of USP. + t5_cpu (`bool`, *optional*, defaults to False) **IGNORED**: + Whether to place T5 model on CPU. Only works without t5_fsdp. + device (`torch.device`, *optional*, defaults to None): + Device to place the model on. If None, use the default device (cuda) + dtype (`torch.dtype`, *optional*, defaults to None): + Data type for DiT model parameters. If None, use the default parameter data type from config + dit_path (`str`, *optional*, defaults to None): + Path to DiT model checkpoint. checkpoint_dir is used if None. + dit_attn_mode (`str`, *optional*, defaults to None): + Attention mode for DiT model. If None, use "torch" attention mode. + t5_path (`str`, *optional*, defaults to None): + Path to T5 model checkpoint. checkpoint_dir is used if None. + t5_fp8 (`bool`, *optional*, defaults to False): + Enable FP8 quantization for T5 model + """ + self.device = device if device is not None else torch.device("cuda") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.t5_fp8 = t5_fp8 + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + # shard_fn = partial(shard_model, device_id=device_id) + checkpoint_path = None if checkpoint_dir is None else os.path.join(checkpoint_dir, config.t5_checkpoint) + tokenizer_path = None if checkpoint_dir is None else os.path.join(checkpoint_dir, config.t5_tokenizer) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=device, + checkpoint_path=checkpoint_path, + tokenizer_path=tokenizer_path, + weight_path=t5_path, + fp8=t5_fp8, + # shard_fn=shard_fn if t5_fsdp else None, + ) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + + self.checkpoint_dir = checkpoint_dir + self.dit_path = dit_path + self.dit_dtype = dit_dtype # if dtype is not None else config.param_dtype + self.dit_weight_dtype = dit_weight_dtype + self.dit_attn_mode = dit_attn_mode + + self.sample_neg_prompt = config.sample_neg_prompt + + def generate( + self, + accelerator: Accelerator, + merge_lora: Optional[callable], + fp8_scaled: bool, + input_prompt, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver="unipc", + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + blocks_to_swap=0, + ): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + blocks_to_swap (`int`, *optional*, defaults to 0): + Number of blocks to swap (offload) to CPU. If 0, no blocks are offloaded. + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + F = frame_num + # self.vae.model.z_dim == 16 + target_shape = (16, (F - 1) // self.vae_stride[0] + 1, size[1] // self.vae_stride[1], size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.patch_size[1] * self.patch_size[2]) * target_shape[1]) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + self.text_encoder.model.to(self.device) + with torch.no_grad(): + if self.t5_fp8: + with accelerator.autocast(): + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + else: + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + + del self.text_encoder + clean_memory_on_device(self.device) + + # load DiT model + loading_device = "cpu" + if blocks_to_swap == 0 and merge_lora is None and not fp8_scaled: + loading_device = self.device + + loading_weight_dtype = self.dit_weight_dtype + if fp8_scaled or merge_lora is not None: + loading_weight_dtype = self.dit_dtype # load as-is + + # set fp8_scaled to False, because we optimize the model after merging LoRA + # TODO state dict based LoRA merge + self.model: WanModel = load_wan_model( + self.config, + False, + self.device, + self.dit_path, + self.dit_attn_mode, + False, + loading_device, + loading_weight_dtype, + False, + ) + + if merge_lora is not None: + # merge LoRA to the model, cast and move to the device + merge_lora(self.model) + + if fp8_scaled: + state_dict = self.model.state_dict() + move_to_device = blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU + state_dict = self.model.fp8_optimization(state_dict, self.device, move_to_device) + info = self.model.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded FP8 optimized weights: {info}") + if blocks_to_swap == 0: + self.model.to(self.device) # make sure all parameters are on the right device + else: + target_dtype = None + target_device = None + if self.dit_weight_dtype is not None: # in case of args.fp8 (not fp8_scaled) + logger.info(f"Convert model to {self.dit_weight_dtype}") + target_dtype = self.dit_weight_dtype + if blocks_to_swap == 0: + logger.info(f"Move model to device: {self.device}") + target_device = self.device + self.model.to(target_device, target_dtype) + + if blocks_to_swap > 0: + logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {self.device}") + self.model.enable_block_swap(blocks_to_swap, self.device, supports_backward=False) + self.model.move_to_device_except_swap_blocks(self.device) + self.model.prepare_block_swap_before_forward() + else: + # make sure the model is on the right device + self.model.to(self.device) + + self.model.eval().requires_grad_(False) + clean_memory_on_device(self.device) + + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g, + ) + ] + + # evaluation mode + # with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + with accelerator.autocast(), torch.no_grad(): + if sample_solver == "unipc": + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == "dpm++": + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps(sample_scheduler, device=self.device, sigmas=sampling_sigmas) + elif sample_solver == "vanilla": + sample_scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=self.num_train_timesteps, shift=shift) + sample_scheduler.set_timesteps(sampling_steps, device=self.device) + timesteps = sample_scheduler.timesteps + + org_step = sample_scheduler.step + + def step_wrapper( + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ): + return org_step(model_output, timestep, sample, return_dict=return_dict) + + sample_scheduler.step = step_wrapper + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + del noise + + arg_c = {"context": context, "seq_len": seq_len} + arg_null = {"context": context_null, "seq_len": seq_len} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0] + noise_pred_uncond = self.model(latent_model_input, t=timestep, **arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) + del noise_pred_cond, noise_pred_uncond + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=seed_g + )[0] + del noise_pred + latents = [temp_x0.squeeze(0)] + del temp_x0 + + x0 = latents + + del latents + del sample_scheduler + del self.model + synchronize_device(self.device) + clean_memory_on_device(self.device) + + # return latents + return x0[0] diff --git a/wan/utils/__init__.py b/wan/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e9a339e69fd55dd226d3ce242613c19bd690522 --- /dev/null +++ b/wan/utils/__init__.py @@ -0,0 +1,8 @@ +from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, + retrieve_timesteps) +from .fm_solvers_unipc import FlowUniPCMultistepScheduler + +__all__ = [ + 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', + 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler' +] diff --git a/wan/utils/fm_solvers.py b/wan/utils/fm_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..c908969e24849ce1381a8df9d5eb401dccf66524 --- /dev/null +++ b/wan/utils/fm_solvers.py @@ -0,0 +1,857 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# Convert dpm solver for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + pass + + +def get_sampling_sigmas(sampling_steps, shift): + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = (shift * sigma / (1 + (shift - 1) * sigma)) + + return sigma + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. This determines the resolution of the diffusion process. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored + and used in multistep updates. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + shift (`float`, *optional*, defaults to 1.0): + A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling + process. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is + applied on the fly. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent + saturation and improve photorealism. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, *optional*, defaults to "zero"): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + invert_sigmas: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", + deprecation_message) + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++" + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++" + ] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + self._step_index = None + self._begin_index = None + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / + sigma_s) * sample - (alpha_t * + (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / + alpha_s) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ((alpha_t / alpha_s) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * + (alpha_t * (torch.exp(-h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * + (sigma_t * (torch.exp(h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * + (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / + (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + self.sigmas[self.step_index - 2], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[ + -2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - + (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((alpha_t / alpha_s0) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2) + return x_t # pyright: ignore + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or + (self.config.lower_order_final and len(self.timesteps) < 15) or + self.config.final_sigmas_type == "zero") + lower_order_second = ((self.step_index == len(self.timesteps) - 2) and + self.config.lower_order_final and + len(self.timesteps) < 15) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" + ] and variance_noise is None: + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to( + device=model_output.device, + dtype=torch.float32) # pyright: ignore + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.Tensor`): + The input sample. + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/wan/utils/fm_solvers_unipc.py b/wan/utils/fm_solvers_unipc.py new file mode 100644 index 0000000000000000000000000000000000000000..57321baa35359782b33143321cd31c8d934a7b29 --- /dev/null +++ b/wan/utils/fm_solvers_unipc.py @@ -0,0 +1,800 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], + b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/wan/utils/utils.py b/wan/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d72599967f0a5a491e722e7d7a942efe5137b210 --- /dev/null +++ b/wan/utils/utils.py @@ -0,0 +1,118 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import binascii +import os +import os.path as osp + +import imageio +import torch +import torchvision + +__all__ = ['cache_video', 'cache_image', 'str2bool'] + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + +def cache_video(tensor, + save_file=None, + fps=30, + suffix='.mp4', + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + # cache file + cache_file = osp.join('/tmp', rand_name( + suffix=suffix)) if save_file is None else save_file + + # save to cache + error = None + for _ in range(retry): + try: + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack([ + torchvision.utils.make_grid( + u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], + dim=1).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + + # write video + writer = imageio.get_writer( + cache_file, fps=fps, codec='libx264', quality=8) + for frame in tensor.numpy(): + writer.append_data(frame) + writer.close() + return cache_file + except Exception as e: + error = e + continue + else: + print(f'cache_video failed, error: {error}', flush=True) + return None + + +def cache_image(tensor, + save_file, + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + # cache file + suffix = osp.splitext(save_file)[1] + if suffix.lower() not in [ + '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' + ]: + suffix = '.png' + + # save to cache + error = None + for _ in range(retry): + try: + tensor = tensor.clamp(min(value_range), max(value_range)) + torchvision.utils.save_image( + tensor, + save_file, + nrow=nrow, + normalize=normalize, + value_range=value_range) + return save_file + except Exception as e: + error = e + continue + + +def str2bool(v): + """ + Convert a string to a boolean. + + Supported true values: 'yes', 'true', 't', 'y', '1' + Supported false values: 'no', 'false', 'f', 'n', '0' + + Args: + v (str): String to convert. + + Returns: + bool: Converted boolean value. + + Raises: + argparse.ArgumentTypeError: If the value cannot be converted to boolean. + """ + if isinstance(v, bool): + return v + v_lower = v.lower() + if v_lower in ('yes', 'true', 't', 'y', '1'): + return True + elif v_lower in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected (True/False)') diff --git a/wan_generate_video.py b/wan_generate_video.py new file mode 100644 index 0000000000000000000000000000000000000000..60228d902c1a47da01847a8726adc1b573411769 --- /dev/null +++ b/wan_generate_video.py @@ -0,0 +1,2510 @@ +# --- START OF FILE wanFUN_generate_video.py --- + +import argparse +from datetime import datetime +import gc +import random +import os +import re +import time +import math +from typing import Tuple, Optional, List, Union, Any +from pathlib import Path # Added for glob_images in V2V + +import torch +import accelerate +from accelerate import Accelerator +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from PIL import Image +import cv2 # Added for V2V video loading/resizing +import numpy as np # Added for V2V video processing +import torchvision.transforms.functional as TF +from tqdm import tqdm + +from networks import lora_wan +from utils.safetensors_utils import mem_eff_save_file, load_safetensors +from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES +import wan +from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype +from wan.modules.vae import WanVAE +from wan.modules.t5 import T5EncoderModel +from wan.modules.clip import CLIPModel +from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler +from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +from blissful_tuner.latent_preview import LatentPreviewer + +try: + from lycoris.kohya import create_network_from_weights +except: + pass + +from utils.model_utils import str_to_dtype +from utils.device_utils import clean_memory_on_device +# Original load_video/load_images are still needed for Fun-Control / image loading +from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device, load_images as hv_load_images, load_video as hv_load_video + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def parse_args() -> argparse.Namespace: + """parse command line arguments""" + parser = argparse.ArgumentParser(description="Wan 2.1 inference script") + + # WAN arguments + parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).") + parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") + parser.add_argument( + "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample." + ) + + parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path") + parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path") + parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16") + parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU") + parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path") + parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path") + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") + parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + + # inference + parser.add_argument("--prompt", type=str, required=True, help="prompt for generation") + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + help="negative prompt for generation, use default negative prompt if not specified", + ) + parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width") + parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task") + parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16") + parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + parser.add_argument( + "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False." + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=5.0, + help="Guidance scale for classifier free guidance. Default is 5.0.", + ) + # V2V arguments + parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference (standard Wan V2V)") + parser.add_argument("--strength", type=float, default=0.75, help="Strength for video2video inference (0.0-1.0)") + # I2V arguments + parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference") + parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference") + # Fun-Control arguments (NEW/MODIFIED) + parser.add_argument( + "--control_path", # Keep this argument name + type=str, + default=None, + help="path to control video for inference with Fun-Control model. video file or directory with images", + ) + parser.add_argument( + "--control_start", + type=float, + default=0.0, + help="Start point (0.0-1.0) in the timeline where control influence is full (after fade-in)", + ) + parser.add_argument( + "--control_end", + type=float, + default=1.0, + help="End point (0.0-1.0) in the timeline where control influence starts to fade out", + ) + parser.add_argument( + "--control_falloff_percentage", # NEW name + type=float, + default=0.3, + help="Falloff percentage (0.0-0.49) for smooth transitions at start/end of control influence region", + ) + parser.add_argument( + "--control_weight", # NEW name + type=float, + default=1.0, + help="Overall weight/strength of control video influence for Fun-Control (0.0 to high values)", + ) + parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving") + parser.add_argument( + "--cfg_skip_mode", + type=str, + default="none", + choices=["early", "late", "middle", "early_late", "alternate", "none"], + help="CFG skip mode. each mode skips different parts of the CFG. " + " early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)", + ) + parser.add_argument( + "--cfg_apply_ratio", + type=float, + default=None, + help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).", + ) + parser.add_argument( + "--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated" + ) + parser.add_argument( + "--slg_scale", + type=float, + default=3.0, + help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond", + ) + parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.") + parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.") + parser.add_argument( + "--slg_mode", + type=str, + default=None, + choices=["original", "uncond"], + help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred", + ) + + # Flow Matching + parser.add_argument( + "--flow_shift", + type=float, + default=None, + help="Shift factor for flow matching schedulers. Default depends on task.", + ) + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") + parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled") + parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", + type=str, + default="torch", + choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"], + help="attention mode", + ) + parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") + parser.add_argument( + "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") + parser.add_argument("--compile", action="store_true", help="Enable torch.compile") + parser.add_argument( + "--compile_args", + nargs=4, + metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), + default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], + help="Torch.compile settings", + ) + parser.add_argument("--preview", type=int, default=None, metavar="N", + help="Enable latent preview every N steps. Generates previews in 'previews' subdirectory.", + ) + parser.add_argument("--preview_suffix", type=str, default=None, + help="Unique suffix for preview files to avoid conflicts in concurrent runs.", + ) + + args = parser.parse_args() + + assert (args.latent_path is None or len(args.latent_path) == 0) or ( + args.output_type == "images" or args.output_type == "video" + ), "latent_path is only supported for images or video output" + + # Add checks for mutually exclusive arguments + if args.video_path is not None and args.image_path is not None: + raise ValueError("--video_path and --image_path cannot be used together.") + if args.video_path is not None and args.control_path is not None: + raise ValueError("--video_path (standard V2V) and --control_path (Fun-Control) cannot be used together.") + if args.image_path is not None and "t2v" in args.task: + logger.warning("--image_path is provided, but task is set to t2v. Task type does not directly affect I2V mode.") + if args.control_path is not None and not WAN_CONFIGS[args.task].is_fun_control: + raise ValueError("--control_path is provided, but the selected task does not support Fun-Control.") + if not (0.0 <= args.control_falloff_percentage <= 0.49): + raise ValueError("--control_falloff_percentage must be between 0.0 and 0.49") + if args.task == "i2v-14B-FC-1.1" and args.image_path is None: + logger.warning(f"Task '{args.task}' typically uses --image_path as the reference image for ref_conv. Proceeding without it.") + return args + +def create_funcontrol_conditioning_latent( + args: argparse.Namespace, + config, + vae: WanVAE, + device: torch.device, + lat_f: int, + lat_h: int, + lat_w: int, + pixel_height: int, # Actual pixel height for resizing + pixel_width: int # Actual pixel width for resizing +) -> Optional[torch.Tensor]: + """ + Creates the conditioning latent tensor 'y' for FunControl models, + replicating the logic from WanWeightedControlToVideo node. + + Args: + args: Command line arguments. + config: Model configuration. + vae: Loaded VAE model instance. + device: Target computation device. + lat_f: Number of latent frames. + lat_h: Latent height. + lat_w: Latent width. + pixel_height: Target pixel height for image/video processing. + pixel_width: Target pixel width for image/video processing. + + Returns: + torch.Tensor: The final conditioning latent 'y' [1, 32, lat_f, lat_h, lat_w], + or None if VAE is missing when required. + """ + logger.info("Creating FunControl conditioning latent 'y'...") + if vae is None: + # Should not happen if called correctly, but check anyway + logger.error("VAE is required to create FunControl conditioning latent but was not provided.") + return None + + batch_size = 1 # Hardcoded for script execution + total_latent_frames = lat_f + vae_dtype = vae.dtype # Use VAE's dtype for encoding + + # Initialize the two parts of the concat latent + # Control part (first 16 channels) - will be filled later + control_latent_part = torch.zeros([batch_size, 16, total_latent_frames, lat_h, lat_w], + device=device, dtype=vae_dtype).contiguous() + # Image guidance part (last 16 channels) + image_guidance_latent = torch.zeros([batch_size, 16, total_latent_frames, lat_h, lat_w], + device=device, dtype=vae_dtype).contiguous() + + # --- Image Guidance Processing (Start/End Images) --- + timeline_mask = torch.zeros([1, 1, total_latent_frames], device=device, dtype=torch.float32).contiguous() + has_start_image = args.image_path is not None + has_end_image = args.end_image_path is not None + + # Process start image if provided + start_latent = None + if has_start_image: + logger.info(f"Processing start image: {args.image_path}") + try: + img = Image.open(args.image_path).convert("RGB") + img_np = np.array(img) + # Resize to target pixel dimensions + interpolation = cv2.INTER_AREA if pixel_height < img_np.shape[0] else cv2.INTER_CUBIC + img_resized_np = cv2.resize(img_np, (pixel_width, pixel_height), interpolation=interpolation) + # Convert to tensor CFHW, range [-1, 1] + img_tensor = TF.to_tensor(img_resized_np).sub_(0.5).div_(0.5).to(device) + img_tensor = img_tensor.unsqueeze(1) # Add frame dim: C,F,H,W + + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae_dtype): + # vae.encode expects a list, returns a list. Take first element. + # Result shape [C', F', H', W'] - needs batch dim for processing here + start_latent = vae.encode([img_tensor])[0].unsqueeze(0).to(device).contiguous() # [1, 16, 1, lat_h, lat_w] + + # Calculate influence and falloff + start_frames_influence = min(start_latent.shape[2], total_latent_frames) # Usually 1 + if start_frames_influence > 0: + # Use falloff_percentage for smooth transition *away* from start image + falloff_len_frames = max(1, int(total_latent_frames * args.control_falloff_percentage)) + start_influence_mask = torch.ones([1, 1, total_latent_frames], device=device, dtype=torch.float32).contiguous() + + # Apply falloff starting *after* the first frame + if total_latent_frames > 1 + falloff_len_frames: + # Falloff from frame 1 to 1+falloff_len_frames + t = torch.linspace(0, 1, falloff_len_frames, device=device) + falloff = 0.5 + 0.5 * torch.cos(t * math.pi) # 1 -> 0 + start_influence_mask[0, 0, 1:1+falloff_len_frames] = falloff + # Set influence to 0 after falloff + start_influence_mask[0, 0, 1+falloff_len_frames:] = 0.0 + elif total_latent_frames > 1: + # Shorter falloff if video is too short + t = torch.linspace(0, 1, total_latent_frames - 1, device=device) + falloff = 0.5 + 0.5 * torch.cos(t * math.pi) # 1 -> 0 + start_influence_mask[0, 0, 1:] = falloff + + # Place start latent in the image guidance part, weighted by mask + # Since start_latent is only frame 0, we just place it there. + # The mask influences how other elements (like end image) blend *in*. + image_guidance_latent[:, :, 0:1, :, :] = start_latent[:, :, 0:1, :, :] # Take first frame + + # Update the main timeline mask + timeline_mask = torch.max(timeline_mask, start_influence_mask) # Start image dominates beginning + logger.info(f"Start image processed. Latent shape: {start_latent.shape}") + + except Exception as e: + logger.error(f"Error processing start image: {e}") + # Continue without start image guidance + + # Process end image if provided + end_latent = None + if has_end_image: + logger.info(f"Processing end image: {args.end_image_path}") + try: + img = Image.open(args.end_image_path).convert("RGB") + img_np = np.array(img) + # Resize to target pixel dimensions + interpolation = cv2.INTER_AREA if pixel_height < img_np.shape[0] else cv2.INTER_CUBIC + img_resized_np = cv2.resize(img_np, (pixel_width, pixel_height), interpolation=interpolation) + # Convert to tensor CFHW, range [-1, 1] + img_tensor = TF.to_tensor(img_resized_np).sub_(0.5).div_(0.5).to(device) + img_tensor = img_tensor.unsqueeze(1) # Add frame dim: C,F,H,W + + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae_dtype): + # vae.encode expects a list, returns a list. Take first element. + # Result shape [C', F', H', W'] - needs batch dim for processing here + end_latent = vae.encode([img_tensor])[0].unsqueeze(0).to(device).contiguous() # [1, 16, 1, lat_h, lat_w] + + # Calculate end image influence transition (S-curve / cubic) + end_influence_mask = torch.zeros([1, 1, total_latent_frames], device=device, dtype=torch.float32).contiguous() + falloff_len_frames = max(1, int(total_latent_frames * args.control_falloff_percentage)) + + # Determine when the end image influence should start ramping up + # More sophisticated start point based on control_end if control video exists + if args.control_path and args.control_end < 1.0: + # Start fade-in just before control video fades out significantly + influence_start_frame = max(0, int(total_latent_frames * args.control_end) - falloff_len_frames // 2) + else: + # Default: start influence around 60% mark if no control or control runs full length + influence_start_frame = max(0, int(total_latent_frames * 0.6)) + + # Ensure start frame isn't too close to the beginning if start image exists + if has_start_image: + influence_start_frame = max(influence_start_frame, 1 + falloff_len_frames) # Ensure it starts after start_img falloff + + transition_length = total_latent_frames - influence_start_frame + if transition_length > 0: + logger.info(f"End image influence transition: frames {influence_start_frame} to {total_latent_frames-1}") + curve_positions = torch.linspace(0, 1, transition_length, device=device) + for i, pos in enumerate(curve_positions): + idx = influence_start_frame + i + if idx < total_latent_frames: + # Cubic ease-in-out curve (smoother than cosine) + if pos < 0.5: influence = 4 * pos * pos * pos + else: p = pos - 1; influence = 1 + 4 * p * p * p + # Ensure full influence near the end + if idx >= total_latent_frames - 3: influence = 1.0 + end_influence_mask[0, 0, idx] = influence + + # Blending logic (similar to base_nodes) + blend_start_frame = influence_start_frame + blend_length = total_latent_frames - blend_start_frame + if blend_length > 0: + # Create reference end latent (just the single frame repeated conceptually) + # Blend existing content with end latent based on influence weight + for i in range(blend_length): + idx = blend_start_frame + i + if idx < total_latent_frames: + weight = end_influence_mask[0, 0, idx].item() + if weight > 0: + # Blend: (1-w)*current + w*end_latent + image_guidance_latent[:, :, idx] = ( + (1.0 - weight) * image_guidance_latent[:, :, idx] + + weight * end_latent[:, :, 0] # Use the single frame end_latent + ) + + # Ensure final frames are exactly the end image latent + last_frames_exact = min(3, total_latent_frames) # Ensure at least last 3 frames are end image + if last_frames_exact > 0: + end_offset = total_latent_frames - last_frames_exact + if end_offset >= 0: + image_guidance_latent[:, :, end_offset:] = end_latent[:, :, 0:1].repeat(1, 1, last_frames_exact, 1, 1) + + # Update the main timeline mask + timeline_mask = torch.max(timeline_mask, end_influence_mask) + logger.info(f"End image processed. Latent shape: {end_latent.shape}") + + except Exception as e: + logger.error(f"Error processing end image: {e}") + # Continue without end image guidance + + # --- Control Video Processing --- + control_video_latent = None + if args.control_path: + logger.info(f"Processing control video: {args.control_path}") + try: + # Load control video frames (use helper from hv_generate_video for consistency) + # Use args.video_length for the number of frames + if os.path.isfile(args.control_path): + video_frames_np = hv_load_video(args.control_path, 0, args.video_length, bucket_reso=(pixel_width, pixel_height)) + elif os.path.isdir(args.control_path): + video_frames_np = hv_load_images(args.control_path, args.video_length, bucket_reso=(pixel_width, pixel_height)) + else: + raise FileNotFoundError(f"Control path not found: {args.control_path}") + + if not video_frames_np: + raise ValueError("No frames loaded from control path.") + + num_control_frames_loaded = len(video_frames_np) + if num_control_frames_loaded < args.video_length: + logger.warning(f"Control video loaded {num_control_frames_loaded} frames, less than target {args.video_length}. Padding with last frame.") + # Pad with the last frame + last_frame = video_frames_np[-1] + padding = [last_frame] * (args.video_length - num_control_frames_loaded) + video_frames_np.extend(padding) + + # Stack and convert to tensor: F, H, W, C -> B, C, F, H, W, range [-1, 1] + video_frames_np = np.stack(video_frames_np[:args.video_length], axis=0) # Ensure correct length + control_tensor = torch.from_numpy(video_frames_np).permute(0, 3, 1, 2).float() / 127.5 - 1.0 # F,C,H,W + control_tensor = control_tensor.permute(1, 0, 2, 3) # C,F,H,W + control_tensor = control_tensor.unsqueeze(0).to(device) # B,C,F,H,W + + # Encode control video + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae_dtype): + # vae.encode expects list of [C, F, H, W], returns list of [C', F', H', W'] + control_video_latent = vae.encode([control_tensor[0]])[0].unsqueeze(0).to(device).contiguous() # [1, 16, lat_f, lat_h, lat_w] + + # Calculate weighted control mask (replicating base_nodes logic) + control_frames_latent = control_video_latent.shape[2] # Should match total_latent_frames + control_mask = torch.zeros([1, 1, control_frames_latent], device=device, dtype=torch.float32).contiguous() + + start_frame_idx = max(0, min(control_frames_latent - 1, int(control_frames_latent * args.control_start))) + end_frame_idx = max(start_frame_idx + 1, min(control_frames_latent, int(control_frames_latent * args.control_end))) + falloff_len_frames = max(2, int(control_frames_latent * args.control_falloff_percentage)) + + # Main active region + if start_frame_idx < end_frame_idx: + control_mask[:, :, start_frame_idx:end_frame_idx] = 1.0 + + # Fall-on at the start + if start_frame_idx > 0: + fallon_start = max(0, start_frame_idx - falloff_len_frames) + fallon_len = start_frame_idx - fallon_start + if fallon_len > 0: + t = torch.linspace(0, 1, fallon_len, device=device) + smooth_t = 0.5 - 0.5 * torch.cos(t * math.pi) # 0 -> 1 + control_mask[:, :, fallon_start:start_frame_idx] = smooth_t.reshape(1, 1, -1) + + # Fall-off at the end (interacting with end_image influence) + if end_frame_idx < control_frames_latent: + falloff_start = end_frame_idx + falloff_end = min(control_frames_latent, falloff_start + falloff_len_frames) + falloff_actual_len = falloff_end - falloff_start + if falloff_actual_len > 0: + # Check for end image influence in this region + if has_end_image: + for i in range(falloff_start, falloff_end): + # Calculate original falloff (1 -> 0) + fade_pos = (i - falloff_start) / falloff_actual_len + original_falloff = 0.5 + 0.5 * math.cos(fade_pos * math.pi) + # Get end image influence (already calculated in timeline_mask) + end_influence_here = timeline_mask[0, 0, i].item() + # Adjust control falloff: decrease faster if end image is taking over + # Use a factor (e.g., 0.8) to control how much end image preempts control + adjusted_falloff = original_falloff * (1.0 - (end_influence_here * 0.8)) + control_mask[0, 0, i] = adjusted_falloff + logger.info("Applied end-image interaction to control falloff.") + else: + # Standard falloff if no end image + t = torch.linspace(0, 1, falloff_actual_len, device=device) + smooth_t = 0.5 + 0.5 * torch.cos(t * math.pi) # 1 -> 0 + control_mask[:, :, falloff_start:falloff_end] = smooth_t.reshape(1, 1, -1) + + # Apply final control weight + control_mask = control_mask * args.control_weight + + # Expand mask and apply to control latent + control_mask_expanded = control_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, 1, 1, F, 1, 1] ? -> needs [1, 1, F, 1, 1] + control_mask_expanded = control_mask.unsqueeze(-1).unsqueeze(-1) # Shape: [1, 1, F, 1, 1] + + # Apply weighting to the control_video_latent + weighted_control_latent = control_video_latent * control_mask_expanded # [1, 16, F, H, W] + + # Place into the first 16 channels of the final latent + control_latent_part = weighted_control_latent + + # Log mask pattern + mask_pattern = "".join(["#" if v > 0.8*args.control_weight else "+" if v > 0.4*args.control_weight else "." if v > 0.1*args.control_weight else " " + for v in control_mask[0, 0, :].tolist()]) + logger.info(f"Control mask pattern (weight={args.control_weight:.2f}): |{mask_pattern}|") + logger.info(f"Control video processed. Latent shape: {control_video_latent.shape}") + + except Exception as e: + logger.error(f"Error processing control video: {e}") + # Continue without control video guidance (control_latent_part remains zeros) + + # --- Final Assembly --- + # Concatenate the control part and the image guidance part + final_y = torch.cat([control_latent_part, image_guidance_latent], dim=1) # Concat along channel dim: [1, 16+16, F, H, W] + final_y = final_y.contiguous() + + logger.info(f"FunControl conditioning latent 'y' created. Final shape: {final_y.shape}") + + # Optional: Clean up intermediate tensors explicitly if memory is tight + del start_latent, end_latent, control_video_latent, control_latent_part, image_guidance_latent + del timeline_mask, control_mask + if 'control_tensor' in locals(): del control_tensor + if 'img_tensor' in locals(): del img_tensor + clean_memory_on_device(device) # Be cautious with frequent cache clearing + + return final_y + +def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None) -> Tuple[int, float, int, bool]: + """Return default values for each task + + Args: + task: task name (t2v, t2i, i2v etc.) + size: size of the video (width, height) + + Returns: + Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip) + """ + width, height = size if size else (0, 0) + + if "t2i" in task: + return 50, 5.0, 1, False + elif "i2v" in task: + flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 + return 40, flow_shift, 81, True + else: # t2v or default + return 50, 5.0, 81, False + + +def setup_args(args: argparse.Namespace) -> argparse.Namespace: + """Validate and set default values for optional arguments + + Args: + args: command line arguments + + Returns: + argparse.Namespace: updated arguments + """ + # Get default values for the task + infer_steps, flow_shift, video_length, _ = get_task_defaults(args.task, tuple(args.video_size)) + + # Apply default values to unset arguments + if args.infer_steps is None: + args.infer_steps = infer_steps + if args.flow_shift is None: + args.flow_shift = flow_shift + # For V2V, video_length might be determined by the input video later if not set + if args.video_length is None and args.video_path is None: + args.video_length = video_length + elif args.video_length is None and args.video_path is not None: + # Delay setting default if V2V and length not specified + pass + elif args.video_length is not None: + # Use specified length + pass + + # Force video_length to 1 for t2i tasks + if "t2i" in args.task: + assert args.video_length == 1, f"video_length should be 1 for task {args.task}" + + # parse slg_layers + if args.slg_layers is not None: + args.slg_layers = list(map(int, args.slg_layers.split(","))) + + return args + + +def check_inputs(args: argparse.Namespace) -> Tuple[int, int, Optional[int]]: + """Validate video size and potentially length (if not V2V auto-detect) + + Args: + args: command line arguments + + Returns: + Tuple[int, int, Optional[int]]: (height, width, video_length) + """ + height = args.video_size[0] + width = args.video_size[1] + size = f"{width}*{height}" + + # Only check supported sizes if not doing V2V (V2V might use custom sizes from input) + # Or if it's FunControl, which might have different size constraints + if args.video_path is None and not WAN_CONFIGS[args.task].is_fun_control: + if size not in SUPPORTED_SIZES[args.task]: + logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.") + + video_length = args.video_length # Might be None if V2V auto-detect + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + return height, width, video_length + + +def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]: + """calculate dimensions for the generation + + Args: + video_size: video frame size (height, width) + video_length: number of frames in the video + config: model configuration + + Returns: + Tuple[Tuple[int, int, int, int], int]: + ((channels, frames, height, width), seq_len) + """ + height, width = video_size + frames = video_length + + # calculate latent space dimensions + lat_f = (frames - 1) // config.vae_stride[0] + 1 + lat_h = height // config.vae_stride[1] + lat_w = width // config.vae_stride[2] + + # calculate sequence length + seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f) + + return ((16, lat_f, lat_h, lat_w), seq_len) + + +# Modified function (replace the original) +def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE: + """load VAE model with robust path handling + + Args: + args: command line arguments + config: model configuration + device: device to use + dtype: data type for the model + + Returns: + WanVAE: loaded VAE model + """ + vae_override_path = args.vae + vae_filename = config.vae_checkpoint # Get expected filename, e.g., "Wan2.1_VAE.pth" + # Assume models are in 'wan' dir relative to script if not otherwise specified + vae_base_dir = "wan" + + final_vae_path = None + + # 1. Check if args.vae is a valid *existing file path* + if vae_override_path and isinstance(vae_override_path, str) and \ + (vae_override_path.endswith(".pth") or vae_override_path.endswith(".safetensors")) and \ + os.path.isfile(vae_override_path): + final_vae_path = vae_override_path + logger.info(f"Using VAE override path from --vae: {final_vae_path}") + + # 2. If override is invalid or not provided, construct default path + if final_vae_path is None: + constructed_path = os.path.join(vae_base_dir, vae_filename) + if os.path.isfile(constructed_path): + final_vae_path = constructed_path + logger.info(f"Constructed default VAE path: {final_vae_path}") + if vae_override_path: + logger.warning(f"Ignoring potentially invalid --vae argument: {vae_override_path}") + else: + # 3. Fallback using ckpt_dir if provided and default construction failed + if args.ckpt_dir: + fallback_path = os.path.join(args.ckpt_dir, vae_filename) + if os.path.isfile(fallback_path): + final_vae_path = fallback_path + logger.info(f"Using VAE path from --ckpt_dir fallback: {final_vae_path}") + else: + # If all attempts fail, raise error + raise FileNotFoundError(f"Cannot find VAE. Checked override '{vae_override_path}', constructed '{constructed_path}', and fallback '{fallback_path}'") + else: + raise FileNotFoundError(f"Cannot find VAE. Checked override '{vae_override_path}' and constructed '{constructed_path}'. No --ckpt_dir provided for fallback.") + + # At this point, final_vae_path should be valid + logger.info(f"Loading VAE model from final path: {final_vae_path}") + cache_device = torch.device("cpu") if args.vae_cache_cpu else None + vae = WanVAE(vae_path=final_vae_path, device=device, dtype=dtype, cache_device=cache_device) + return vae + + +def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel: + """load text encoder (T5) model + + Args: + args: command line arguments + config: model configuration + device: device to use + + Returns: + T5EncoderModel: loaded text encoder model + """ + checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint) + tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer) + + text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=device, + checkpoint_path=checkpoint_path, + tokenizer_path=tokenizer_path, + weight_path=args.t5, + fp8=args.fp8_t5, + ) + + return text_encoder + + +def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel: + """load CLIP model (for I2V only) + + Args: + args: command line arguments + config: model configuration + device: device to use + + Returns: + CLIPModel: loaded CLIP model + """ + checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint) + tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer) + + clip = CLIPModel( + dtype=config.clip_dtype, + device=device, + checkpoint_path=checkpoint_path, + tokenizer_path=tokenizer_path, + weight_path=args.clip, + ) + + return clip + + +def load_dit_model( + args: argparse.Namespace, + config, + device: torch.device, + dit_dtype: torch.dtype, + dit_weight_dtype: Optional[torch.dtype] = None, + is_i2v: bool = False, # is_i2v might influence model loading specifics in some versions +) -> WanModel: + """load DiT model + + Args: + args: command line arguments + config: model configuration + device: device to use + dit_dtype: data type for the model + dit_weight_dtype: data type for the model weights. None for as-is + is_i2v: I2V mode (might affect some model config details) + + Returns: + WanModel: loaded DiT model + """ + loading_device = "cpu" + if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled: + loading_device = device + + loading_weight_dtype = dit_weight_dtype + if args.fp8_scaled or args.lora_weight is not None: + loading_weight_dtype = dit_dtype # load as-is + + # do not fp8 optimize because we will merge LoRA weights + # The 'is_i2v' flag might be used internally by load_wan_model if needed by specific Wan versions + model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, False) + + return model + + +def merge_lora_weights(model: WanModel, args: argparse.Namespace, device: torch.device) -> None: + """merge LoRA weights to the model + + Args: + model: DiT model + args: command line arguments + device: device to use + """ + if args.lora_weight is None or len(args.lora_weight) == 0: + return + + for i, lora_weight in enumerate(args.lora_weight): + if args.lora_multiplier is not None and len(args.lora_multiplier) > i: + lora_multiplier = args.lora_multiplier[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + + # apply include/exclude patterns + original_key_count = len(weights_sd.keys()) + if args.include_patterns is not None and len(args.include_patterns) > i: + include_pattern = args.include_patterns[i] + regex_include = re.compile(include_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} + logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") + if args.exclude_patterns is not None and len(args.exclude_patterns) > i: + original_key_count_ex = len(weights_sd.keys()) + exclude_pattern = args.exclude_patterns[i] + regex_exclude = re.compile(exclude_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} + logger.info( + f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}" + ) + if len(weights_sd) != original_key_count: + remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) + remaining_keys.sort() + logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") + if len(weights_sd) == 0: + logger.warning(f"No keys left after filtering.") + + if args.lycoris: + lycoris_net, _ = create_network_from_weights( + multiplier=lora_multiplier, + file=None, + weights_sd=weights_sd, + unet=model, + text_encoder=None, + vae=None, + for_inference=True, + ) + lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device) + else: + network = lora_wan.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True) + network.merge_to(None, model, weights_sd, device=device, non_blocking=True) + + synchronize_device(device) + logger.info("LoRA weights loaded") + + # save model here before casting to dit_weight_dtype + if args.save_merged_model: + logger.info(f"Saving merged model to {args.save_merged_model}") + mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory + logger.info("Merged model saved") + + +def optimize_model( + model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype +) -> None: + """optimize the model (FP8 conversion, device move etc.) + + Args: + model: dit model + args: command line arguments + device: device to use + dit_dtype: dtype for the model + dit_weight_dtype: dtype for the model weights + """ + if args.fp8_scaled: + # load state dict as-is and optimize to fp8 + state_dict = model.state_dict() + + # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy) + move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU + state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast) + + info = model.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded FP8 optimized weights: {info}") + + if args.blocks_to_swap == 0: + model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.) + else: + # simple cast to dit_dtype + target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) + target_device = None + + if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled + logger.info(f"Convert model to {dit_weight_dtype}") + target_dtype = dit_weight_dtype + + if args.blocks_to_swap == 0: + logger.info(f"Move model to device: {device}") + target_device = device + + model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations + + if args.compile: + compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args + logger.info( + f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" + ) + torch._dynamo.config.cache_size_limit = 32 + for i in range(len(model.blocks)): + model.blocks[i] = torch.compile( + model.blocks[i], + backend=compile_backend, + mode=compile_mode, + dynamic=compile_dynamic.lower() in "true", + fullgraph=compile_fullgraph.lower() in "true", + ) + + if args.blocks_to_swap > 0: + logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") + model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) + model.move_to_device_except_swap_blocks(device) + model.prepare_block_swap_before_forward() + else: + # make sure the model is on the right device + model.to(device) + + model.eval().requires_grad_(False) + clean_memory_on_device(device) + + +def prepare_t2v_inputs( + args: argparse.Namespace, config, accelerator: Accelerator, device: torch.device, vae: Optional[WanVAE] = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + """Prepare inputs for T2V (including Fun-Control variation) + + Args: + args: command line arguments + config: model configuration + accelerator: Accelerator instance + device: device to use + vae: VAE model, required only for Fun-Control + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + (noise, context, context_null, (arg_c, arg_null)) + """ + # Prepare inputs for T2V + # calculate dimensions and sequence length + height, width = args.video_size + frames = args.video_length # Should be set by now + (ch, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config) + target_shape = (ch, lat_f, lat_h, lat_w) # Should be (16, lat_f, lat_h, lat_w) for base latent + + # configure negative prompt + n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt + + # set seed + seed = args.seed # Seed should be set in generate() + if not args.cpu_noise: + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + else: + # ComfyUI compatible noise + seed_g = torch.manual_seed(seed) + + # load text encoder + text_encoder = load_text_encoder(args, config, device) + text_encoder.model.to(device) + + # encode prompt + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + + # free text encoder and clean memory + del text_encoder + clean_memory_on_device(device) + + # Initialize 'y' (conditioning latent) to None + y = None + + # Handle Fun-Control T2V case + if config.is_fun_control: + logger.info("Preparing inputs for Fun-Control T2V.") + if vae is None: + raise ValueError("VAE is required for Fun-Control T2V input preparation.") + + # Calculate pixel dimensions needed for encoding helper + pixel_height = lat_h * config.vae_stride[1] + pixel_width = lat_w * config.vae_stride[2] + + # Create the conditioning latent 'y' + # This function handles control video encoding (if path provided) + # and creates the [1, 32, F, H, W] tensor. + # If no control path, it creates the control part as zeros. + # Since this is T2V, image_path and end_image_path are None in args, + # so the image guidance part will also be zeros. + vae.to_device(device) # Ensure VAE is on device + y = create_funcontrol_conditioning_latent( + args, config, vae, device, lat_f, lat_h, lat_w, pixel_height, pixel_width + ) + # Move VAE back after use + vae.to_device(args.vae_cache_cpu if args.vae_cache_cpu else "cpu") + clean_memory_on_device(device) + + if y is None: + raise RuntimeError("Failed to create FunControl conditioning latent 'y'.") + + # generate noise (base latent noise, shape [16, F, H, W]) + noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu") + noise = noise.to(device) + + # prepare model input arguments + arg_c = {"context": context, "seq_len": seq_len} + arg_null = {"context": context_null, "seq_len": seq_len} + + # Add 'y' ONLY if it was created (i.e., for Fun-Control) + if y is not None: + arg_c["y"] = [y] # Model expects y as a list + arg_null["y"] = [y] + logger.info(f"Added FunControl conditioning 'y' (shape: {y.shape}) to model inputs.") + elif config.is_fun_control: + # This case should technically be handled by y being zeros, but double-check + logger.warning("FunControl task but 'y' tensor was not generated. Model might error.") + # Create a zero tensor as fallback if y generation failed somehow? + # y = torch.zeros([1, 32, lat_f, lat_h, lat_w], device=device, dtype=noise.dtype) + # arg_c["y"] = [y] + # arg_null["y"] = [y] + + + return noise, context, context_null, (arg_c, arg_null) + +# ========================================================================= # +# START OF MODIFIED FUNCTION prepare_i2v_inputs +# ========================================================================= # +def prepare_i2v_inputs( + args: argparse.Namespace, config, accelerator: Accelerator, device: torch.device, vae: WanVAE +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + """Prepare inputs for I2V (including Fun-Control I2V variation) + + Args: + args: command line arguments + config: model configuration + accelerator: Accelerator instance + device: device to use + vae: VAE model, used for image encoding + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: + (noise, context, context_null, y, (arg_c, arg_null)) + 'y' is the conditioning latent ([1, 32, F, H, W] for FunControl, + [1, C+4, F, H, W] for standard I2V with mask). + """ + if vae is None: + raise ValueError("VAE must be provided for I2V input preparation.") + + # --- Prepare Conditioning Latent 'y' --- + # This check MUST come first to decide the entire logic path + if config.is_fun_control: + # --- FunControl I2V Path --- + logger.info("Preparing inputs for Fun-Control I2V.") + + # Calculate dimensions (FunControl might use different aspect logic) + height, width = args.video_size + frames = args.video_length # Should be set by now + (_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config) + pixel_height = lat_h * config.vae_stride[1] + pixel_width = lat_w * config.vae_stride[2] + noise_channels = 16 # FunControl DiT denoises 16 channels + + logger.info(f"FunControl I2V target pixel resolution: {pixel_height}x{pixel_width}, latent shape: ({lat_f}, {lat_h}, {lat_w}), seq_len: {seq_len}") + + # set seed + seed = args.seed + if not args.cpu_noise: + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + else: + seed_g = torch.manual_seed(seed) + + # generate noise (for the part being denoised by the DiT) + noise = torch.randn( + noise_channels, lat_f, lat_h, lat_w, + dtype=torch.float32, generator=seed_g, + device=device if not args.cpu_noise else "cpu", + ) + noise = noise.to(device) + + # configure negative prompt + n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt + + # load text encoder & encode prompts + text_encoder = load_text_encoder(args, config, device) + text_encoder.model.to(device) + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + del text_encoder + clean_memory_on_device(device) + + # load CLIP model & encode image + clip = load_clip_model(args, config, device) + clip.model.to(device) + if not args.image_path: + raise ValueError("--image_path is required for FunControl I2V mode.") + img_clip = Image.open(args.image_path).convert("RGB") + img_tensor_clip = TF.to_tensor(img_clip).sub_(0.5).div_(0.5).to(device) # CHW, [-1, 1] + with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): + clip_context = clip.visual([img_tensor_clip.unsqueeze(1)]) # Add Frame dim + del clip, img_clip, img_tensor_clip + clean_memory_on_device(device) + + fun_ref_latent = None + # Check if the task requires ref_conv and if a reference image is provided via --image_path + if args.task == "i2v-14B-FC-1.1" and args.image_path is not None: + logger.info(f"Task {args.task} requires ref_conv. Encoding reference image from --image_path: {args.image_path}") + try: + ref_img = Image.open(args.image_path).convert("RGB") + ref_img_np = np.array(ref_img) + # Resize ref image to target pixel dimensions + interpolation = cv2.INTER_AREA if pixel_height < ref_img_np.shape[0] else cv2.INTER_CUBIC + ref_img_resized_np = cv2.resize(ref_img_np, (pixel_width, pixel_height), interpolation=interpolation) + # Convert to tensor CFHW, range [-1, 1] + ref_img_tensor = TF.to_tensor(ref_img_resized_np).sub_(0.5).div_(0.5).to(device) + ref_img_tensor = ref_img_tensor.unsqueeze(1) # Add frame dim: C,F,H,W + + vae.to_device(device) # Ensure VAE is on device for encoding + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae.dtype): + # Encode the single reference frame + # vae.encode returns list, take first element. Result shape [C', 1, H', W'] + fun_ref_latent = vae.encode([ref_img_tensor])[0] + # Squeeze the frame dimension for Conv2d in the model: [C', H', W'] + fun_ref_latent = fun_ref_latent.squeeze(1) + logger.info(f"Encoded fun_ref latent. Shape: {fun_ref_latent.shape}") + # Keep VAE on device for main conditioning latent creation below + + except Exception as e: + logger.error(f"Error processing reference image for fun_ref: {e}") + fun_ref_latent = None # Continue without ref if encoding fails + + # **IMPORTANT**: Since --image_path is now used for fun_ref, + # temporarily set it to None *before* calling create_funcontrol_conditioning_latent + # so it doesn't get processed *again* as a start image inside that function. + original_image_path = args.image_path + args.image_path = None + + # Use the FunControl helper function to create the 32-channel 'y' + vae.to_device(device) # Ensure VAE is on compute device + y = create_funcontrol_conditioning_latent( + args, config, vae, device, lat_f, lat_h, lat_w, pixel_height, pixel_width + ) + if args.task == "i2v-14B-FC-1.1": + args.image_path = original_image_path + if y is None: + raise RuntimeError("Failed to create FunControl conditioning latent 'y'.") + vae.to_device(args.vae_cache_cpu if args.vae_cache_cpu else "cpu") # Move VAE back + clean_memory_on_device(device) + + # Prepare Model Input Arguments for FunControl + y_for_model = y[0] # Shape becomes [32, F, H, W] + arg_c = { + "context": context, + "clip_fea": clip_context, + "seq_len": seq_len, + "y": [y_for_model], # Pass the 4D tensor in the list + } + arg_null = { + "context": context_null, + "clip_fea": clip_context, + "seq_len": seq_len, + "y": [y_for_model], # Pass the 4D tensor in the list + } + + if fun_ref_latent is not None: + # Model forward expects fun_ref directly, not in a list like 'y' + arg_c["fun_ref"] = fun_ref_latent + arg_null["fun_ref"] = fun_ref_latent # Pass to both cond and uncond + logger.info("Added fun_ref latent to model inputs.") + + # Return noise, context, context_null, y (for potential debugging), (arg_c, arg_null) + return noise, context, context_null, y, (arg_c, arg_null) + + else: + # --- Standard I2V Path (Logic copied/adapted from original wan_generate_video.py) --- + logger.info("Preparing inputs for standard I2V.") + + # get video dimensions + height, width = args.video_size + frames = args.video_length # Should be set by now + max_area = width * height + + # load image + if not args.image_path: + raise ValueError("--image_path is required for standard I2V mode.") + img = Image.open(args.image_path).convert("RGB") + img_cv2 = np.array(img) # PIL to numpy + img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) # For CLIP + + # end frame image + end_img = None + end_img_cv2 = None + if args.end_image_path is not None: + end_img = Image.open(args.end_image_path).convert("RGB") + end_img_cv2 = np.array(end_img) # PIL to numpy + has_end_image = end_img is not None + + # calculate latent dimensions: keep aspect ratio (Original Method) + img_height, img_width = img.size[::-1] # PIL size is W,H + aspect_ratio = img_height / img_width + lat_h = round(np.sqrt(max_area * aspect_ratio) / config.vae_stride[1] / config.patch_size[1]) * config.patch_size[1] + lat_w = round(np.sqrt(max_area / aspect_ratio) / config.vae_stride[2] / config.patch_size[2]) * config.patch_size[2] + target_height = lat_h * config.vae_stride[1] + target_width = lat_w * config.vae_stride[2] + + # --- CRITICAL ORIGINAL LOGIC DIFFERENCE #1: Frame Dimension --- + lat_f_base = (frames - 1) // config.vae_stride[0] + 1 # size of latent frames + lat_f_effective = lat_f_base + (1 if has_end_image else 0) # Adjust frame dim if end image exists + + # --- CRITICAL ORIGINAL LOGIC DIFFERENCE #2: Sequence Length --- + max_seq_len = math.ceil(lat_f_effective * lat_h * lat_w / (config.patch_size[1] * config.patch_size[2])) + + logger.info(f"Standard I2V target pixel resolution: {target_height}x{target_width}, latent shape: ({lat_f_effective}, {lat_h}, {lat_w}), seq_len: {max_seq_len}") + + # set seed + seed = args.seed + if not args.cpu_noise: + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + else: + seed_g = torch.manual_seed(seed) + + # --- CRITICAL ORIGINAL LOGIC DIFFERENCE #3: Noise Shape --- + noise = torch.randn( + 16, # Channel dim for latent + lat_f_effective, # Use adjusted frame dim + lat_h, lat_w, + dtype=torch.float32, generator=seed_g, + device=device if not args.cpu_noise else "cpu", + ) + noise = noise.to(device) + + # configure negative prompt + n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt + + # load text encoder & encode prompts + text_encoder = load_text_encoder(args, config, device) + text_encoder.model.to(device) + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + del text_encoder + clean_memory_on_device(device) + + # load CLIP model & encode image + clip = load_clip_model(args, config, device) + clip.model.to(device) + logger.info(f"Encoding image to CLIP context") + with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): + # Use the [-1, 1] tensor directly if clip.visual expects that format + # clip_context = clip.visual([img_tensor[:, None, :, :]]).squeeze(1) # Original had [img_tensor[:, None, :, :]] which adds frame dim + # Use unsqueeze(1) which seems more consistent with other parts + clip_context = clip.visual([img_tensor.unsqueeze(1)]) # Add Frame dim + logger.info(f"CLIP Encoding complete") + del clip + clean_memory_on_device(device) + + # --- CRITICAL ORIGINAL LOGIC DIFFERENCE #4: VAE Encoding and 'y' construction --- + logger.info(f"Encoding image(s) to latent space (Standard I2V method)") + vae.to_device(device) + + # Resize image(s) for VAE + interpolation = cv2.INTER_AREA if target_height < img_cv2.shape[0] else cv2.INTER_CUBIC + img_resized_np = cv2.resize(img_cv2, (target_width, target_height), interpolation=interpolation) + img_resized = TF.to_tensor(img_resized_np).sub_(0.5).div_(0.5).to(device) # [-1, 1], CHW + img_resized = img_resized.unsqueeze(1) # Add frame dimension -> CFHW, Shape [C, 1, H, W] + + end_img_resized = None + if has_end_image and end_img_cv2 is not None: + interpolation_end = cv2.INTER_AREA if target_height < end_img_cv2.shape[0] else cv2.INTER_CUBIC + end_img_resized_np = cv2.resize(end_img_cv2, (target_width, target_height), interpolation=interpolation_end) + end_img_resized = TF.to_tensor(end_img_resized_np).sub_(0.5).div_(0.5).to(device) # [-1, 1], CHW + end_img_resized = end_img_resized.unsqueeze(1) # Add frame dimension -> CFHW, Shape [C, 1, H, W] + + # --- CRITICAL ORIGINAL LOGIC DIFFERENCE #5: Mask Shape --- + msk = torch.zeros(4, lat_f_effective, lat_h, lat_w, device=device, dtype=vae.dtype) # Use adjusted frame dim + msk[:, 0] = 1 # Mask first frame + if has_end_image: + msk[:, -1] = 1 # Mask last frame (the lat_f+1'th frame) + + # Encode image(s) using VAE (Padded Method) + with accelerator.autocast(), torch.no_grad(): + # Pad the *start* image tensor temporally before encoding + # Calculate padding needed to reach base frame count (before adding end frame) + padding_frames_needed = frames - 1 # Number of frames to generate *after* the first + if padding_frames_needed < 0: padding_frames_needed = 0 + + img_padded = img_resized # Start with [C, 1, H, W] + if padding_frames_needed > 0: + # Create padding tensor [C, padding_frames_needed, H, W] + padding_tensor = torch.zeros( + img_resized.shape[0], padding_frames_needed, img_resized.shape[2], img_resized.shape[3], + device=device, dtype=img_resized.dtype + ) + # Concatenate along frame dimension (dim=1) + img_padded = torch.cat([img_resized, padding_tensor], dim=1) + # Shape should now be [C, 1 + padding_frames_needed, H, W] = [C, frames, H, W] + + # Encode the padded start image tensor. VAE output matches latent frame count. + # vae.encode expects [C, F, H, W] + y_latent_base = vae.encode([img_padded])[0] # Shape [C', lat_f_base, H, W] + + if has_end_image and end_img_resized is not None: + # Encode the single end frame + y_end = vae.encode([end_img_resized])[0] # Shape [C', 1, H, W] + # Concatenate along frame dimension (dim=1) + y_latent_combined = torch.cat([y_latent_base, y_end], dim=1) # Shape [C', lat_f_base + 1, H, W] = [C', lat_f_effective, H, W] + else: + y_latent_combined = y_latent_base # Shape [C', lat_f_base, H, W] = [C', lat_f_effective, H, W] + + # Concatenate mask and the combined latent + # --- CRITICAL ORIGINAL LOGIC DIFFERENCE #6: Final 'y' Tensor --- + y = torch.cat([msk, y_latent_combined], dim=0) # Shape [4+C', lat_f_effective, H, W] + # y = y.unsqueeze(0) # Add batch dimension? Check model input requirements. Assume model forward handles list/batching. + + logger.info(f"Standard I2V conditioning 'y' constructed. Shape: {y.shape}") + logger.info(f"Image encoding complete") + + # Move VAE back + vae.to_device(args.vae_cache_cpu if args.vae_cache_cpu else "cpu") + clean_memory_on_device(device) + + # Prepare model input arguments for Standard I2V + arg_c = { + "context": context, # Model expects batch dim? Assuming yes. + "clip_fea": clip_context, + "seq_len": max_seq_len, # Use original seq len calculation + "y": [y], # Use the 'original method' y + } + arg_null = { + "context": context_null, + "clip_fea": clip_context, + "seq_len": max_seq_len, + "y": [y], # Use the 'original method' y + } + + # Return noise, context, context_null, y (for debugging), (arg_c, arg_null) + return noise, context, context_null, y, (arg_c, arg_null) +# ========================================================================= # +# END OF MODIFIED FUNCTION prepare_i2v_inputs +# ========================================================================= # + + +# --- V2V Helper Functions --- + +def load_video(video_path, start_frame=0, num_frames=None, bucket_reso=(256, 256)): + """Load video frames and resize them to the target resolution for V2V. + + Args: + video_path (str): Path to the video file + start_frame (int): First frame to load (0-indexed) + num_frames (int, optional): Number of frames to load. If None, load all frames from start_frame. + bucket_reso (tuple): Target resolution (height, width) + + Returns: + list: List of numpy arrays containing video frames in RGB format, resized. + int: Actual number of frames loaded. + """ + logger.info(f"Loading video for V2V from {video_path}, target reso {bucket_reso}, frames {start_frame}-{start_frame+num_frames if num_frames else 'end'}") + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Failed to open video file: {video_path}") + + # Get total frame count and FPS + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + logger.info(f"Input video has {total_frames} frames, {fps} FPS") + + # Calculate how many frames to load + if num_frames is None: + frames_to_load = total_frames - start_frame + else: + # Make sure we don't try to load more frames than exist + frames_to_load = min(num_frames, total_frames - start_frame) + + if frames_to_load <= 0: + cap.release() + logger.warning(f"No frames to load (start_frame={start_frame}, num_frames={num_frames}, total_frames={total_frames})") + return [], 0 + + # Skip to start frame + if start_frame > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + # Read frames + frames = [] + target_h, target_w = bucket_reso + for i in range(frames_to_load): + ret, frame = cap.read() + if not ret: + logger.warning(f"Could only read {len(frames)} frames out of {frames_to_load} requested.") + break + + # Convert from BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize the frame + # Use INTER_AREA for downscaling, INTER_LANCZOS4/CUBIC for upscaling + current_h, current_w = frame_rgb.shape[:2] + if target_h * target_w < current_h * current_w: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 # Higher quality for upscaling + frame_resized = cv2.resize(frame_rgb, (target_w, target_h), interpolation=interpolation) + + frames.append(frame_resized) + + cap.release() + actual_frames_loaded = len(frames) + logger.info(f"Successfully loaded and resized {actual_frames_loaded} frames for V2V.") + + return frames, actual_frames_loaded + + +def encode_video_to_latents(video_tensor: torch.Tensor, vae: WanVAE, device: torch.device, vae_dtype: torch.dtype, args: argparse.Namespace) -> torch.Tensor: # Added args parameter + """Encode video tensor to latent space using VAE for V2V. + + Args: + video_tensor (torch.Tensor): Video tensor with shape [B, C, F, H, W], values in [0, 1]. + vae (WanVAE): VAE model instance. + device (torch.device): Device to perform encoding on. + vae_dtype (torch.dtype): Target dtype for the output latents. + args (argparse.Namespace): Command line arguments (needed for vae_cache_cpu). # Added args description + + Returns: + torch.Tensor: Encoded latents with shape [B, C', F', H', W']. + """ + if vae is None: + raise ValueError("VAE must be provided for video encoding.") + + logger.info(f"Encoding video tensor to latents: input shape {video_tensor.shape}") + + # Ensure VAE is on the correct device + vae.to_device(device) + + # Prepare video tensor: move to device, ensure float32, scale to [-1, 1] + video_tensor = video_tensor.to(device=device, dtype=torch.float32) + video_tensor = video_tensor * 2.0 - 1.0 # Scale from [0, 1] to [-1, 1] + + # WanVAE expects input as a list of [C, F, H, W] tensors (no batch dim) + # Process each video in the batch if batch size > 1 (usually 1 here) + latents_list = [] + batch_size = video_tensor.shape[0] + for i in range(batch_size): + video_single = video_tensor[i] # Shape [C, F, H, W] + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=vae.dtype): # Use VAE's internal dtype for autocast + # vae.encode expects a list containing the tensor + encoded_latent = vae.encode([video_single])[0] # Returns tensor [C', F', H', W'] + latents_list.append(encoded_latent) + + # Stack results back into a batch + latents = torch.stack(latents_list, dim=0) # Shape [B, C', F', H', W'] + + # Move VAE back to CPU (or cache device) + # Use the passed args object here + vae_target_device = torch.device("cpu") if not args.vae_cache_cpu else torch.device("cpu") # Default to CPU, TODO: check if cache device needs specific name + if args.vae_cache_cpu: + # Determine the actual cache device if needed, for now, CPU is safe fallback + logger.info("Moving VAE to CPU for caching (as configured by --vae_cache_cpu).") + else: + logger.info("Moving VAE to CPU after encoding.") + vae.to_device(vae_target_device) # Use args to decide target device + clean_memory_on_device(device) # Clean the GPU memory + + # Convert latents to the desired final dtype (e.g., bfloat16) + latents = latents.to(dtype=vae_dtype) + logger.info(f"Encoded video latents shape: {latents.shape}, dtype: {latents.dtype}") + + return latents + + +def prepare_v2v_inputs(args: argparse.Namespace, config, accelerator: Accelerator, device: torch.device, video_latents: torch.Tensor): + """Prepare inputs for Video2Video inference based on encoded video latents. + + Args: + args (argparse.Namespace): Command line arguments. + config: Model configuration. + accelerator: Accelerator instance. + device (torch.device): Device to use. + video_latents (torch.Tensor): Encoded latent representation of input video [B, C', F', H', W']. + + Returns: + Tuple containing noise, context, context_null, (arg_c, arg_null). + """ + # Get dimensions directly from the video latents + if len(video_latents.shape) != 5: + raise ValueError(f"Expected video_latents to have 5 dimensions [B, C, F, H, W], but got shape {video_latents.shape}") + + batch_size, latent_channels, lat_f, lat_h, lat_w = video_latents.shape + if batch_size != 1: + logger.warning(f"V2V input preparation currently assumes batch size 1, but got {batch_size}. Using first item.") + video_latents = video_latents[0:1] # Keep batch dim + + target_shape = video_latents.shape[1:] # Get shape without batch dim: [C', F', H', W'] + + # Calculate the sequence length based on actual latent dimensions + patch_h, patch_w = config.patch_size[1], config.patch_size[2] + spatial_tokens_per_frame = (lat_h * lat_w) // (patch_h * patch_w) + seq_len = spatial_tokens_per_frame * lat_f + logger.info(f"V2V derived latent shape: {target_shape}, seq_len: {seq_len}") + + # Configure negative prompt + n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt + + # Set seed (already set in generate(), just need generator) + seed = args.seed + if not args.cpu_noise: + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + else: + seed_g = torch.manual_seed(seed) + + # Load text encoder + text_encoder = load_text_encoder(args, config, device) + text_encoder.model.to(device) + + # Encode prompt + with torch.no_grad(): + if args.fp8_t5: + with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + else: + context = text_encoder([args.prompt], device) + context_null = text_encoder([n_prompt], device) + + # Free text encoder and clean memory + del text_encoder + clean_memory_on_device(device) + + # Generate noise with the same shape as video_latents (including batch dimension) + noise = torch.randn( + video_latents.shape, # [B, C', F', H', W'] + dtype=torch.float32, + device=device if not args.cpu_noise else "cpu", + generator=seed_g + ) + noise = noise.to(device) # Ensure noise is on the target device + + # Prepare model input arguments (context needs to match batch size of latents) + # Assuming batch size 1 for now based on implementation + arg_c = {"context": context, "seq_len": seq_len} + arg_null = {"context": context_null, "seq_len": seq_len} + + # V2V does not use 'y' or 'clip_fea' in the standard Wan model case + # If a specific V2V variant *did* need them, they would be added here. + + return noise, context, context_null, (arg_c, arg_null) + + +# --- End V2V Helper Functions --- + +def load_control_video(control_path: str, frames: int, height: int, width: int, args=None) -> torch.Tensor: + """Load control video to pixel space for Fun-Control model with enhanced control. + + Args: + control_path: path to control video + frames: number of frames in the video + height: height of the video + width: width of the video + args: command line arguments (optional, for logging) + + Returns: + torch.Tensor: control video tensor, CFHW, range [-1, 1] + """ + logger = logging.getLogger(__name__) + msg = f"Load control video for Fun-Control from {control_path}" + if args: + # Use the correct argument names from wanFUN_generate_video.py + msg += f" (weight={args.control_weight}, start={args.control_start}, end={args.control_end})" + logger.info(msg) + + # Use the original helper from hv_generate_video for consistency + if os.path.isfile(control_path): + # Use hv_load_video which returns list of numpy arrays (HWC, 0-255) + video_frames_np = hv_load_video(control_path, 0, frames, bucket_reso=(width, height)) + elif os.path.isdir(control_path): + # Use hv_load_images which returns list of numpy arrays (HWC, 0-255) + video_frames_np = hv_load_images(control_path, frames, bucket_reso=(width, height)) + else: + raise FileNotFoundError(f"Control path not found: {control_path}") + + if not video_frames_np: + raise ValueError(f"No frames loaded from control path: {control_path}") + if len(video_frames_np) < frames: + logger.warning(f"Control video has {len(video_frames_np)} frames, less than requested {frames}. Using available frames.") + # Optionally, could repeat last frame or loop, but using available is simplest + frames = len(video_frames_np) # Adjust frame count + + # Stack and convert to tensor: F, H, W, C (0-255) -> F, C, H, W (-1 to 1) + video_frames_np = np.stack(video_frames_np, axis=0) + video_tensor = torch.from_numpy(video_frames_np).permute(0, 3, 1, 2).float() / 127.5 - 1.0 # Normalize to [-1, 1] + + # Permute to C, F, H, W + video_tensor = video_tensor.permute(1, 0, 2, 3) + logger.info(f"Loaded Fun-Control video tensor shape: {video_tensor.shape}") + + return video_tensor + +def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]: + """setup scheduler for sampling + + Args: + args: command line arguments + config: model configuration + device: device to use + + Returns: + Tuple[Any, torch.Tensor]: (scheduler, timesteps) + """ + if args.sample_solver == "unipc": + scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False) + scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift) + timesteps = scheduler.timesteps + elif args.sample_solver == "dpm++": + scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift) + timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas) + elif args.sample_solver == "vanilla": + scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift) + scheduler.set_timesteps(args.infer_steps, device=device) + timesteps = scheduler.timesteps + + # FlowMatchDiscreteScheduler does not support generator argument in step method + org_step = scheduler.step + + def step_wrapper( + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None, # Add generator argument here + ): + # Call original step, ignoring generator if it doesn't accept it + try: + # Try calling with generator if the underlying class was updated + return org_step(model_output, timestep, sample, return_dict=return_dict, generator=generator) + except TypeError: + # Fallback to calling without generator + logger.warning("Scheduler step does not support generator argument, proceeding without it.") + return org_step(model_output, timestep, sample, return_dict=return_dict) + + + scheduler.step = step_wrapper + else: + raise NotImplementedError(f"Unsupported solver: {args.sample_solver}") + + logger.info(f"Using scheduler: {args.sample_solver}, timesteps shape: {timesteps.shape}") + return scheduler, timesteps + + +def run_sampling( + model: WanModel, + noise: torch.Tensor, # This might be pure noise (T2V/I2V) or mixed noise+latent (V2V) + scheduler: Any, + timesteps: torch.Tensor, # Might be a subset for V2V + args: argparse.Namespace, + inputs: Tuple[dict, dict], + device: torch.device, + seed_g: torch.Generator, + accelerator: Accelerator, + previewer: Optional[LatentPreviewer] = None, # Add previewer argument + use_cpu_offload: bool = True, # Example parameter, adjust as needed + preview_suffix: Optional[str] = None # <<< ADD suffix argument +) -> torch.Tensor: + """run sampling loop (Denoising) + Args: + model: dit model + noise: initial latent state (pure noise or mixed noise/video latent) + scheduler: scheduler for sampling + timesteps: time steps for sampling (can be subset for V2V) + args: command line arguments + inputs: model input dictionaries (arg_c, arg_null) containing context etc. + device: device to use + seed_g: random generator + accelerator: Accelerator instance + previewer: LatentPreviewer instance or None # Added description + use_cpu_offload: Whether to offload tensors to CPU during processing (example) + preview_suffix: Unique suffix for preview files to avoid conflicts in concurrent runs. + Returns: + torch.Tensor: generated latent + """ + arg_c, arg_null = inputs + + latent = noise # Initialize latent state + # Determine storage device (CPU if offloading, otherwise compute device) + latent_storage_device = torch.device("cpu") if use_cpu_offload else device + latent = latent.to(latent_storage_device) # Move initial state to storage device + + # cfg skip logic + apply_cfg_array = [] + num_timesteps = len(timesteps) + + # ... (keep existing cfg skip logic) ... + if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None: + # Calculate thresholds based on cfg_apply_ratio + apply_steps = int(num_timesteps * args.cfg_apply_ratio) + + if args.cfg_skip_mode == "early": + start_index = num_timesteps - apply_steps; end_index = num_timesteps + elif args.cfg_skip_mode == "late": + start_index = 0; end_index = apply_steps + elif args.cfg_skip_mode == "early_late": + start_index = (num_timesteps - apply_steps) // 2; end_index = start_index + apply_steps + elif args.cfg_skip_mode == "middle": + skip_steps = num_timesteps - apply_steps + middle_start = (num_timesteps - skip_steps) // 2; middle_end = middle_start + skip_steps + else: # Includes "alternate" - handled inside loop + start_index = 0; end_index = num_timesteps # Default range for alternate + + w = 0.0 # For alternate mode + for step_idx in range(num_timesteps): + apply = True # Default + if args.cfg_skip_mode == "alternate": + w += args.cfg_apply_ratio; apply = w >= 1.0 + if apply: w -= 1.0 + elif args.cfg_skip_mode == "middle": + apply = not (step_idx >= middle_start and step_idx < middle_end) + elif args.cfg_skip_mode != "none": # early, late, early_late + apply = step_idx >= start_index and step_idx < end_index + + apply_cfg_array.append(apply) + + pattern = ["A" if apply else "S" for apply in apply_cfg_array] + pattern = "".join(pattern) + logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, steps: {num_timesteps}, pattern: {pattern}") + else: + # Apply CFG on all steps + apply_cfg_array = [True] * num_timesteps + + # SLG (Skip Layer Guidance) setup + apply_slg_global = args.slg_layers is not None and args.slg_mode is not None + slg_start_step = int(args.slg_start * num_timesteps) + slg_end_step = int(args.slg_end * num_timesteps) + + logger.info(f"Starting sampling loop for {num_timesteps} steps.") + for i, t in enumerate(tqdm(timesteps)): + # Prepare input for the model (move latent to compute device) + # Latent should be [B, C, F, H, W] or [C, F, H, W] + latent_on_device = latent.to(device) + + # FIX: Check if latent_on_device has too many dimensions and fix it + # The model expects input x as a list of tensors with shape [C, F, H, W] + # This adjustment seems specific to a potential bug elsewhere, keep it if needed. + if len(latent_on_device.shape) > 5: + while len(latent_on_device.shape) > 5: + latent_on_device = latent_on_device.squeeze(0) + logger.debug(f"Adjusted latent shape for model input: {latent_on_device.shape}") + + # The model expects the latent input 'x' as a list: [tensor] + # If batch dimension is present, we need to split the tensor into a list of tensors + if len(latent_on_device.shape) == 5: + # Has batch dimension [B, C, F, H, W] + latent_model_input_list = [latent_on_device[i] for i in range(latent_on_device.shape[0])] + elif len(latent_on_device.shape) == 4: + # No batch dimension [C, F, H, W] + latent_model_input_list = [latent_on_device] + else: + # Handle unexpected shape + raise ValueError(f"Latent tensor has unexpected shape {latent_on_device.shape} for model input.") + + timestep = torch.stack([t]).to(device) # Ensure timestep is a tensor on device + + with accelerator.autocast(), torch.no_grad(): + # --- (Keep existing prediction logic: cond, uncond, slg, cfg) --- + # 1. Predict conditional noise estimate + noise_pred_cond = model(latent_model_input_list, t=timestep, **arg_c)[0] + # Move result to storage device early if offloading to potentially save VRAM during uncond/slg pred + noise_pred_cond = noise_pred_cond.to(latent_storage_device) + + # 2. Predict unconditional noise estimate (potentially with SLG) + apply_cfg = apply_cfg_array[i] + if apply_cfg: + apply_slg_step = apply_slg_global and (i >= slg_start_step and i < slg_end_step) + slg_indices_for_call = args.slg_layers if apply_slg_step else None + uncond_input_args = arg_null + + if apply_slg_step and args.slg_mode == "original": + noise_pred_uncond = model(latent_model_input_list, t=timestep, **uncond_input_args)[0].to(latent_storage_device) + skip_layer_out = model(latent_model_input_list, t=timestep, skip_block_indices=slg_indices_for_call, **uncond_input_args)[0].to(latent_storage_device) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out) + + elif apply_slg_step and args.slg_mode == "uncond": + noise_pred_uncond = model(latent_model_input_list, t=timestep, skip_block_indices=slg_indices_for_call, **uncond_input_args)[0].to(latent_storage_device) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: # Regular CFG + noise_pred_uncond = model(latent_model_input_list, t=timestep, **uncond_input_args)[0].to(latent_storage_device) + noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + # CFG is skipped, use conditional prediction directly + noise_pred = noise_pred_cond + # --- End prediction logic --- + + # 3. Compute previous sample state with the scheduler + # Ensure noise_pred and latent_on_device have matching batch dimensions for scheduler + if len(noise_pred.shape) < len(latent_on_device.shape): + noise_pred = noise_pred.unsqueeze(0) # Add batch dim if missing ([C,F,H,W]->[1,C,F,H,W]) + elif len(noise_pred.shape) > len(latent_on_device.shape): + # This shouldn't happen if latent_on_device handles batch correctly + logger.warning(f"Noise pred shape {noise_pred.shape} has more dims than latent {latent_on_device.shape}") + + # Scheduler expects noise_pred [B, C, F, H, W] and sample [B, C, F, H, W] + # latent_on_device should already have the batch dim handled by the logic above + scheduler_output = scheduler.step( + noise_pred.to(device), # Ensure noise_pred is on compute device for step + t, + latent_on_device, # Pass the tensor (with batch dim) on compute device + return_dict=False, + generator=seed_g + ) + prev_latent = scheduler_output[0] # Get the new latent state [B, C, F, H, W] + + # 4. Update latent state (move back to storage device) + latent = prev_latent.to(latent_storage_device) + + # --- Latent Preview Call --- + # Preview the state *after* step 'i' is completed + if previewer is not None and (i + 1) % args.preview == 0 and (i + 1) < num_timesteps: + try: + logger.debug(f"Generating preview for step {i + 1}") + # Pass the *resulting* latent from this step (prev_latent). + # Ensure it's on the compute device for the previewer call. + # LatentPreviewer handles internal device management. + # Need to pass without batch dim if previewer expects [C, F, H, W] + # Check LatentPreviewer.preview expects [C, F, H, W] + if len(prev_latent.shape) == 5: + preview_latent_input = prev_latent.squeeze(0) # Remove batch dim + else: + preview_latent_input = prev_latent # Assume already [C, F, H, W] + + # Pass the latent on the main compute device + print(f"DEBUG run_sampling: Step {i}, prev_latent shape: {prev_latent.shape}, preview_latent_input shape: {preview_latent_input.shape}") + previewer.preview(preview_latent_input.to(device), i, preview_suffix=preview_suffix) # Pass 0-based index 'i' + except Exception as e: + logger.error(f"Error during latent preview generation at step {i + 1}: {e}", exc_info=True) + # Optional: Disable previewer after first error to avoid repeated logs/errors + # logger.warning("Disabling latent preview due to error.") + # previewer = None + + # Return the final denoised latent (should be on storage device) + logger.info("Sampling loop finished.") + return latent + +def generate(args: argparse.Namespace) -> Optional[torch.Tensor]: + """main function for generation pipeline (T2V, I2V, V2V) + + Args: + args: command line arguments + + Returns: + Optional[torch.Tensor]: generated latent tensor [B, C, F, H, W], or None if only saving merged model. + """ + device = torch.device(args.device) + cfg = WAN_CONFIGS[args.task] + + # --- Determine Mode --- + is_i2v = args.image_path is not None + is_v2v = args.video_path is not None + is_fun_control = args.control_path is not None and cfg.is_fun_control + is_t2v = not is_i2v and not is_v2v and not is_fun_control + + if is_v2v: logger.info(f"Running Video-to-Video (V2V) inference with strength {args.strength}") + elif is_i2v: logger.info(f"Running Image-to-Video (I2V) inference") + elif is_fun_control: logger.info(f"Running Text-to-Video with Fun-Control") # Note: FunControl can also be I2V if image_path is given + else: logger.info(f"Running Text-to-Video (T2V) inference") + + # --- Data Types --- + dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16 + if dit_dtype.itemsize == 1: # FP8 weights loaded + dit_dtype = torch.bfloat16 # Use bfloat16 for computation + if args.fp8_scaled: + raise ValueError("Cannot use --fp8_scaled with pre-quantized FP8 weights.") + dit_weight_dtype = None # Weights are already FP8 + elif args.fp8_scaled: + dit_weight_dtype = None # Optimize later + elif args.fp8: + dit_weight_dtype = torch.float8_e4m3fn + else: + dit_weight_dtype = dit_dtype # Use compute dtype for weights + + vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else (torch.bfloat16 if dit_dtype == torch.bfloat16 else torch.float16) + logger.info( + f"Using device: {device}, DiT compute: {dit_dtype}, DiT weight: {dit_weight_dtype or 'Mixed (FP8 Scaled)' if args.fp8_scaled else dit_dtype}, VAE: {vae_dtype}, T5 FP8: {args.fp8_t5}" + ) + + # --- Accelerator --- + mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" + accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) + + # --- Seed --- + seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + args.seed = seed # Store seed back for metadata + logger.info(f"Using seed: {seed}") + + # --- Load VAE (if needed for input processing) --- + vae = None + # VAE is needed early for V2V, I2V (both types), and FunControl T2V + needs_vae_early = is_v2v or is_i2v or (is_fun_control and is_t2v) or (is_fun_control and is_i2v) # Refined condition + if needs_vae_early: + vae = load_vae(args, cfg, device, vae_dtype) + # Keep VAE on specified device for now, will be moved as needed + + # --- Prepare Inputs --- + noise = None + context = None + context_null = None + inputs = None + video_latents = None # For V2V mixing + + if is_v2v: + # Standard V2V path (mutually exclusive with FunControl) + # 1. Load and prepare video + video_frames_np, actual_frames_loaded = load_video( + args.video_path, + start_frame=0, + num_frames=args.video_length, # Can be None + bucket_reso=tuple(args.video_size) + ) + if actual_frames_loaded == 0: + raise ValueError(f"Could not load any frames from video: {args.video_path}") + + # Update video_length if it was None or if fewer frames were loaded + if args.video_length is None or actual_frames_loaded < args.video_length: + logger.info(f"Updating video_length based on loaded frames: {actual_frames_loaded}") + args.video_length = actual_frames_loaded + # Re-check height/width/length now that length is known + height, width, video_length = check_inputs(args) + args.video_size = [height, width] # Update args + else: + video_length = args.video_length # Use the specified length + + # Convert frames to tensor [1, C, F, H, W], range [0, 1] + video_tensor = torch.from_numpy(np.stack(video_frames_np, axis=0)) #[F,H,W,C] + video_tensor = video_tensor.permute(0, 3, 1, 2).float() / 255.0 #[F,C,H,W] + video_tensor = video_tensor.permute(1, 0, 2, 3).unsqueeze(0) #[1,C,F,H,W] + + # 2. Encode video to latents + video_latents = encode_video_to_latents(video_tensor, vae, device, dit_dtype, args) # Use DiT dtype for latents + del video_tensor # Free pixel video memory + clean_memory_on_device(device) + + # 3. Prepare V2V inputs (noise matching latent shape, context, etc.) + noise, context, context_null, inputs = prepare_v2v_inputs(args, cfg, accelerator, device, video_latents) + + elif is_i2v: + # I2V path (handles both standard and FunControl internally based on config) + if args.video_length is None: + raise ValueError("video_length must be specified for I2V mode.") + noise, context, context_null, _, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae) + # Note: prepare_i2v_inputs moves VAE to CPU/cache after use + + elif is_fun_control: # Pure FunControl T2V (no image input unless using start/end image) + if args.video_length is None: + raise ValueError("video_length must be specified for Fun-Control T2V mode.") + noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae) + # Note: prepare_t2v_inputs moves VAE to CPU/cache if it used it + + elif is_t2v: # Standard T2V + if args.video_length is None: + raise ValueError("video_length must be specified for standard T2V mode.") + noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, None) # Pass None for VAE + + + # At this point, VAE should be on CPU/cache unless still needed for decoding + # If VAE wasn't loaded early (standard T2V), vae is still None + + # --- Load DiT Model --- + model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) # Pass is_i2v flag (for potential internal use) + + # --- Merge LoRA --- + if args.lora_weight is not None and len(args.lora_weight) > 0: + merge_lora_weights(model, args, device) + if args.save_merged_model: + logger.info("Merged model saved. Exiting without generation.") + # Clean up resources if exiting early + if 'model' in locals(): del model + if 'vae' in locals() and vae is not None: del vae + clean_memory_on_device(device) + return None # Exit early + + # --- Optimize Model (FP8, Swapping, Compile) --- + optimize_model(model, args, device, dit_dtype, dit_weight_dtype) + + # --- Setup Scheduler & Timesteps --- + scheduler, timesteps = setup_scheduler(args, cfg, device) + + # --- Prepare for Sampling --- + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + + # `latent` here is the initial state *before* the sampling loop starts + latent = noise # Start with noise (already shaped correctly for T2V/I2V/V2V) + + # --- V2V Strength Adjustment --- + if is_v2v and args.strength < 1.0: + if video_latents is None: + raise RuntimeError("video_latents not available for V2V strength adjustment.") + + # Calculate number of inference steps based on strength + num_inference_steps = max(1, int(args.infer_steps * args.strength)) + logger.info(f"V2V Strength: {args.strength}, adjusting inference steps from {args.infer_steps} to {num_inference_steps}") + + # Get starting timestep index and value + t_start_idx = len(timesteps) - num_inference_steps + if t_start_idx < 0: t_start_idx = 0 # Ensure non-negative index + t_start = timesteps[t_start_idx] # Timestep value at the start of sampling + + # Mix noise and video latents based on starting timestep using scheduler + # Ensure video_latents are on the same device and dtype as noise for mixing + video_latents = video_latents.to(device=latent.device, dtype=latent.dtype) + + if latent.shape != video_latents.shape: + logger.error(f"Noise shape {latent.shape} does not match video latent shape {video_latents.shape} for V2V mixing. Cannot proceed.") + raise ValueError("Shape mismatch between noise and video latents in V2V.") + + # Use scheduler's add_noise for better mixing + latent = scheduler.add_noise(video_latents, latent, t_start.unsqueeze(0)) + logger.info(f"Mixed video latents and noise using scheduler at timestep {t_start.item():.1f}") + + # Use only the required subset of timesteps + timesteps = timesteps[t_start_idx:] + logger.info(f"Using last {len(timesteps)} timesteps for V2V sampling.") + else: + logger.info(f"Using full {len(timesteps)} timesteps for sampling.") + # Latent remains the initial noise + + # --- Initialize Latent Previewer --- # ADDED SECTION + previewer = None + if LatentPreviewer is not None and args.preview is not None and args.preview > 0: + logger.info(f"Initializing Latent Previewer (every {args.preview} steps)...") + try: + # Use the initial 'latent' state which might be pure noise or mixed V2V start + # Pass without batch dim [C, F, H, W] + initial_latent_for_preview = latent.clone().squeeze(0) + previewer = LatentPreviewer(args, initial_latent_for_preview, timesteps, device, dit_dtype, model_type="wan") + logger.info("Latent Previewer initialized successfully.") + except Exception as e: + logger.error(f"Failed to initialize Latent Previewer: {e}", exc_info=True) + previewer = None # Ensure it's None if init fails + # --- END ADDED SECTION --- + + # --- Run Sampling Loop --- + logger.info("Starting denoising sampling loop...") + final_latent = run_sampling( + model, + latent, # Initial state (noise or mixed) + scheduler, + timesteps, # Full or partial timesteps + args, + inputs, # Contains context etc. + device, + seed_g, + accelerator, + previewer=previewer, # MODIFIED: Pass the previewer instance + use_cpu_offload=(args.blocks_to_swap > 0), # Example: offload if swapping + preview_suffix=args.preview_suffix # <<< Pass the suffix from args + ) + + # --- Cleanup --- + del model + if 'scheduler' in locals(): del scheduler + if 'context' in locals(): del context + if 'context_null' in locals(): del context_null + if 'inputs' in locals(): del inputs # Free memory from encoded inputs + if video_latents is not None: del video_latents + # previewer instance will be garbage collected + + synchronize_device(device) + + if args.blocks_to_swap > 0: + logger.info("Waiting for 5 seconds to ensure block swap finishes...") + time.sleep(5) + + gc.collect() + clean_memory_on_device(device) + + # Store VAE instance in args for decoding function (if it exists) + args._vae = vae # Store VAE instance (might be None if T2V) + + # Return latent with batch dimension [1, C, F, H, W] + # final_latent is potentially on CPU if use_cpu_offload=True + if len(final_latent.shape) == 4: # If run_sampling returned [C, F, H, W] + final_latent = final_latent.unsqueeze(0) + + return final_latent + +def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor: + """decode latent tensor to video frames + + Args: + latent: latent tensor [B, C, F, H, W] + args: command line arguments (contains _vae instance) + cfg: model configuration + + Returns: + torch.Tensor: decoded video tensor [B, C, F, H, W], range [0, 1], on CPU + """ + device = torch.device(args.device) + + # Load VAE model or use the one from the generation pipeline + vae = None + if hasattr(args, "_vae") and args._vae is not None: + vae = args._vae + logger.info("Using VAE instance from generation pipeline for decoding.") + else: + # Need to load VAE if it wasn't used/stored (e.g., pure T2V or latent input mode) + logger.info("Loading VAE for decoding...") + # Attempt to detect DiT dtype even if DiT wasn't loaded (e.g., latent mode) + # Fallback to bfloat16 if DiT path isn't available + try: + dit_dtype_ref = detect_wan_sd_dtype(args.dit) if args.dit else torch.bfloat16 + except: # Handle cases where DiT path is invalid or missing in latent mode + dit_dtype_ref = torch.bfloat16 + logger.warning("Could not detect DiT dtype for VAE decoding, defaulting to bfloat16.") + + vae_dtype_decode = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else (torch.bfloat16 if dit_dtype_ref == torch.bfloat16 else torch.float16) + vae = load_vae(args, cfg, device, vae_dtype_decode) + args._vae = vae # Store it in case needed again? + + # Ensure VAE is on device for decoding + vae.to_device(device) + + logger.info(f"Decoding video from latents: shape {latent.shape}, dtype {latent.dtype}") + # Ensure latent is on the correct device and expected dtype for VAE + latent_decode = latent.to(device=device, dtype=vae.dtype) + + # VAE decode expects list of [C, F, H, W] or a single [B, C, F, H, W] + # WanVAE wrapper seems to handle the list internally now? Check its decode method. + # Assuming it takes [B, C, F, H, W] directly or handles the list internally. + videos = None + with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): + # WanVAE.decode returns a list of decoded videos [C, F, H, W] + decoded_list = vae.decode(latent_decode) # Pass the batch tensor + if decoded_list and len(decoded_list) > 0: + # Stack list back into batch dimension: B, C, F, H, W + videos = torch.stack(decoded_list, dim=0) + else: + raise RuntimeError("VAE decoding failed or returned empty list.") + + + # Move VAE back to CPU/cache + vae.to_device(args.vae_cache_cpu if args.vae_cache_cpu else "cpu") + clean_memory_on_device(device) + + logger.info(f"Decoded video shape: {videos.shape}") + + # Post-processing: trim tail frames, convert to float32 CPU, scale to [0, 1] + if args.trim_tail_frames > 0: + logger.info(f"Trimming last {args.trim_tail_frames} frames.") + videos = videos[:, :, : -args.trim_tail_frames, :, :] + + # Scale from [-1, 1] (VAE output range) to [0, 1] (video save range) + videos = (videos + 1.0) / 2.0 + videos = torch.clamp(videos, 0.0, 1.0) + + # Move to CPU and convert to float32 for saving + video_final = videos.cpu().to(torch.float32) + logger.info(f"Decoding complete. Final video tensor shape: {video_final.shape}") + + return video_final + + +def save_output( + video_tensor: torch.Tensor, # Expects [B, C, F, H, W] range [0, 1] + args: argparse.Namespace, + original_base_names: Optional[List[str]] = None, + latent_to_save: Optional[torch.Tensor] = None # Optional latent [B, C, F, H, W] +) -> None: + """save output video, images, or latent + + Args: + video_tensor: decoded video tensor [B, C, F, H, W], range [0, 1] + args: command line arguments + original_base_names: original base names (if latents are loaded from files) + latent_to_save: optional raw latent tensor to save + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") + + seed = args.seed + # Get dimensions from the *decoded* video tensor + batch_size, channels, video_length, height, width = video_tensor.shape + + base_name = f"{time_flag}_{seed}" + if original_base_names: + # Use first original name if loading multiple latents (though currently unsupported) + base_name += f"_{original_base_names[0]}" + + # --- Save Latent --- + if (args.output_type == "latent" or args.output_type == "both") and latent_to_save is not None: + latent_path = os.path.join(save_path, f"{base_name}_latent.safetensors") + logger.info(f"Saving latent tensor shape: {latent_to_save.shape}") + + metadata = {} + if not args.no_metadata: + # Try to get model paths robustly for metadata + cfg = WAN_CONFIGS.get(args.task) # Get config if task exists + dit_path_meta = "N/A" + if args.dit: dit_path_meta = args.dit + elif cfg and cfg.dit_checkpoint and args.ckpt_dir: dit_path_meta = os.path.join(args.ckpt_dir, cfg.dit_checkpoint) + elif cfg and cfg.dit_checkpoint: dit_path_meta = cfg.dit_checkpoint # Use relative path if no ckpt_dir + + vae_path_meta = "N/A" + if args.vae: vae_path_meta = args.vae + elif cfg and cfg.vae_checkpoint and args.ckpt_dir: vae_path_meta = os.path.join(args.ckpt_dir, cfg.vae_checkpoint) + elif cfg and cfg.vae_checkpoint: vae_path_meta = cfg.vae_checkpoint # Use relative path if no ckpt_dir + + metadata = { + "prompt": f"{args.prompt}", + "negative_prompt": f"{args.negative_prompt or ''}", + "seeds": f"{seed}", + "height": f"{height}", # Use decoded height/width + "width": f"{width}", + "video_length": f"{video_length}", # Use decoded length + "infer_steps": f"{args.infer_steps}", + "guidance_scale": f"{args.guidance_scale}", + "flow_shift": f"{args.flow_shift}", + "task": f"{args.task}", + "dit_model": f"{dit_path_meta}", + "vae_model": f"{vae_path_meta}", + # Add V2V/I2V specific info + "mode": "V2V" if args.video_path else ("I2V" if args.image_path else ("FunControl" if args.control_path else "T2V")), + } + if args.video_path: metadata["v2v_strength"] = f"{args.strength}" + if args.image_path: metadata["i2v_image"] = f"{os.path.basename(args.image_path)}" + if args.end_image_path: metadata["i2v_end_image"] = f"{os.path.basename(args.end_image_path)}" + if args.control_path: + metadata["funcontrol_video"] = f"{os.path.basename(args.control_path)}" + metadata["funcontrol_weight"] = f"{args.control_weight}" + metadata["funcontrol_start"] = f"{args.control_start}" + metadata["funcontrol_end"] = f"{args.control_end}" + metadata["funcontrol_falloff"] = f"{args.control_falloff_percentage}" + # Add LoRA info if used + if args.lora_weight: + metadata["lora_weights"] = ", ".join([os.path.basename(p) for p in args.lora_weight]) + metadata["lora_multipliers"] = ", ".join(map(str, args.lora_multiplier)) + + + # Ensure latent is on CPU for saving + sd = {"latent": latent_to_save.cpu()} + try: + save_file(sd, latent_path, metadata=metadata) + logger.info(f"Latent saved to: {latent_path}") + except Exception as e: + logger.error(f"Failed to save latent file: {e}") + + + # --- Save Video or Images --- + if args.output_type == "video" or args.output_type == "both": + video_path = os.path.join(save_path, f"{base_name}.mp4") + # save_videos_grid expects [B, T, H, W, C], need to permute and rescale if needed + # Input video_tensor is [B, C, T, H, W], range [0, 1] + # save_videos_grid handles the rescale flag correctly if input is [0,1] + try: + save_videos_grid(video_tensor, video_path, fps=args.fps, rescale=False) # Pass rescale=False as tensor is already [0,1] + logger.info(f"Video saved to: {video_path}") + except Exception as e: + logger.error(f"Failed to save video file: {e}") + logger.error(f"Video tensor info: shape={video_tensor.shape}, dtype={video_tensor.dtype}, min={video_tensor.min()}, max={video_tensor.max()}") + + + elif args.output_type == "images": + image_save_dir = os.path.join(save_path, base_name) + os.makedirs(image_save_dir, exist_ok=True) + # save_images_grid expects [B, T, H, W, C], need to permute and rescale if needed + # Input video_tensor is [B, C, T, H, W], range [0, 1] + # save_images_grid handles the rescale flag correctly if input is [0,1] + try: + # Save as individual frames + save_images_grid(video_tensor, image_save_dir, "frame", rescale=False, save_individually=True) # Pass rescale=False + logger.info(f"Image frames saved to directory: {image_save_dir}") + except Exception as e: + logger.error(f"Failed to save image files: {e}") + + +def main(): + # --- Argument Parsing & Setup --- + args = parse_args() + + # Determine mode: generation or loading latents + latents_mode = args.latent_path is not None and len(args.latent_path) > 0 + + # Set device + device_str = args.device if args.device is not None else ("cuda" if torch.cuda.is_available() else "cpu") + args.device = torch.device(device_str) # Store device back in args + logger.info(f"Using device: {args.device}") + + generated_latent = None # To hold the generated latent if not in latents_mode + cfg = WAN_CONFIGS[args.task] # Get config early for potential use + height, width, video_length = None, None, None # Initialize dimensions + original_base_names = None # For naming output when loading latents + + if not latents_mode: + # --- Generation Mode (T2V, I2V, V2V, Fun-Control) --- + logger.info("Running in Generation Mode") + # Setup arguments (defaults, etc.) + args = setup_args(args) + # Validate inputs (initial check, V2V might refine length later) + height, width, video_length = check_inputs(args) + args.video_size = [height, width] # Ensure args reflect checked dimensions + args.video_length = video_length # May still be None for V2V + + # Determine specific mode string + mode_str = "Unknown" + if args.video_path: mode_str = "V2V" + elif args.image_path and args.control_path: mode_str = "FunControl-I2V" # FunControl overrides if control_path is present + elif args.control_path: mode_str = "FunControl-T2V" + elif args.image_path: mode_str = "I2V" + else: mode_str = "T2V" + + logger.info(f"Mode: {mode_str} (Task: {args.task})") + logger.info( + f"Initial settings: video size: {height}x{width}@{video_length or 'auto'} (HxW@F), fps: {args.fps}, " + f"infer_steps: {args.infer_steps}, guidance: {args.guidance_scale}, flow_shift: {args.flow_shift}" + ) + if mode_str == "V2V": logger.info(f"V2V Strength: {args.strength}") + if "FunControl" in mode_str: logger.info(f"FunControl Weight: {args.control_weight}, Start: {args.control_start}, End: {args.control_end}, Falloff: {args.control_falloff_percentage}") + + # Core generation pipeline + generated_latent = generate(args) # Returns [B, C, F, H, W] or None + + if args.save_merged_model: + logger.info("Exiting after saving merged model.") + return # Exit if only saving model + + if generated_latent is None: + logger.error("Generation failed or was skipped, exiting.") + return + + # Update dimensions based on the *actual* generated latent + # Latent shape might differ slightly from input request depending on VAE/model strides + _, lat_c, lat_f, lat_h, lat_w = generated_latent.shape + # Convert latent dimensions back to pixel dimensions for metadata/logging + pixel_height = lat_h * cfg.vae_stride[1] + pixel_width = lat_w * cfg.vae_stride[2] + pixel_frames = (lat_f - 1) * cfg.vae_stride[0] + 1 + logger.info(f"Generation complete. Latent shape: {generated_latent.shape} -> Pixel Video: {pixel_height}x{pixel_width}@{pixel_frames}") + # Use these derived pixel dimensions for saving metadata + height, width, video_length = pixel_height, pixel_width, pixel_frames + + + else: + # --- Latents Mode (Load and Decode) --- + logger.info("Running in Latent Loading Mode") + original_base_names = [] + latents_list = [] + seeds = [] # Try to recover seed from metadata + + # Currently only supporting one latent file input + if len(args.latent_path) > 1: + logger.warning("Loading multiple latent files is not fully supported for metadata merging. Using first file's info.") + + latent_path = args.latent_path[0] + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + loaded_latent = None + metadata = {} + seed = args.seed if args.seed is not None else random.randint(0, 2**32-1) # Default seed if none in metadata + + try: + if os.path.splitext(latent_path)[1].lower() != ".safetensors": + logger.warning("Loading non-safetensors latent file. Metadata might be missing.") + loaded_latent = torch.load(latent_path, map_location="cpu") + # Attempt to handle different save formats (dict vs raw tensor) + if isinstance(loaded_latent, dict): + if "latent" in loaded_latent: + loaded_latent = loaded_latent["latent"] + elif "state_dict" in loaded_latent: # Might be a full model checkpoint by mistake + raise ValueError("Loaded file appears to be a model checkpoint, not a latent tensor.") + else: # Try the first value if it's a tensor + first_key = next(iter(loaded_latent)) + if isinstance(loaded_latent[first_key], torch.Tensor): + loaded_latent = loaded_latent[first_key] + else: + raise ValueError("Could not find latent tensor in loaded dictionary.") + elif not isinstance(loaded_latent, torch.Tensor): + raise ValueError(f"Loaded file content is not a tensor or expected dictionary format: {type(loaded_latent)}") + + + else: + # Load latent tensor + loaded_latent = load_file(latent_path, device="cpu")["latent"] + # Load metadata + with safe_open(latent_path, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + logger.info(f"Loaded metadata: {metadata}") + + # Restore args from metadata if available AND not overridden by command line + # Command line args take precedence if provided + if args.seed is None and "seeds" in metadata: seed = int(metadata["seeds"]) + if "prompt" in metadata: args.prompt = args.prompt or metadata["prompt"] # Keep command line if provided + if "negative_prompt" in metadata: args.negative_prompt = args.negative_prompt or metadata["negative_prompt"] + # We need height/width/length to decode, so always load if available + if "height" in metadata: height = int(metadata["height"]) + if "width" in metadata: width = int(metadata["width"]) + if "video_length" in metadata: video_length = int(metadata["video_length"]) + # Restore other relevant args if not set by user + if args.guidance_scale == 5.0 and "guidance_scale" in metadata: args.guidance_scale = float(metadata["guidance_scale"]) # Assuming 5.0 is default + if args.infer_steps is None and "infer_steps" in metadata: args.infer_steps = int(metadata["infer_steps"]) + if args.flow_shift is None and "flow_shift" in metadata: args.flow_shift = float(metadata["flow_shift"]) + if "task" in metadata: args.task = args.task or metadata["task"] # Restore task if not specified + # FunControl specific args + if "funcontrol_weight" in metadata: args.control_weight = args.control_weight or float(metadata["funcontrol_weight"]) + if "funcontrol_start" in metadata: args.control_start = args.control_start or float(metadata["funcontrol_start"]) + if "funcontrol_end" in metadata: args.control_end = args.control_end or float(metadata["funcontrol_end"]) + if "funcontrol_falloff" in metadata: args.control_falloff_percentage = args.control_falloff_percentage or float(metadata["funcontrol_falloff"]) + # V2V specific args + if "v2v_strength" in metadata: args.strength = args.strength or float(metadata["v2v_strength"]) + + # Update config based on restored task + cfg = WAN_CONFIGS[args.task] + + seeds.append(seed) + latents_list.append(loaded_latent) + logger.info(f"Loaded latent from {latent_path}. Shape: {loaded_latent.shape}, dtype: {loaded_latent.dtype}") + + except Exception as e: + logger.error(f"Failed to load latent file {latent_path}: {e}") + return + + if not latents_list: + logger.error("No latent tensors were loaded.") + return + + # Stack latents (currently just one) - ensure batch dimension + generated_latent = torch.stack(latents_list, dim=0) # [B, C, F, H, W] + if len(generated_latent.shape) != 5: + # Maybe saved without batch dim? Try adding it. + if len(generated_latent.shape) == 4: + logger.warning(f"Loaded latent has 4 dimensions {generated_latent.shape}. Adding batch dimension.") + generated_latent = generated_latent.unsqueeze(0) + else: + raise ValueError(f"Loaded latent has incorrect shape: {generated_latent.shape}. Expected 4 or 5 dimensions.") + + # Set seed from metadata (or default) + args.seed = seeds[0] + + # Infer pixel dimensions from latent shape and config if not available in metadata + _, _, lat_f, lat_h, lat_w = generated_latent.shape # Get dimensions from loaded latent + if height is None or width is None or video_length is None: + logger.warning("Dimensions not found in metadata, inferring from latent shape.") + height = lat_h * cfg.vae_stride[1] + width = lat_w * cfg.vae_stride[2] + video_length = (lat_f - 1) * cfg.vae_stride[0] + 1 + logger.info(f"Inferred pixel dimensions: {height}x{width}@{video_length}") + # Store final dimensions in args for consistency + args.video_size = [height, width] + args.video_length = video_length + + # --- Decode and Save --- + if generated_latent is not None: + # Decode latent to video tensor [B, C, F, H, W], range [0, 1] + decoded_video = decode_latent(generated_latent, args, cfg) + + # Save the output (latent and/or video/images) + save_output( + decoded_video, + args, + original_base_names=original_base_names, + latent_to_save=generated_latent if (args.output_type == "latent" or args.output_type == "both") else None + ) + else: + logger.error("No latent available for decoding and saving.") + + logger.info("Done!") + + +if __name__ == "__main__": + main() +# --- END OF FILE wanFUN_generate_video.py --- \ No newline at end of file