jbilcke-hf HF Staff commited on
Commit
2ba9257
·
1 Parent(s): 267b599

improve UI to support conditioning

Browse files
CLAUDE.md CHANGED
@@ -1,18 +1,26 @@
 
 
 
 
1
  # Video Model Studio - Guidelines for Claude
2
 
3
  ## Build & Run Commands
4
- - Setup: `./setup.sh` (with flash attention) or `./setup_no_captions.sh` (without)
5
- - Run: `./run.sh` or `python3.10 app.py`
6
- - Test: `python3 tests/test_dataset.py`
7
- - Single model test: `bash tests/scripts/dummy_cogvideox_lora.sh`
 
 
8
 
9
  ## Code Style
10
  - Python version: 3.10 (required for flash-attention compatibility)
11
- - Type hints: Use typing module annotations for all functions
12
- - Docstrings: Google style with Args/Returns sections
13
- - Error handling: Use try/except with specific exceptions, log errors
14
  - Imports: Group standard lib, third-party, and project imports
15
  - Naming: snake_case for functions/variables, PascalCase for classes
16
  - Use Path objects from pathlib instead of string paths
17
- - Format utility functions: Extract reusable logic to separate functions
18
- - Environment variables: Use parse_bool_env for boolean env vars
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
  # Video Model Studio - Guidelines for Claude
6
 
7
  ## Build & Run Commands
8
+ - Setup: `./setup.sh` (with flash attention) or `./degraded_setup.sh` (without flash-attention)
9
+ - Run: `./run.sh` or `python3.10 app.py`
10
+ - Test:
11
+ - Full test: `python3 tests/test_dataset.py`
12
+ - Single model test: `bash tests/scripts/dummy_cogvideox_lora.sh` (or other model variants)
13
+ - Run test suite: `bash tests/test_model_runs_minimally_lora.sh`
14
 
15
  ## Code Style
16
  - Python version: 3.10 (required for flash-attention compatibility)
17
+ - Type hints: Use typing module annotations for all functions (from typing import Any, Optional, Dict, List, Union, Tuple)
18
+ - Docstrings: Google style with Args/Returns sections for all functions
19
+ - Error handling: Use try/except with specific exceptions, log errors appropriately
20
  - Imports: Group standard lib, third-party, and project imports
21
  - Naming: snake_case for functions/variables, PascalCase for classes
22
  - Use Path objects from pathlib instead of string paths
23
+ - Extract reusable logic to separate utility functions
24
+ - Environment variables: Use parse_bool_env for boolean env vars
25
+ - Logging: Use the logging module with appropriate log levels (DEBUG, INFO, WARNING, ERROR)
26
+ - UI components: Organize in tabs and use consistent naming for components dict
degraded_requirements.txt CHANGED
@@ -9,8 +9,16 @@ diffusers @ git+https://github.com/huggingface/diffusers.git@main
9
  imageio
10
  imageio-ffmpeg
11
 
 
 
 
 
 
 
12
  #flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
13
 
 
 
14
  # for youtube video download
15
  pytube
16
  pytubefix
 
9
  imageio
10
  imageio-ffmpeg
11
 
12
+ #--------------- MACOS HACKS ----------------
13
+
14
+ # use eva-decord for better compatiblity on macOS
15
+ eva-decord
16
+
17
+ # don't install flash attention on macOS
18
  #flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
19
 
20
+ #--------------- / MACOS HACKS --------------
21
+
22
  # for youtube video download
23
  pytube
24
  pytubefix
vms/config.py CHANGED
@@ -205,7 +205,9 @@ MODEL_TYPES = {
205
  # Training types
206
  TRAINING_TYPES = {
207
  "LoRA Finetune": "lora",
208
- "Full Finetune": "full-finetune"
 
 
209
  }
210
 
211
  # Model versions for each model type
@@ -288,6 +290,13 @@ DEFAULT_NB_LR_WARMUP_STEPS = math.ceil(0.20 * DEFAULT_NB_TRAINING_STEPS) # 20%
288
  # Whether to automatically restart a training job after a server reboot or not
289
  DEFAULT_AUTO_RESUME = False
290
 
 
 
 
 
 
 
 
291
  # For validation
292
  DEFAULT_VALIDATION_NB_STEPS = 50
293
  DEFAULT_VALIDATION_HEIGHT = 512
@@ -468,6 +477,69 @@ TRAINING_PRESETS = {
468
  "num_gpus": DEFAULT_NUM_GPUS,
469
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
470
  "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  }
472
  }
473
 
 
205
  # Training types
206
  TRAINING_TYPES = {
207
  "LoRA Finetune": "lora",
208
+ "Full Finetune": "full-finetune",
209
+ "Control LoRA": "control-lora",
210
+ "Control Full Finetune": "control-full-finetune"
211
  }
212
 
213
  # Model versions for each model type
 
290
  # Whether to automatically restart a training job after a server reboot or not
291
  DEFAULT_AUTO_RESUME = False
292
 
293
+ # Control training defaults
294
+ DEFAULT_CONTROL_TYPE = "canny"
295
+ DEFAULT_TRAIN_QK_NORM = False
296
+ DEFAULT_FRAME_CONDITIONING_TYPE = "full"
297
+ DEFAULT_FRAME_CONDITIONING_INDEX = 0
298
+ DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK = False
299
+
300
  # For validation
301
  DEFAULT_VALIDATION_NB_STEPS = 50
302
  DEFAULT_VALIDATION_HEIGHT = 512
 
477
  "num_gpus": DEFAULT_NUM_GPUS,
478
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
479
  "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
480
+ },
481
+ "Wan-2.1-I2V (Control LoRA)": {
482
+ "model_type": "wan",
483
+ "training_type": "control-lora",
484
+ "lora_rank": "32",
485
+ "lora_alpha": "32",
486
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
487
+ "batch_size": DEFAULT_BATCH_SIZE,
488
+ "learning_rate": 5e-5,
489
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
490
+ "training_buckets": SMALL_TRAINING_BUCKETS,
491
+ "flow_weighting_scheme": "logit_normal",
492
+ "num_gpus": DEFAULT_NUM_GPUS,
493
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
494
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
495
+ "control_type": "custom",
496
+ "train_qk_norm": True,
497
+ "frame_conditioning_type": "index",
498
+ "frame_conditioning_index": 0,
499
+ "frame_conditioning_concatenate_mask": True,
500
+ "description": "Image-conditioned video generation with LoRA adapters"
501
+ },
502
+ "LTX-Video (Control LoRA)": {
503
+ "model_type": "ltx_video",
504
+ "training_type": "control-lora",
505
+ "lora_rank": "128",
506
+ "lora_alpha": "128",
507
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
508
+ "batch_size": DEFAULT_BATCH_SIZE,
509
+ "learning_rate": DEFAULT_LEARNING_RATE,
510
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
511
+ "training_buckets": SMALL_TRAINING_BUCKETS,
512
+ "flow_weighting_scheme": "logit_normal",
513
+ "num_gpus": DEFAULT_NUM_GPUS,
514
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
515
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
516
+ "control_type": "custom",
517
+ "train_qk_norm": True,
518
+ "frame_conditioning_type": "index",
519
+ "frame_conditioning_index": 0,
520
+ "frame_conditioning_concatenate_mask": True,
521
+ "description": "Image-conditioned video generation with LoRA adapters"
522
+ },
523
+ "HunyuanVideo (Control LoRA)": {
524
+ "model_type": "hunyuan_video",
525
+ "training_type": "control-lora",
526
+ "lora_rank": "128",
527
+ "lora_alpha": "128",
528
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
529
+ "batch_size": DEFAULT_BATCH_SIZE,
530
+ "learning_rate": 2e-5,
531
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
532
+ "training_buckets": SMALL_TRAINING_BUCKETS,
533
+ "flow_weighting_scheme": "none",
534
+ "num_gpus": DEFAULT_NUM_GPUS,
535
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
536
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
537
+ "control_type": "custom",
538
+ "train_qk_norm": True,
539
+ "frame_conditioning_type": "index",
540
+ "frame_conditioning_index": 0,
541
+ "frame_conditioning_concatenate_mask": True,
542
+ "description": "Image-conditioned video generation with HunyuanVideo and LoRA adapters"
543
  }
544
  }
545
 
vms/ui/models/tabs/drafts_tab.py CHANGED
@@ -88,7 +88,7 @@ class DraftsTab(BaseTab):
88
  edit_btn.click(
89
  fn=lambda model_id=model.id: self.edit_model(model_id),
90
  inputs=[],
91
- outputs=[]
92
  )
93
  with gr.Column(scale=1, min_width=10):
94
  delete_btn = gr.Button("🗑️ Delete", size="sm", variant="stop")
@@ -107,7 +107,7 @@ class DraftsTab(BaseTab):
107
  # Switch to project view with this model
108
  self.app.switch_project(model_id)
109
  # Set main tab to Project (index 0)
110
- self.app.switch_to_tab(0)
111
 
112
  def delete_model(self, model_id: str) -> gr.Column:
113
  """Delete a model and refresh the list"""
 
88
  edit_btn.click(
89
  fn=lambda model_id=model.id: self.edit_model(model_id),
90
  inputs=[],
91
+ outputs=[self.app.main_tabs]
92
  )
93
  with gr.Column(scale=1, min_width=10):
94
  delete_btn = gr.Button("🗑️ Delete", size="sm", variant="stop")
 
107
  # Switch to project view with this model
108
  self.app.switch_project(model_id)
109
  # Set main tab to Project (index 0)
110
+ return self.app.main_tabs.update(selected=0)
111
 
112
  def delete_model(self, model_id: str) -> gr.Column:
113
  """Delete a model and refresh the list"""
vms/ui/models/tabs/trained_tab.py CHANGED
@@ -94,7 +94,7 @@ class TrainedTab(BaseTab):
94
  preview_btn.click(
95
  fn=lambda model_id=model.id: self.preview_model(model_id),
96
  inputs=[],
97
- outputs=[]
98
  )
99
 
100
  download_btn.click(
@@ -106,7 +106,7 @@ class TrainedTab(BaseTab):
106
  publish_btn.click(
107
  fn=lambda model_id=model.id: self.publish_model(model_id),
108
  inputs=[],
109
- outputs=[]
110
  )
111
 
112
  delete_btn.click(
@@ -117,28 +117,28 @@ class TrainedTab(BaseTab):
117
 
118
  return new_container
119
 
120
- def preview_model(self, model_id: str) -> None:
121
  """Open model preview"""
122
  if self.app:
123
  # Switch to project view with this model
124
  self.app.switch_project(model_id)
125
  # Set main tab to Project (index 0)
126
- self.app.switch_to_tab(0)
127
- # Navigate to preview tab
128
- # TODO: Implement proper tab navigation
129
 
130
  def download_model(self, model_id: str) -> None:
131
  """Download model weights"""
132
  # TODO: Implement file download
133
  gr.Info(f"Download for model {model_id[:8]}... is not yet implemented")
134
 
135
- def publish_model(self, model_id: str) -> None:
136
  """Publish model to Hugging Face Hub"""
137
  if self.app:
138
  # Switch to the selected model project
139
  self.app.switch_project(model_id)
140
- # Navigate to publish tab (typically in Manage tab)
141
- # TODO: Implement proper tab navigation
 
142
 
143
  def delete_model(self, model_id: str) -> gr.Column:
144
  """Delete a model and refresh the list"""
 
94
  preview_btn.click(
95
  fn=lambda model_id=model.id: self.preview_model(model_id),
96
  inputs=[],
97
+ outputs=[self.app.main_tabs]
98
  )
99
 
100
  download_btn.click(
 
106
  publish_btn.click(
107
  fn=lambda model_id=model.id: self.publish_model(model_id),
108
  inputs=[],
109
+ outputs=[self.app.main_tabs]
110
  )
111
 
112
  delete_btn.click(
 
117
 
118
  return new_container
119
 
120
+ def preview_model(self, model_id: str) -> gr.Tabs:
121
  """Open model preview"""
122
  if self.app:
123
  # Switch to project view with this model
124
  self.app.switch_project(model_id)
125
  # Set main tab to Project (index 0)
126
+ return self.app.main_tabs.update(selected=0)
127
+ # TODO: Navigate to preview tab
 
128
 
129
  def download_model(self, model_id: str) -> None:
130
  """Download model weights"""
131
  # TODO: Implement file download
132
  gr.Info(f"Download for model {model_id[:8]}... is not yet implemented")
133
 
134
+ def publish_model(self, model_id: str) -> gr.Tabs:
135
  """Publish model to Hugging Face Hub"""
136
  if self.app:
137
  # Switch to the selected model project
138
  self.app.switch_project(model_id)
139
+ # Navigate to the main project tab
140
+ return self.app.main_tabs.update(selected=0)
141
+ # TODO: Navigate to publish tab
142
 
143
  def delete_model(self, model_id: str) -> gr.Column:
144
  """Delete a model and refresh the list"""
vms/ui/models/tabs/training_tab.py CHANGED
@@ -109,7 +109,7 @@ class TrainingTab(BaseTab):
109
  preview_btn.click(
110
  fn=lambda model_id=model.id: self.preview_model(model_id),
111
  inputs=[],
112
- outputs=[]
113
  )
114
 
115
  download_btn.click(
@@ -147,15 +147,14 @@ class TrainingTab(BaseTab):
147
  # Refresh the list
148
  return self.refresh_models()
149
 
150
- def preview_model(self, model_id: str) -> None:
151
  """Open model preview"""
152
  if self.app:
153
  # Switch to project view with this model
154
  self.app.switch_project(model_id)
155
  # Set main tab to Project (index 0)
156
- self.app.switch_to_tab(0)
157
- # Switch to preview tab (index 3)
158
- # TODO: Implement proper tab navigation
159
 
160
  def download_model(self, model_id: str) -> None:
161
  """Download model weights"""
 
109
  preview_btn.click(
110
  fn=lambda model_id=model.id: self.preview_model(model_id),
111
  inputs=[],
112
+ outputs=[self.app.main_tabs]
113
  )
114
 
115
  download_btn.click(
 
147
  # Refresh the list
148
  return self.refresh_models()
149
 
150
+ def preview_model(self, model_id: str) -> gr.Tabs:
151
  """Open model preview"""
152
  if self.app:
153
  # Switch to project view with this model
154
  self.app.switch_project(model_id)
155
  # Set main tab to Project (index 0)
156
+ return self.app.main_tabs.update(selected=0)
157
+ # TODO: Navigate to preview tab
 
158
 
159
  def download_model(self, model_id: str) -> None:
160
  """Download model weights"""
vms/ui/project/services/training.py CHANGED
@@ -40,6 +40,9 @@ from vms.config import (
40
  DEFAULT_NB_TRAINING_STEPS,
41
  DEFAULT_NB_LR_WARMUP_STEPS,
42
  DEFAULT_AUTO_RESUME,
 
 
 
43
  generate_model_project_id
44
  )
45
  from vms.utils import (
@@ -229,7 +232,13 @@ class TrainingService:
229
  "num_gpus": DEFAULT_NUM_GPUS,
230
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
231
  "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
232
- "auto_resume": DEFAULT_AUTO_RESUME
 
 
 
 
 
 
233
  }
234
 
235
  return default_state
@@ -756,9 +765,37 @@ class TrainingService:
756
  config.data_root = str(dataset_config_file)
757
 
758
  # Update LoRA parameters if using LoRA training type
759
- if training_type == "lora":
760
  config.lora_rank = int(lora_rank)
761
  config.lora_alpha = int(lora_alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
762
 
763
  # Update with resume_from_checkpoint if provided
764
  if resume_from_checkpoint:
@@ -882,8 +919,11 @@ class TrainingService:
882
  with open(self.app.output_pid_file, 'w') as f:
883
  f.write(str(process.pid))
884
 
885
- # Save session info including repo_id for later hub upload
886
- self.save_session({
 
 
 
887
  "model_type": model_type,
888
  "model_version": model_version,
889
  "training_type": training_type,
@@ -898,7 +938,20 @@ class TrainingService:
898
  "lr_warmup_steps": lr_warmup_steps,
899
  "repo_id": repo_id,
900
  "start_time": datetime.now().isoformat()
901
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
902
 
903
  # Update initial training status
904
  total_steps = int(train_steps)
 
40
  DEFAULT_NB_TRAINING_STEPS,
41
  DEFAULT_NB_LR_WARMUP_STEPS,
42
  DEFAULT_AUTO_RESUME,
43
+ DEFAULT_CONTROL_TYPE, DEFAULT_TRAIN_QK_NORM,
44
+ DEFAULT_FRAME_CONDITIONING_TYPE, DEFAULT_FRAME_CONDITIONING_INDEX,
45
+ DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK,
46
  generate_model_project_id
47
  )
48
  from vms.utils import (
 
232
  "num_gpus": DEFAULT_NUM_GPUS,
233
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
234
  "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
235
+ "auto_resume": DEFAULT_AUTO_RESUME,
236
+ # Control parameters
237
+ "control_type": DEFAULT_CONTROL_TYPE,
238
+ "train_qk_norm": DEFAULT_TRAIN_QK_NORM,
239
+ "frame_conditioning_type": DEFAULT_FRAME_CONDITIONING_TYPE,
240
+ "frame_conditioning_index": DEFAULT_FRAME_CONDITIONING_INDEX,
241
+ "frame_conditioning_concatenate_mask": DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK
242
  }
243
 
244
  return default_state
 
765
  config.data_root = str(dataset_config_file)
766
 
767
  # Update LoRA parameters if using LoRA training type
768
+ if training_type == "lora" or training_type == "control-lora":
769
  config.lora_rank = int(lora_rank)
770
  config.lora_alpha = int(lora_alpha)
771
+
772
+ # Update Control parameters if using control training types
773
+ if training_type in ["control-lora", "control-full-finetune"]:
774
+ # Get control parameters from UI state
775
+ current_state = self.load_ui_state()
776
+
777
+ # Add control-specific parameters
778
+ control_type = current_state.get("control_type", DEFAULT_CONTROL_TYPE)
779
+ train_qk_norm = current_state.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM)
780
+ frame_conditioning_type = current_state.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE)
781
+ frame_conditioning_index = current_state.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX)
782
+ frame_conditioning_concatenate_mask = current_state.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK)
783
+
784
+ # Map boolean from UI state to command line args
785
+ config_args.extend([
786
+ "--control_type", control_type,
787
+ ])
788
+
789
+ if train_qk_norm:
790
+ config_args.append("--train_qk_norm")
791
+
792
+ config_args.extend([
793
+ "--frame_conditioning_type", frame_conditioning_type,
794
+ "--frame_conditioning_index", str(frame_conditioning_index)
795
+ ])
796
+
797
+ if frame_conditioning_concatenate_mask:
798
+ config_args.append("--frame_conditioning_concatenate_mask")
799
 
800
  # Update with resume_from_checkpoint if provided
801
  if resume_from_checkpoint:
 
919
  with open(self.app.output_pid_file, 'w') as f:
920
  f.write(str(process.pid))
921
 
922
+ # Get current UI state for all parameters
923
+ current_state = self.load_ui_state()
924
+
925
+ # Build session data
926
+ session_data = {
927
  "model_type": model_type,
928
  "model_version": model_version,
929
  "training_type": training_type,
 
938
  "lr_warmup_steps": lr_warmup_steps,
939
  "repo_id": repo_id,
940
  "start_time": datetime.now().isoformat()
941
+ }
942
+
943
+ # Add control parameters if relevant
944
+ if training_type in ["control-lora", "control-full-finetune"]:
945
+ session_data.update({
946
+ "control_type": current_state.get("control_type", DEFAULT_CONTROL_TYPE),
947
+ "train_qk_norm": current_state.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM),
948
+ "frame_conditioning_type": current_state.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE),
949
+ "frame_conditioning_index": current_state.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX),
950
+ "frame_conditioning_concatenate_mask": current_state.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK)
951
+ })
952
+
953
+ # Save session
954
+ self.save_session(session_data)
955
 
956
  # Update initial training status
957
  total_steps = int(train_steps)
vms/ui/project/tabs/caption_tab.py CHANGED
@@ -11,7 +11,7 @@ from pathlib import Path
11
  import mimetypes
12
 
13
  from vms.utils import BaseTab, is_image_file, is_video_file, copy_files_to_training_dir
14
- from vms.config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX, STAGING_PATH, TRAINING_VIDEOS_PATH, USE_LARGE_DATASET
15
 
16
  logger = logging.getLogger(__name__)
17
 
@@ -52,11 +52,11 @@ class CaptionTab(BaseTab):
52
  )
53
  with gr.Row():
54
  self.components["run_autocaption_btn"] = gr.Button(
55
- "Automatically fill missing captions",
56
  variant="primary"
57
  )
58
  self.components["copy_files_to_training_dir_btn"] = gr.Button(
59
- "Copy assets to training directory",
60
  variant="primary"
61
  )
62
  self.components["stop_autocaption_btn"] = gr.Button(
 
11
  import mimetypes
12
 
13
  from vms.utils import BaseTab, is_image_file, is_video_file, copy_files_to_training_dir
14
+ from vms.config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX, STAGING_PATH, USE_LARGE_DATASET
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
52
  )
53
  with gr.Row():
54
  self.components["run_autocaption_btn"] = gr.Button(
55
+ "Automatically caption data",
56
  variant="primary"
57
  )
58
  self.components["copy_files_to_training_dir_btn"] = gr.Button(
59
+ "Copy assets to training folder",
60
  variant="primary"
61
  )
62
  self.components["stop_autocaption_btn"] = gr.Button(
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -10,8 +10,8 @@ from typing import Dict, Any, List, Optional
10
 
11
  from vms.utils import BaseTab, validate_model_repo
12
  from vms.config import (
13
- HF_API_TOKEN, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH,
14
- TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, LOG_FILE_PATH, USE_LARGE_DATASET
15
  )
16
 
17
  logger = logging.getLogger(__name__)
 
10
 
11
  from vms.utils import BaseTab, validate_model_repo
12
  from vms.config import (
13
+ HF_API_TOKEN, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
14
+ USE_LARGE_DATASET
15
  )
16
 
17
  logger = logging.getLogger(__name__)
vms/ui/project/tabs/train_tab.py CHANGED
@@ -26,7 +26,10 @@ from vms.config import (
26
  DEFAULT_PRECOMPUTATION_ITEMS,
27
  DEFAULT_NB_TRAINING_STEPS,
28
  DEFAULT_NB_LR_WARMUP_STEPS,
29
- DEFAULT_AUTO_RESUME
 
 
 
30
  )
31
 
32
  logger = logging.getLogger(__name__)
@@ -116,18 +119,165 @@ class TrainTab(BaseTab):
116
  # LoRA specific parameters (will show/hide based on training type)
117
  with gr.Row(visible=True) as lora_params_row:
118
  self.components["lora_params_row"] = lora_params_row
119
- self.components["lora_rank"] = gr.Dropdown(
120
- label="LoRA Rank",
121
- choices=["16", "32", "64", "128", "256", "512", "1024"],
122
- value=DEFAULT_LORA_RANK_STR,
123
- type="value"
124
- )
125
- self.components["lora_alpha"] = gr.Dropdown(
126
- label="LoRA Alpha",
127
- choices=["16", "32", "64", "128", "256", "512", "1024"],
128
- value=DEFAULT_LORA_ALPHA_STR,
129
- type="value"
130
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  with gr.Row():
133
  self.components["train_steps"] = gr.Number(
@@ -426,6 +576,37 @@ class TrainTab(BaseTab):
426
  inputs=[self.components["lora_alpha"]],
427
  outputs=[]
428
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  self.components["train_steps"].change(
431
  fn=lambda v: self.app.update_ui_state(train_steps=v),
@@ -470,11 +651,23 @@ class TrainTab(BaseTab):
470
  self.components["save_iterations"],
471
  self.components["preset_info"],
472
  self.components["lora_params_row"],
 
473
  self.components["num_gpus"],
474
  self.components["precomputation_items"],
475
  self.components["lr_warmup_steps"],
476
  # Add model_version to the outputs
477
- self.components["model_version"]
 
 
 
 
 
 
 
 
 
 
 
478
  ]
479
  )
480
 
@@ -702,11 +895,28 @@ class TrainTab(BaseTab):
702
  # Get model info text
703
  model_info = self.get_model_info(model_type, training_type)
704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  # Get default parameters for this model type and training type
706
  params = self.get_default_params(MODEL_TYPES.get(model_type), TRAINING_TYPES.get(training_type))
707
 
708
  # Check if LoRA params should be visible
709
- show_lora_params = training_type == "LoRA Finetune"
 
 
 
710
 
711
  # Return updates for UI components
712
  return {
@@ -715,7 +925,12 @@ class TrainTab(BaseTab):
715
  self.components["batch_size"]: params["batch_size"],
716
  self.components["learning_rate"]: params["learning_rate"],
717
  self.components["save_iterations"]: params["save_iterations"],
718
- self.components["lora_params_row"]: gr.Row(visible=show_lora_params)
 
 
 
 
 
719
  }
720
 
721
  def get_model_info(self, model_type: str, training_type: str) -> str:
@@ -729,6 +944,10 @@ class TrainTab(BaseTab):
729
 
730
  if training_type == "LoRA Finetune":
731
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
 
 
 
 
732
  else:
733
  return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
734
 
@@ -740,6 +959,10 @@ class TrainTab(BaseTab):
740
 
741
  if training_type == "LoRA Finetune":
742
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
 
 
 
 
743
  else:
744
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
745
 
@@ -751,6 +974,10 @@ class TrainTab(BaseTab):
751
 
752
  if training_type == "LoRA Finetune":
753
  return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
 
 
 
 
754
  else:
755
  return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
756
 
@@ -848,7 +1075,11 @@ class TrainTab(BaseTab):
848
  info_text = f"{description}{bucket_info}"
849
 
850
  # Check if LoRA params should be visible
851
- show_lora_params = preset["training_type"] == "lora"
 
 
 
 
852
 
853
  # Use preset defaults but preserve user-modified values if they exist
854
  lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", DEFAULT_LORA_RANK_STR) else preset.get("lora_rank", DEFAULT_LORA_RANK_STR)
@@ -861,6 +1092,13 @@ class TrainTab(BaseTab):
861
  precomputation_items_val = current_state.get("precomputation_items") if current_state.get("precomputation_items") != preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) else preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS)
862
  lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS)
863
 
 
 
 
 
 
 
 
864
  # Get the appropriate model version for the selected model type
865
  model_versions = self.get_model_version_choices(model_display_name)
866
  default_model_version = self.get_default_model_version(model_display_name)
@@ -896,6 +1134,16 @@ class TrainTab(BaseTab):
896
  precomputation_items_val,
897
  lr_warmup_steps_val,
898
  model_version_update,
 
 
 
 
 
 
 
 
 
 
899
  )
900
 
901
 
 
26
  DEFAULT_PRECOMPUTATION_ITEMS,
27
  DEFAULT_NB_TRAINING_STEPS,
28
  DEFAULT_NB_LR_WARMUP_STEPS,
29
+ DEFAULT_AUTO_RESUME,
30
+ DEFAULT_CONTROL_TYPE, DEFAULT_TRAIN_QK_NORM,
31
+ DEFAULT_FRAME_CONDITIONING_TYPE, DEFAULT_FRAME_CONDITIONING_INDEX,
32
+ DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK
33
  )
34
 
35
  logger = logging.getLogger(__name__)
 
119
  # LoRA specific parameters (will show/hide based on training type)
120
  with gr.Row(visible=True) as lora_params_row:
121
  self.components["lora_params_row"] = lora_params_row
122
+ with gr.Column():
123
+ gr.Markdown("""
124
+ ## 🔄 LoRA Training Parameters
125
+
126
+ LoRA (Low-Rank Adaptation) trains small adapter matrices instead of the full model, requiring much less memory while still achieving great results.
127
+ """)
128
+
129
+ # Second row for actual LoRA parameters
130
+ with gr.Row(visible=True) as lora_settings_row:
131
+ self.components["lora_settings_row"] = lora_settings_row
132
+ with gr.Column():
133
+ self.components["lora_rank"] = gr.Dropdown(
134
+ label="LoRA Rank",
135
+ choices=["16", "32", "64", "128", "256", "512", "1024"],
136
+ value=DEFAULT_LORA_RANK_STR,
137
+ type="value",
138
+ info="Controls the size and expressiveness of LoRA adapters. Higher values = better quality but larger file size"
139
+ )
140
+
141
+ with gr.Accordion("What is LoRA Rank?", open=False):
142
+ gr.Markdown("""
143
+ **LoRA Rank** determines the complexity of the LoRA adapters:
144
+
145
+ - **Lower rank (16-32)**: Smaller file size, faster training, but less expressive
146
+ - **Medium rank (64-128)**: Good balance between quality and file size
147
+ - **Higher rank (256-1024)**: More expressive adapters, better quality but larger file size
148
+
149
+ Think of rank as the "capacity" of your adapter. Higher ranks can learn more complex modifications to the base model but require more VRAM during training and result in larger files.
150
+
151
+ **Quick guide:**
152
+ - For Wan models: Use 32-64 (Wan models work well with lower ranks)
153
+ - For LTX-Video: Use 128-256
154
+ - For Hunyuan Video: Use 128
155
+ """)
156
+
157
+ with gr.Column():
158
+ self.components["lora_alpha"] = gr.Dropdown(
159
+ label="LoRA Alpha",
160
+ choices=["16", "32", "64", "128", "256", "512", "1024"],
161
+ value=DEFAULT_LORA_ALPHA_STR,
162
+ type="value",
163
+ info="Controls the effective learning rate scaling of LoRA adapters. Usually set to same value as rank"
164
+ )
165
+
166
+ with gr.Accordion("What is LoRA Alpha?", open=False):
167
+ gr.Markdown("""
168
+ **LoRA Alpha** controls the effective scale of the LoRA updates:
169
+
170
+ - The actual scaling factor is calculated as `alpha ÷ rank`
171
+ - Usually set to match the rank value (alpha = rank)
172
+ - Higher alpha = stronger effect from the adapters
173
+ - Lower alpha = more subtle adapter influence
174
+
175
+ **Best practice:**
176
+ - For most cases, set alpha equal to rank
177
+ - For more aggressive training, set alpha higher than rank
178
+ - For more conservative training, set alpha lower than rank
179
+ """)
180
+
181
+
182
+ # Control specific parameters (will show/hide based on training type)
183
+ with gr.Row(visible=False) as control_params_row:
184
+ self.components["control_params_row"] = control_params_row
185
+ with gr.Column():
186
+ gr.Markdown("""
187
+ ## 🖼️ Control Training Settings
188
+
189
+ Control training enables **image-to-video generation** by teaching the model how to use an image as a guide for video creation.
190
+ This is ideal for turning still images into dynamic videos while preserving composition, style, and content.
191
+ """)
192
+
193
+ # Second row for control parameters
194
+ with gr.Row(visible=False) as control_settings_row:
195
+ self.components["control_settings_row"] = control_settings_row
196
+ with gr.Column():
197
+ self.components["control_type"] = gr.Dropdown(
198
+ label="Control Type",
199
+ choices=["canny", "custom"],
200
+ value=DEFAULT_CONTROL_TYPE,
201
+ info="Type of control conditioning. 'canny' uses edge detection preprocessing, 'custom' allows direct image conditioning."
202
+ )
203
+
204
+ with gr.Accordion("What is Control Conditioning?", open=False):
205
+ gr.Markdown("""
206
+ **Control Conditioning** allows the model to be guided by an input image, adapting the video generation based on the image content. This is used for image-to-video generation where you want to turn an image into a moving video while maintaining its style, composition or content.
207
+
208
+ - **canny**: Uses edge detection to extract outlines from images for structure-preserving video generation
209
+ - **custom**: Direct image conditioning without preprocessing, preserving more image details
210
+ """)
211
+
212
+ with gr.Column():
213
+ self.components["train_qk_norm"] = gr.Checkbox(
214
+ label="Train QK Normalization Layers",
215
+ value=DEFAULT_TRAIN_QK_NORM,
216
+ info="Enable to train query-key normalization layers for better control signal integration"
217
+ )
218
+
219
+ with gr.Accordion("What is QK Normalization?", open=False):
220
+ gr.Markdown("""
221
+ **QK Normalization** refers to normalizing the query and key values in the attention mechanism of transformers.
222
+
223
+ - When enabled, allows the model to better integrate control signals with content generation
224
+ - Improves training stability for control models
225
+ - Generally recommended for control training, especially with image conditioning
226
+ """)
227
+
228
+ with gr.Row(visible=False) as frame_conditioning_row:
229
+ self.components["frame_conditioning_row"] = frame_conditioning_row
230
+ with gr.Column():
231
+ self.components["frame_conditioning_type"] = gr.Dropdown(
232
+ label="Frame Conditioning Type",
233
+ choices=["index", "prefix", "random", "first_and_last", "full"],
234
+ value=DEFAULT_FRAME_CONDITIONING_TYPE,
235
+ info="Determines which frames receive conditioning during training"
236
+ )
237
+
238
+ with gr.Accordion("Frame Conditioning Type Explanation", open=False):
239
+ gr.Markdown("""
240
+ **Frame Conditioning Types** determine which frames in the video receive image conditioning:
241
+
242
+ - **index**: Only applies conditioning to a single frame at the specified index
243
+ - **prefix**: Applies conditioning to all frames before a certain point
244
+ - **random**: Randomly selects frames to receive conditioning during training
245
+ - **first_and_last**: Only applies conditioning to the first and last frames
246
+ - **full**: Applies conditioning to all frames in the video
247
+
248
+ For image-to-video tasks, 'index' (usually with index 0) is most common as it conditions only the first frame.
249
+ """)
250
+
251
+ with gr.Column():
252
+ self.components["frame_conditioning_index"] = gr.Number(
253
+ label="Frame Conditioning Index",
254
+ value=DEFAULT_FRAME_CONDITIONING_INDEX,
255
+ precision=0,
256
+ info="Specifies which frame receives conditioning when using 'index' type (0 = first frame)"
257
+ )
258
+
259
+ with gr.Row(visible=False) as control_options_row:
260
+ self.components["control_options_row"] = control_options_row
261
+ with gr.Column():
262
+ self.components["frame_conditioning_concatenate_mask"] = gr.Checkbox(
263
+ label="Concatenate Frame Mask",
264
+ value=DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK,
265
+ info="Enable to add frame mask information to the conditioning channels"
266
+ )
267
+
268
+ with gr.Accordion("What is Frame Mask Concatenation?", open=False):
269
+ gr.Markdown("""
270
+ **Frame Mask Concatenation** adds an additional channel to the control signal that indicates which frames are being conditioned:
271
+
272
+ - Creates a binary mask (0/1) indicating which frames receive conditioning
273
+ - Helps the model distinguish between conditioned and unconditioned frames
274
+ - Particularly useful for 'index' conditioning where only select frames are conditioned
275
+ - Generally improves temporal consistency between conditioned and unconditioned frames
276
+ """)
277
+
278
+ with gr.Column():
279
+ # Empty column for layout balance
280
+ pass
281
 
282
  with gr.Row():
283
  self.components["train_steps"] = gr.Number(
 
576
  inputs=[self.components["lora_alpha"]],
577
  outputs=[]
578
  )
579
+
580
+ # Control parameters change events
581
+ self.components["control_type"].change(
582
+ fn=lambda v: self.app.update_ui_state(control_type=v),
583
+ inputs=[self.components["control_type"]],
584
+ outputs=[]
585
+ )
586
+
587
+ self.components["train_qk_norm"].change(
588
+ fn=lambda v: self.app.update_ui_state(train_qk_norm=v),
589
+ inputs=[self.components["train_qk_norm"]],
590
+ outputs=[]
591
+ )
592
+
593
+ self.components["frame_conditioning_type"].change(
594
+ fn=lambda v: self.app.update_ui_state(frame_conditioning_type=v),
595
+ inputs=[self.components["frame_conditioning_type"]],
596
+ outputs=[]
597
+ )
598
+
599
+ self.components["frame_conditioning_index"].change(
600
+ fn=lambda v: self.app.update_ui_state(frame_conditioning_index=v),
601
+ inputs=[self.components["frame_conditioning_index"]],
602
+ outputs=[]
603
+ )
604
+
605
+ self.components["frame_conditioning_concatenate_mask"].change(
606
+ fn=lambda v: self.app.update_ui_state(frame_conditioning_concatenate_mask=v),
607
+ inputs=[self.components["frame_conditioning_concatenate_mask"]],
608
+ outputs=[]
609
+ )
610
 
611
  self.components["train_steps"].change(
612
  fn=lambda v: self.app.update_ui_state(train_steps=v),
 
651
  self.components["save_iterations"],
652
  self.components["preset_info"],
653
  self.components["lora_params_row"],
654
+ self.components["lora_settings_row"],
655
  self.components["num_gpus"],
656
  self.components["precomputation_items"],
657
  self.components["lr_warmup_steps"],
658
  # Add model_version to the outputs
659
+ self.components["model_version"],
660
+ # Control parameters rows visibility
661
+ self.components["control_params_row"],
662
+ self.components["control_settings_row"],
663
+ self.components["frame_conditioning_row"],
664
+ self.components["control_options_row"],
665
+ # Control parameter values
666
+ self.components["control_type"],
667
+ self.components["train_qk_norm"],
668
+ self.components["frame_conditioning_type"],
669
+ self.components["frame_conditioning_index"],
670
+ self.components["frame_conditioning_concatenate_mask"],
671
  ]
672
  )
673
 
 
895
  # Get model info text
896
  model_info = self.get_model_info(model_type, training_type)
897
 
898
+ # Add general information about the selected training type
899
+ if training_type == "Full Finetune":
900
+ finetune_info = """
901
+ ## 🧠 Full Finetune Mode
902
+
903
+ Full finetune mode trains all parameters of the model, requiring more VRAM but potentially enabling higher quality results.
904
+
905
+ - Requires 20-50GB+ VRAM depending on model
906
+ - Creates a complete standalone model (~8GB+ file size)
907
+ - Recommended only for high-end GPUs (A100, H100, etc.)
908
+ - Not recommended for the larger models like Hunyuan Video on consumer hardware
909
+ """
910
+ model_info = finetune_info + "\n\n" + model_info
911
+
912
  # Get default parameters for this model type and training type
913
  params = self.get_default_params(MODEL_TYPES.get(model_type), TRAINING_TYPES.get(training_type))
914
 
915
  # Check if LoRA params should be visible
916
+ show_lora_params = training_type in ["LoRA Finetune", "Control LoRA"]
917
+
918
+ # Check if Control-specific params should be visible
919
+ show_control_params = training_type in ["Control LoRA", "Control Full Finetune"]
920
 
921
  # Return updates for UI components
922
  return {
 
925
  self.components["batch_size"]: params["batch_size"],
926
  self.components["learning_rate"]: params["learning_rate"],
927
  self.components["save_iterations"]: params["save_iterations"],
928
+ self.components["lora_params_row"]: gr.Row(visible=show_lora_params),
929
+ self.components["lora_settings_row"]: gr.Row(visible=show_lora_params),
930
+ self.components["control_params_row"]: gr.Row(visible=show_control_params),
931
+ self.components["control_settings_row"]: gr.Row(visible=show_control_params),
932
+ self.components["frame_conditioning_row"]: gr.Row(visible=show_control_params),
933
+ self.components["control_options_row"]: gr.Row(visible=show_control_params)
934
  }
935
 
936
  def get_model_info(self, model_type: str, training_type: str) -> str:
 
944
 
945
  if training_type == "LoRA Finetune":
946
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
947
+ elif training_type == "Control LoRA":
948
+ return base_info + "\n- Required VRAM: ~20GB minimum\n- Default LoRA rank: 128 (~400 MB)\n- Supports image conditioning"
949
+ elif training_type == "Control Full Finetune":
950
+ return base_info + "\n- Required VRAM: ~50GB minimum\n- Supports image conditioning\n- **Not recommended due to VRAM requirements**"
951
  else:
952
  return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
953
 
 
959
 
960
  if training_type == "LoRA Finetune":
961
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
962
+ elif training_type == "Control LoRA":
963
+ return base_info + "\n- Required VRAM: ~20GB minimum\n- Default LoRA rank: 128 (~400 MB)\n- Supports image conditioning"
964
+ elif training_type == "Control Full Finetune":
965
+ return base_info + "\n- Required VRAM: ~23GB minimum\n- Supports image conditioning"
966
  else:
967
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
968
 
 
974
 
975
  if training_type == "LoRA Finetune":
976
  return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
977
+ elif training_type == "Control LoRA":
978
+ return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 32 (~120 MB)\n- Supports image conditioning"
979
+ elif training_type == "Control Full Finetune":
980
+ return base_info + "\n- Required VRAM: ~40GB minimum\n- Supports image conditioning\n- **Not recommended due to VRAM requirements**"
981
  else:
982
  return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
983
 
 
1075
  info_text = f"{description}{bucket_info}"
1076
 
1077
  # Check if LoRA params should be visible
1078
+ training_type_internal = preset["training_type"]
1079
+ show_lora_params = training_type_internal == "lora" or training_type_internal == "control-lora"
1080
+
1081
+ # Check if Control params should be visible
1082
+ show_control_params = training_type_internal == "control-lora" or training_type_internal == "control-full-finetune"
1083
 
1084
  # Use preset defaults but preserve user-modified values if they exist
1085
  lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", DEFAULT_LORA_RANK_STR) else preset.get("lora_rank", DEFAULT_LORA_RANK_STR)
 
1092
  precomputation_items_val = current_state.get("precomputation_items") if current_state.get("precomputation_items") != preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) else preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS)
1093
  lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS)
1094
 
1095
+ # Control parameters
1096
+ control_type_val = current_state.get("control_type") if current_state.get("control_type") != preset.get("control_type", DEFAULT_CONTROL_TYPE) else preset.get("control_type", DEFAULT_CONTROL_TYPE)
1097
+ train_qk_norm_val = current_state.get("train_qk_norm") if current_state.get("train_qk_norm") != preset.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM) else preset.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM)
1098
+ frame_conditioning_type_val = current_state.get("frame_conditioning_type") if current_state.get("frame_conditioning_type") != preset.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE) else preset.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE)
1099
+ frame_conditioning_index_val = current_state.get("frame_conditioning_index") if current_state.get("frame_conditioning_index") != preset.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX) else preset.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX)
1100
+ frame_conditioning_concatenate_mask_val = current_state.get("frame_conditioning_concatenate_mask") if current_state.get("frame_conditioning_concatenate_mask") != preset.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK) else preset.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK)
1101
+
1102
  # Get the appropriate model version for the selected model type
1103
  model_versions = self.get_model_version_choices(model_display_name)
1104
  default_model_version = self.get_default_model_version(model_display_name)
 
1134
  precomputation_items_val,
1135
  lr_warmup_steps_val,
1136
  model_version_update,
1137
+ # Control parameters rows visibility
1138
+ gr.Row(visible=show_control_params),
1139
+ gr.Row(visible=show_control_params),
1140
+ gr.Row(visible=show_control_params),
1141
+ # Control parameter values
1142
+ control_type_val,
1143
+ train_qk_norm_val,
1144
+ frame_conditioning_type_val,
1145
+ frame_conditioning_index_val,
1146
+ frame_conditioning_concatenate_mask_val,
1147
  )
1148
 
1149