Spaces:
Running
Running
Commit
·
2ba9257
1
Parent(s):
267b599
improve UI to support conditioning
Browse files- CLAUDE.md +17 -9
- degraded_requirements.txt +8 -0
- vms/config.py +73 -1
- vms/ui/models/tabs/drafts_tab.py +2 -2
- vms/ui/models/tabs/trained_tab.py +9 -9
- vms/ui/models/tabs/training_tab.py +4 -5
- vms/ui/project/services/training.py +58 -5
- vms/ui/project/tabs/caption_tab.py +3 -3
- vms/ui/project/tabs/manage_tab.py +2 -2
- vms/ui/project/tabs/train_tab.py +265 -17
CLAUDE.md
CHANGED
@@ -1,18 +1,26 @@
|
|
|
|
|
|
|
|
|
|
1 |
# Video Model Studio - Guidelines for Claude
|
2 |
|
3 |
## Build & Run Commands
|
4 |
-
- Setup: `./setup.sh` (with flash attention) or `./
|
5 |
-
- Run: `./run.sh` or `python3.10 app.py`
|
6 |
-
- Test:
|
7 |
-
-
|
|
|
|
|
8 |
|
9 |
## Code Style
|
10 |
- Python version: 3.10 (required for flash-attention compatibility)
|
11 |
-
- Type hints: Use typing module annotations for all functions
|
12 |
-
- Docstrings: Google style with Args/Returns sections
|
13 |
-
- Error handling: Use try/except with specific exceptions, log errors
|
14 |
- Imports: Group standard lib, third-party, and project imports
|
15 |
- Naming: snake_case for functions/variables, PascalCase for classes
|
16 |
- Use Path objects from pathlib instead of string paths
|
17 |
-
-
|
18 |
-
- Environment variables: Use parse_bool_env for boolean env vars
|
|
|
|
|
|
1 |
+
# CLAUDE.md
|
2 |
+
|
3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
4 |
+
|
5 |
# Video Model Studio - Guidelines for Claude
|
6 |
|
7 |
## Build & Run Commands
|
8 |
+
- Setup: `./setup.sh` (with flash attention) or `./degraded_setup.sh` (without flash-attention)
|
9 |
+
- Run: `./run.sh` or `python3.10 app.py`
|
10 |
+
- Test:
|
11 |
+
- Full test: `python3 tests/test_dataset.py`
|
12 |
+
- Single model test: `bash tests/scripts/dummy_cogvideox_lora.sh` (or other model variants)
|
13 |
+
- Run test suite: `bash tests/test_model_runs_minimally_lora.sh`
|
14 |
|
15 |
## Code Style
|
16 |
- Python version: 3.10 (required for flash-attention compatibility)
|
17 |
+
- Type hints: Use typing module annotations for all functions (from typing import Any, Optional, Dict, List, Union, Tuple)
|
18 |
+
- Docstrings: Google style with Args/Returns sections for all functions
|
19 |
+
- Error handling: Use try/except with specific exceptions, log errors appropriately
|
20 |
- Imports: Group standard lib, third-party, and project imports
|
21 |
- Naming: snake_case for functions/variables, PascalCase for classes
|
22 |
- Use Path objects from pathlib instead of string paths
|
23 |
+
- Extract reusable logic to separate utility functions
|
24 |
+
- Environment variables: Use parse_bool_env for boolean env vars
|
25 |
+
- Logging: Use the logging module with appropriate log levels (DEBUG, INFO, WARNING, ERROR)
|
26 |
+
- UI components: Organize in tabs and use consistent naming for components dict
|
degraded_requirements.txt
CHANGED
@@ -9,8 +9,16 @@ diffusers @ git+https://github.com/huggingface/diffusers.git@main
|
|
9 |
imageio
|
10 |
imageio-ffmpeg
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
#flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
13 |
|
|
|
|
|
14 |
# for youtube video download
|
15 |
pytube
|
16 |
pytubefix
|
|
|
9 |
imageio
|
10 |
imageio-ffmpeg
|
11 |
|
12 |
+
#--------------- MACOS HACKS ----------------
|
13 |
+
|
14 |
+
# use eva-decord for better compatiblity on macOS
|
15 |
+
eva-decord
|
16 |
+
|
17 |
+
# don't install flash attention on macOS
|
18 |
#flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
19 |
|
20 |
+
#--------------- / MACOS HACKS --------------
|
21 |
+
|
22 |
# for youtube video download
|
23 |
pytube
|
24 |
pytubefix
|
vms/config.py
CHANGED
@@ -205,7 +205,9 @@ MODEL_TYPES = {
|
|
205 |
# Training types
|
206 |
TRAINING_TYPES = {
|
207 |
"LoRA Finetune": "lora",
|
208 |
-
"Full Finetune": "full-finetune"
|
|
|
|
|
209 |
}
|
210 |
|
211 |
# Model versions for each model type
|
@@ -288,6 +290,13 @@ DEFAULT_NB_LR_WARMUP_STEPS = math.ceil(0.20 * DEFAULT_NB_TRAINING_STEPS) # 20%
|
|
288 |
# Whether to automatically restart a training job after a server reboot or not
|
289 |
DEFAULT_AUTO_RESUME = False
|
290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
# For validation
|
292 |
DEFAULT_VALIDATION_NB_STEPS = 50
|
293 |
DEFAULT_VALIDATION_HEIGHT = 512
|
@@ -468,6 +477,69 @@ TRAINING_PRESETS = {
|
|
468 |
"num_gpus": DEFAULT_NUM_GPUS,
|
469 |
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
470 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
}
|
472 |
}
|
473 |
|
|
|
205 |
# Training types
|
206 |
TRAINING_TYPES = {
|
207 |
"LoRA Finetune": "lora",
|
208 |
+
"Full Finetune": "full-finetune",
|
209 |
+
"Control LoRA": "control-lora",
|
210 |
+
"Control Full Finetune": "control-full-finetune"
|
211 |
}
|
212 |
|
213 |
# Model versions for each model type
|
|
|
290 |
# Whether to automatically restart a training job after a server reboot or not
|
291 |
DEFAULT_AUTO_RESUME = False
|
292 |
|
293 |
+
# Control training defaults
|
294 |
+
DEFAULT_CONTROL_TYPE = "canny"
|
295 |
+
DEFAULT_TRAIN_QK_NORM = False
|
296 |
+
DEFAULT_FRAME_CONDITIONING_TYPE = "full"
|
297 |
+
DEFAULT_FRAME_CONDITIONING_INDEX = 0
|
298 |
+
DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK = False
|
299 |
+
|
300 |
# For validation
|
301 |
DEFAULT_VALIDATION_NB_STEPS = 50
|
302 |
DEFAULT_VALIDATION_HEIGHT = 512
|
|
|
477 |
"num_gpus": DEFAULT_NUM_GPUS,
|
478 |
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
479 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
480 |
+
},
|
481 |
+
"Wan-2.1-I2V (Control LoRA)": {
|
482 |
+
"model_type": "wan",
|
483 |
+
"training_type": "control-lora",
|
484 |
+
"lora_rank": "32",
|
485 |
+
"lora_alpha": "32",
|
486 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
487 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
488 |
+
"learning_rate": 5e-5,
|
489 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
490 |
+
"training_buckets": SMALL_TRAINING_BUCKETS,
|
491 |
+
"flow_weighting_scheme": "logit_normal",
|
492 |
+
"num_gpus": DEFAULT_NUM_GPUS,
|
493 |
+
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
494 |
+
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
495 |
+
"control_type": "custom",
|
496 |
+
"train_qk_norm": True,
|
497 |
+
"frame_conditioning_type": "index",
|
498 |
+
"frame_conditioning_index": 0,
|
499 |
+
"frame_conditioning_concatenate_mask": True,
|
500 |
+
"description": "Image-conditioned video generation with LoRA adapters"
|
501 |
+
},
|
502 |
+
"LTX-Video (Control LoRA)": {
|
503 |
+
"model_type": "ltx_video",
|
504 |
+
"training_type": "control-lora",
|
505 |
+
"lora_rank": "128",
|
506 |
+
"lora_alpha": "128",
|
507 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
508 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
509 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
510 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
511 |
+
"training_buckets": SMALL_TRAINING_BUCKETS,
|
512 |
+
"flow_weighting_scheme": "logit_normal",
|
513 |
+
"num_gpus": DEFAULT_NUM_GPUS,
|
514 |
+
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
515 |
+
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
516 |
+
"control_type": "custom",
|
517 |
+
"train_qk_norm": True,
|
518 |
+
"frame_conditioning_type": "index",
|
519 |
+
"frame_conditioning_index": 0,
|
520 |
+
"frame_conditioning_concatenate_mask": True,
|
521 |
+
"description": "Image-conditioned video generation with LoRA adapters"
|
522 |
+
},
|
523 |
+
"HunyuanVideo (Control LoRA)": {
|
524 |
+
"model_type": "hunyuan_video",
|
525 |
+
"training_type": "control-lora",
|
526 |
+
"lora_rank": "128",
|
527 |
+
"lora_alpha": "128",
|
528 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
529 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
530 |
+
"learning_rate": 2e-5,
|
531 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
532 |
+
"training_buckets": SMALL_TRAINING_BUCKETS,
|
533 |
+
"flow_weighting_scheme": "none",
|
534 |
+
"num_gpus": DEFAULT_NUM_GPUS,
|
535 |
+
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
536 |
+
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
537 |
+
"control_type": "custom",
|
538 |
+
"train_qk_norm": True,
|
539 |
+
"frame_conditioning_type": "index",
|
540 |
+
"frame_conditioning_index": 0,
|
541 |
+
"frame_conditioning_concatenate_mask": True,
|
542 |
+
"description": "Image-conditioned video generation with HunyuanVideo and LoRA adapters"
|
543 |
}
|
544 |
}
|
545 |
|
vms/ui/models/tabs/drafts_tab.py
CHANGED
@@ -88,7 +88,7 @@ class DraftsTab(BaseTab):
|
|
88 |
edit_btn.click(
|
89 |
fn=lambda model_id=model.id: self.edit_model(model_id),
|
90 |
inputs=[],
|
91 |
-
outputs=[]
|
92 |
)
|
93 |
with gr.Column(scale=1, min_width=10):
|
94 |
delete_btn = gr.Button("🗑️ Delete", size="sm", variant="stop")
|
@@ -107,7 +107,7 @@ class DraftsTab(BaseTab):
|
|
107 |
# Switch to project view with this model
|
108 |
self.app.switch_project(model_id)
|
109 |
# Set main tab to Project (index 0)
|
110 |
-
self.app.
|
111 |
|
112 |
def delete_model(self, model_id: str) -> gr.Column:
|
113 |
"""Delete a model and refresh the list"""
|
|
|
88 |
edit_btn.click(
|
89 |
fn=lambda model_id=model.id: self.edit_model(model_id),
|
90 |
inputs=[],
|
91 |
+
outputs=[self.app.main_tabs]
|
92 |
)
|
93 |
with gr.Column(scale=1, min_width=10):
|
94 |
delete_btn = gr.Button("🗑️ Delete", size="sm", variant="stop")
|
|
|
107 |
# Switch to project view with this model
|
108 |
self.app.switch_project(model_id)
|
109 |
# Set main tab to Project (index 0)
|
110 |
+
return self.app.main_tabs.update(selected=0)
|
111 |
|
112 |
def delete_model(self, model_id: str) -> gr.Column:
|
113 |
"""Delete a model and refresh the list"""
|
vms/ui/models/tabs/trained_tab.py
CHANGED
@@ -94,7 +94,7 @@ class TrainedTab(BaseTab):
|
|
94 |
preview_btn.click(
|
95 |
fn=lambda model_id=model.id: self.preview_model(model_id),
|
96 |
inputs=[],
|
97 |
-
outputs=[]
|
98 |
)
|
99 |
|
100 |
download_btn.click(
|
@@ -106,7 +106,7 @@ class TrainedTab(BaseTab):
|
|
106 |
publish_btn.click(
|
107 |
fn=lambda model_id=model.id: self.publish_model(model_id),
|
108 |
inputs=[],
|
109 |
-
outputs=[]
|
110 |
)
|
111 |
|
112 |
delete_btn.click(
|
@@ -117,28 +117,28 @@ class TrainedTab(BaseTab):
|
|
117 |
|
118 |
return new_container
|
119 |
|
120 |
-
def preview_model(self, model_id: str) ->
|
121 |
"""Open model preview"""
|
122 |
if self.app:
|
123 |
# Switch to project view with this model
|
124 |
self.app.switch_project(model_id)
|
125 |
# Set main tab to Project (index 0)
|
126 |
-
self.app.
|
127 |
-
# Navigate to preview tab
|
128 |
-
# TODO: Implement proper tab navigation
|
129 |
|
130 |
def download_model(self, model_id: str) -> None:
|
131 |
"""Download model weights"""
|
132 |
# TODO: Implement file download
|
133 |
gr.Info(f"Download for model {model_id[:8]}... is not yet implemented")
|
134 |
|
135 |
-
def publish_model(self, model_id: str) ->
|
136 |
"""Publish model to Hugging Face Hub"""
|
137 |
if self.app:
|
138 |
# Switch to the selected model project
|
139 |
self.app.switch_project(model_id)
|
140 |
-
# Navigate to
|
141 |
-
|
|
|
142 |
|
143 |
def delete_model(self, model_id: str) -> gr.Column:
|
144 |
"""Delete a model and refresh the list"""
|
|
|
94 |
preview_btn.click(
|
95 |
fn=lambda model_id=model.id: self.preview_model(model_id),
|
96 |
inputs=[],
|
97 |
+
outputs=[self.app.main_tabs]
|
98 |
)
|
99 |
|
100 |
download_btn.click(
|
|
|
106 |
publish_btn.click(
|
107 |
fn=lambda model_id=model.id: self.publish_model(model_id),
|
108 |
inputs=[],
|
109 |
+
outputs=[self.app.main_tabs]
|
110 |
)
|
111 |
|
112 |
delete_btn.click(
|
|
|
117 |
|
118 |
return new_container
|
119 |
|
120 |
+
def preview_model(self, model_id: str) -> gr.Tabs:
|
121 |
"""Open model preview"""
|
122 |
if self.app:
|
123 |
# Switch to project view with this model
|
124 |
self.app.switch_project(model_id)
|
125 |
# Set main tab to Project (index 0)
|
126 |
+
return self.app.main_tabs.update(selected=0)
|
127 |
+
# TODO: Navigate to preview tab
|
|
|
128 |
|
129 |
def download_model(self, model_id: str) -> None:
|
130 |
"""Download model weights"""
|
131 |
# TODO: Implement file download
|
132 |
gr.Info(f"Download for model {model_id[:8]}... is not yet implemented")
|
133 |
|
134 |
+
def publish_model(self, model_id: str) -> gr.Tabs:
|
135 |
"""Publish model to Hugging Face Hub"""
|
136 |
if self.app:
|
137 |
# Switch to the selected model project
|
138 |
self.app.switch_project(model_id)
|
139 |
+
# Navigate to the main project tab
|
140 |
+
return self.app.main_tabs.update(selected=0)
|
141 |
+
# TODO: Navigate to publish tab
|
142 |
|
143 |
def delete_model(self, model_id: str) -> gr.Column:
|
144 |
"""Delete a model and refresh the list"""
|
vms/ui/models/tabs/training_tab.py
CHANGED
@@ -109,7 +109,7 @@ class TrainingTab(BaseTab):
|
|
109 |
preview_btn.click(
|
110 |
fn=lambda model_id=model.id: self.preview_model(model_id),
|
111 |
inputs=[],
|
112 |
-
outputs=[]
|
113 |
)
|
114 |
|
115 |
download_btn.click(
|
@@ -147,15 +147,14 @@ class TrainingTab(BaseTab):
|
|
147 |
# Refresh the list
|
148 |
return self.refresh_models()
|
149 |
|
150 |
-
def preview_model(self, model_id: str) ->
|
151 |
"""Open model preview"""
|
152 |
if self.app:
|
153 |
# Switch to project view with this model
|
154 |
self.app.switch_project(model_id)
|
155 |
# Set main tab to Project (index 0)
|
156 |
-
self.app.
|
157 |
-
#
|
158 |
-
# TODO: Implement proper tab navigation
|
159 |
|
160 |
def download_model(self, model_id: str) -> None:
|
161 |
"""Download model weights"""
|
|
|
109 |
preview_btn.click(
|
110 |
fn=lambda model_id=model.id: self.preview_model(model_id),
|
111 |
inputs=[],
|
112 |
+
outputs=[self.app.main_tabs]
|
113 |
)
|
114 |
|
115 |
download_btn.click(
|
|
|
147 |
# Refresh the list
|
148 |
return self.refresh_models()
|
149 |
|
150 |
+
def preview_model(self, model_id: str) -> gr.Tabs:
|
151 |
"""Open model preview"""
|
152 |
if self.app:
|
153 |
# Switch to project view with this model
|
154 |
self.app.switch_project(model_id)
|
155 |
# Set main tab to Project (index 0)
|
156 |
+
return self.app.main_tabs.update(selected=0)
|
157 |
+
# TODO: Navigate to preview tab
|
|
|
158 |
|
159 |
def download_model(self, model_id: str) -> None:
|
160 |
"""Download model weights"""
|
vms/ui/project/services/training.py
CHANGED
@@ -40,6 +40,9 @@ from vms.config import (
|
|
40 |
DEFAULT_NB_TRAINING_STEPS,
|
41 |
DEFAULT_NB_LR_WARMUP_STEPS,
|
42 |
DEFAULT_AUTO_RESUME,
|
|
|
|
|
|
|
43 |
generate_model_project_id
|
44 |
)
|
45 |
from vms.utils import (
|
@@ -229,7 +232,13 @@ class TrainingService:
|
|
229 |
"num_gpus": DEFAULT_NUM_GPUS,
|
230 |
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
231 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
232 |
-
"auto_resume": DEFAULT_AUTO_RESUME
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
}
|
234 |
|
235 |
return default_state
|
@@ -756,9 +765,37 @@ class TrainingService:
|
|
756 |
config.data_root = str(dataset_config_file)
|
757 |
|
758 |
# Update LoRA parameters if using LoRA training type
|
759 |
-
if training_type == "lora":
|
760 |
config.lora_rank = int(lora_rank)
|
761 |
config.lora_alpha = int(lora_alpha)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
762 |
|
763 |
# Update with resume_from_checkpoint if provided
|
764 |
if resume_from_checkpoint:
|
@@ -882,8 +919,11 @@ class TrainingService:
|
|
882 |
with open(self.app.output_pid_file, 'w') as f:
|
883 |
f.write(str(process.pid))
|
884 |
|
885 |
-
#
|
886 |
-
self.
|
|
|
|
|
|
|
887 |
"model_type": model_type,
|
888 |
"model_version": model_version,
|
889 |
"training_type": training_type,
|
@@ -898,7 +938,20 @@ class TrainingService:
|
|
898 |
"lr_warmup_steps": lr_warmup_steps,
|
899 |
"repo_id": repo_id,
|
900 |
"start_time": datetime.now().isoformat()
|
901 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
902 |
|
903 |
# Update initial training status
|
904 |
total_steps = int(train_steps)
|
|
|
40 |
DEFAULT_NB_TRAINING_STEPS,
|
41 |
DEFAULT_NB_LR_WARMUP_STEPS,
|
42 |
DEFAULT_AUTO_RESUME,
|
43 |
+
DEFAULT_CONTROL_TYPE, DEFAULT_TRAIN_QK_NORM,
|
44 |
+
DEFAULT_FRAME_CONDITIONING_TYPE, DEFAULT_FRAME_CONDITIONING_INDEX,
|
45 |
+
DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK,
|
46 |
generate_model_project_id
|
47 |
)
|
48 |
from vms.utils import (
|
|
|
232 |
"num_gpus": DEFAULT_NUM_GPUS,
|
233 |
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
234 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
235 |
+
"auto_resume": DEFAULT_AUTO_RESUME,
|
236 |
+
# Control parameters
|
237 |
+
"control_type": DEFAULT_CONTROL_TYPE,
|
238 |
+
"train_qk_norm": DEFAULT_TRAIN_QK_NORM,
|
239 |
+
"frame_conditioning_type": DEFAULT_FRAME_CONDITIONING_TYPE,
|
240 |
+
"frame_conditioning_index": DEFAULT_FRAME_CONDITIONING_INDEX,
|
241 |
+
"frame_conditioning_concatenate_mask": DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK
|
242 |
}
|
243 |
|
244 |
return default_state
|
|
|
765 |
config.data_root = str(dataset_config_file)
|
766 |
|
767 |
# Update LoRA parameters if using LoRA training type
|
768 |
+
if training_type == "lora" or training_type == "control-lora":
|
769 |
config.lora_rank = int(lora_rank)
|
770 |
config.lora_alpha = int(lora_alpha)
|
771 |
+
|
772 |
+
# Update Control parameters if using control training types
|
773 |
+
if training_type in ["control-lora", "control-full-finetune"]:
|
774 |
+
# Get control parameters from UI state
|
775 |
+
current_state = self.load_ui_state()
|
776 |
+
|
777 |
+
# Add control-specific parameters
|
778 |
+
control_type = current_state.get("control_type", DEFAULT_CONTROL_TYPE)
|
779 |
+
train_qk_norm = current_state.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM)
|
780 |
+
frame_conditioning_type = current_state.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE)
|
781 |
+
frame_conditioning_index = current_state.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX)
|
782 |
+
frame_conditioning_concatenate_mask = current_state.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK)
|
783 |
+
|
784 |
+
# Map boolean from UI state to command line args
|
785 |
+
config_args.extend([
|
786 |
+
"--control_type", control_type,
|
787 |
+
])
|
788 |
+
|
789 |
+
if train_qk_norm:
|
790 |
+
config_args.append("--train_qk_norm")
|
791 |
+
|
792 |
+
config_args.extend([
|
793 |
+
"--frame_conditioning_type", frame_conditioning_type,
|
794 |
+
"--frame_conditioning_index", str(frame_conditioning_index)
|
795 |
+
])
|
796 |
+
|
797 |
+
if frame_conditioning_concatenate_mask:
|
798 |
+
config_args.append("--frame_conditioning_concatenate_mask")
|
799 |
|
800 |
# Update with resume_from_checkpoint if provided
|
801 |
if resume_from_checkpoint:
|
|
|
919 |
with open(self.app.output_pid_file, 'w') as f:
|
920 |
f.write(str(process.pid))
|
921 |
|
922 |
+
# Get current UI state for all parameters
|
923 |
+
current_state = self.load_ui_state()
|
924 |
+
|
925 |
+
# Build session data
|
926 |
+
session_data = {
|
927 |
"model_type": model_type,
|
928 |
"model_version": model_version,
|
929 |
"training_type": training_type,
|
|
|
938 |
"lr_warmup_steps": lr_warmup_steps,
|
939 |
"repo_id": repo_id,
|
940 |
"start_time": datetime.now().isoformat()
|
941 |
+
}
|
942 |
+
|
943 |
+
# Add control parameters if relevant
|
944 |
+
if training_type in ["control-lora", "control-full-finetune"]:
|
945 |
+
session_data.update({
|
946 |
+
"control_type": current_state.get("control_type", DEFAULT_CONTROL_TYPE),
|
947 |
+
"train_qk_norm": current_state.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM),
|
948 |
+
"frame_conditioning_type": current_state.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE),
|
949 |
+
"frame_conditioning_index": current_state.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX),
|
950 |
+
"frame_conditioning_concatenate_mask": current_state.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK)
|
951 |
+
})
|
952 |
+
|
953 |
+
# Save session
|
954 |
+
self.save_session(session_data)
|
955 |
|
956 |
# Update initial training status
|
957 |
total_steps = int(train_steps)
|
vms/ui/project/tabs/caption_tab.py
CHANGED
@@ -11,7 +11,7 @@ from pathlib import Path
|
|
11 |
import mimetypes
|
12 |
|
13 |
from vms.utils import BaseTab, is_image_file, is_video_file, copy_files_to_training_dir
|
14 |
-
from vms.config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX, STAGING_PATH,
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
@@ -52,11 +52,11 @@ class CaptionTab(BaseTab):
|
|
52 |
)
|
53 |
with gr.Row():
|
54 |
self.components["run_autocaption_btn"] = gr.Button(
|
55 |
-
"Automatically
|
56 |
variant="primary"
|
57 |
)
|
58 |
self.components["copy_files_to_training_dir_btn"] = gr.Button(
|
59 |
-
"Copy assets to training
|
60 |
variant="primary"
|
61 |
)
|
62 |
self.components["stop_autocaption_btn"] = gr.Button(
|
|
|
11 |
import mimetypes
|
12 |
|
13 |
from vms.utils import BaseTab, is_image_file, is_video_file, copy_files_to_training_dir
|
14 |
+
from vms.config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX, STAGING_PATH, USE_LARGE_DATASET
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
|
|
52 |
)
|
53 |
with gr.Row():
|
54 |
self.components["run_autocaption_btn"] = gr.Button(
|
55 |
+
"Automatically caption data",
|
56 |
variant="primary"
|
57 |
)
|
58 |
self.components["copy_files_to_training_dir_btn"] = gr.Button(
|
59 |
+
"Copy assets to training folder",
|
60 |
variant="primary"
|
61 |
)
|
62 |
self.components["stop_autocaption_btn"] = gr.Button(
|
vms/ui/project/tabs/manage_tab.py
CHANGED
@@ -10,8 +10,8 @@ from typing import Dict, Any, List, Optional
|
|
10 |
|
11 |
from vms.utils import BaseTab, validate_model_repo
|
12 |
from vms.config import (
|
13 |
-
HF_API_TOKEN, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
|
14 |
-
|
15 |
)
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
|
|
10 |
|
11 |
from vms.utils import BaseTab, validate_model_repo
|
12 |
from vms.config import (
|
13 |
+
HF_API_TOKEN, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
|
14 |
+
USE_LARGE_DATASET
|
15 |
)
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -26,7 +26,10 @@ from vms.config import (
|
|
26 |
DEFAULT_PRECOMPUTATION_ITEMS,
|
27 |
DEFAULT_NB_TRAINING_STEPS,
|
28 |
DEFAULT_NB_LR_WARMUP_STEPS,
|
29 |
-
DEFAULT_AUTO_RESUME
|
|
|
|
|
|
|
30 |
)
|
31 |
|
32 |
logger = logging.getLogger(__name__)
|
@@ -116,18 +119,165 @@ class TrainTab(BaseTab):
|
|
116 |
# LoRA specific parameters (will show/hide based on training type)
|
117 |
with gr.Row(visible=True) as lora_params_row:
|
118 |
self.components["lora_params_row"] = lora_params_row
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
with gr.Row():
|
133 |
self.components["train_steps"] = gr.Number(
|
@@ -426,6 +576,37 @@ class TrainTab(BaseTab):
|
|
426 |
inputs=[self.components["lora_alpha"]],
|
427 |
outputs=[]
|
428 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
|
430 |
self.components["train_steps"].change(
|
431 |
fn=lambda v: self.app.update_ui_state(train_steps=v),
|
@@ -470,11 +651,23 @@ class TrainTab(BaseTab):
|
|
470 |
self.components["save_iterations"],
|
471 |
self.components["preset_info"],
|
472 |
self.components["lora_params_row"],
|
|
|
473 |
self.components["num_gpus"],
|
474 |
self.components["precomputation_items"],
|
475 |
self.components["lr_warmup_steps"],
|
476 |
# Add model_version to the outputs
|
477 |
-
self.components["model_version"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
]
|
479 |
)
|
480 |
|
@@ -702,11 +895,28 @@ class TrainTab(BaseTab):
|
|
702 |
# Get model info text
|
703 |
model_info = self.get_model_info(model_type, training_type)
|
704 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
705 |
# Get default parameters for this model type and training type
|
706 |
params = self.get_default_params(MODEL_TYPES.get(model_type), TRAINING_TYPES.get(training_type))
|
707 |
|
708 |
# Check if LoRA params should be visible
|
709 |
-
show_lora_params = training_type
|
|
|
|
|
|
|
710 |
|
711 |
# Return updates for UI components
|
712 |
return {
|
@@ -715,7 +925,12 @@ class TrainTab(BaseTab):
|
|
715 |
self.components["batch_size"]: params["batch_size"],
|
716 |
self.components["learning_rate"]: params["learning_rate"],
|
717 |
self.components["save_iterations"]: params["save_iterations"],
|
718 |
-
self.components["lora_params_row"]: gr.Row(visible=show_lora_params)
|
|
|
|
|
|
|
|
|
|
|
719 |
}
|
720 |
|
721 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
@@ -729,6 +944,10 @@ class TrainTab(BaseTab):
|
|
729 |
|
730 |
if training_type == "LoRA Finetune":
|
731 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
|
|
|
|
|
|
|
|
732 |
else:
|
733 |
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
734 |
|
@@ -740,6 +959,10 @@ class TrainTab(BaseTab):
|
|
740 |
|
741 |
if training_type == "LoRA Finetune":
|
742 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
|
|
|
|
|
|
|
|
743 |
else:
|
744 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
745 |
|
@@ -751,6 +974,10 @@ class TrainTab(BaseTab):
|
|
751 |
|
752 |
if training_type == "LoRA Finetune":
|
753 |
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
|
|
|
|
|
|
|
|
754 |
else:
|
755 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
756 |
|
@@ -848,7 +1075,11 @@ class TrainTab(BaseTab):
|
|
848 |
info_text = f"{description}{bucket_info}"
|
849 |
|
850 |
# Check if LoRA params should be visible
|
851 |
-
|
|
|
|
|
|
|
|
|
852 |
|
853 |
# Use preset defaults but preserve user-modified values if they exist
|
854 |
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", DEFAULT_LORA_RANK_STR) else preset.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
@@ -861,6 +1092,13 @@ class TrainTab(BaseTab):
|
|
861 |
precomputation_items_val = current_state.get("precomputation_items") if current_state.get("precomputation_items") != preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) else preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS)
|
862 |
lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS)
|
863 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
864 |
# Get the appropriate model version for the selected model type
|
865 |
model_versions = self.get_model_version_choices(model_display_name)
|
866 |
default_model_version = self.get_default_model_version(model_display_name)
|
@@ -896,6 +1134,16 @@ class TrainTab(BaseTab):
|
|
896 |
precomputation_items_val,
|
897 |
lr_warmup_steps_val,
|
898 |
model_version_update,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
899 |
)
|
900 |
|
901 |
|
|
|
26 |
DEFAULT_PRECOMPUTATION_ITEMS,
|
27 |
DEFAULT_NB_TRAINING_STEPS,
|
28 |
DEFAULT_NB_LR_WARMUP_STEPS,
|
29 |
+
DEFAULT_AUTO_RESUME,
|
30 |
+
DEFAULT_CONTROL_TYPE, DEFAULT_TRAIN_QK_NORM,
|
31 |
+
DEFAULT_FRAME_CONDITIONING_TYPE, DEFAULT_FRAME_CONDITIONING_INDEX,
|
32 |
+
DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK
|
33 |
)
|
34 |
|
35 |
logger = logging.getLogger(__name__)
|
|
|
119 |
# LoRA specific parameters (will show/hide based on training type)
|
120 |
with gr.Row(visible=True) as lora_params_row:
|
121 |
self.components["lora_params_row"] = lora_params_row
|
122 |
+
with gr.Column():
|
123 |
+
gr.Markdown("""
|
124 |
+
## 🔄 LoRA Training Parameters
|
125 |
+
|
126 |
+
LoRA (Low-Rank Adaptation) trains small adapter matrices instead of the full model, requiring much less memory while still achieving great results.
|
127 |
+
""")
|
128 |
+
|
129 |
+
# Second row for actual LoRA parameters
|
130 |
+
with gr.Row(visible=True) as lora_settings_row:
|
131 |
+
self.components["lora_settings_row"] = lora_settings_row
|
132 |
+
with gr.Column():
|
133 |
+
self.components["lora_rank"] = gr.Dropdown(
|
134 |
+
label="LoRA Rank",
|
135 |
+
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
136 |
+
value=DEFAULT_LORA_RANK_STR,
|
137 |
+
type="value",
|
138 |
+
info="Controls the size and expressiveness of LoRA adapters. Higher values = better quality but larger file size"
|
139 |
+
)
|
140 |
+
|
141 |
+
with gr.Accordion("What is LoRA Rank?", open=False):
|
142 |
+
gr.Markdown("""
|
143 |
+
**LoRA Rank** determines the complexity of the LoRA adapters:
|
144 |
+
|
145 |
+
- **Lower rank (16-32)**: Smaller file size, faster training, but less expressive
|
146 |
+
- **Medium rank (64-128)**: Good balance between quality and file size
|
147 |
+
- **Higher rank (256-1024)**: More expressive adapters, better quality but larger file size
|
148 |
+
|
149 |
+
Think of rank as the "capacity" of your adapter. Higher ranks can learn more complex modifications to the base model but require more VRAM during training and result in larger files.
|
150 |
+
|
151 |
+
**Quick guide:**
|
152 |
+
- For Wan models: Use 32-64 (Wan models work well with lower ranks)
|
153 |
+
- For LTX-Video: Use 128-256
|
154 |
+
- For Hunyuan Video: Use 128
|
155 |
+
""")
|
156 |
+
|
157 |
+
with gr.Column():
|
158 |
+
self.components["lora_alpha"] = gr.Dropdown(
|
159 |
+
label="LoRA Alpha",
|
160 |
+
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
161 |
+
value=DEFAULT_LORA_ALPHA_STR,
|
162 |
+
type="value",
|
163 |
+
info="Controls the effective learning rate scaling of LoRA adapters. Usually set to same value as rank"
|
164 |
+
)
|
165 |
+
|
166 |
+
with gr.Accordion("What is LoRA Alpha?", open=False):
|
167 |
+
gr.Markdown("""
|
168 |
+
**LoRA Alpha** controls the effective scale of the LoRA updates:
|
169 |
+
|
170 |
+
- The actual scaling factor is calculated as `alpha ÷ rank`
|
171 |
+
- Usually set to match the rank value (alpha = rank)
|
172 |
+
- Higher alpha = stronger effect from the adapters
|
173 |
+
- Lower alpha = more subtle adapter influence
|
174 |
+
|
175 |
+
**Best practice:**
|
176 |
+
- For most cases, set alpha equal to rank
|
177 |
+
- For more aggressive training, set alpha higher than rank
|
178 |
+
- For more conservative training, set alpha lower than rank
|
179 |
+
""")
|
180 |
+
|
181 |
+
|
182 |
+
# Control specific parameters (will show/hide based on training type)
|
183 |
+
with gr.Row(visible=False) as control_params_row:
|
184 |
+
self.components["control_params_row"] = control_params_row
|
185 |
+
with gr.Column():
|
186 |
+
gr.Markdown("""
|
187 |
+
## 🖼️ Control Training Settings
|
188 |
+
|
189 |
+
Control training enables **image-to-video generation** by teaching the model how to use an image as a guide for video creation.
|
190 |
+
This is ideal for turning still images into dynamic videos while preserving composition, style, and content.
|
191 |
+
""")
|
192 |
+
|
193 |
+
# Second row for control parameters
|
194 |
+
with gr.Row(visible=False) as control_settings_row:
|
195 |
+
self.components["control_settings_row"] = control_settings_row
|
196 |
+
with gr.Column():
|
197 |
+
self.components["control_type"] = gr.Dropdown(
|
198 |
+
label="Control Type",
|
199 |
+
choices=["canny", "custom"],
|
200 |
+
value=DEFAULT_CONTROL_TYPE,
|
201 |
+
info="Type of control conditioning. 'canny' uses edge detection preprocessing, 'custom' allows direct image conditioning."
|
202 |
+
)
|
203 |
+
|
204 |
+
with gr.Accordion("What is Control Conditioning?", open=False):
|
205 |
+
gr.Markdown("""
|
206 |
+
**Control Conditioning** allows the model to be guided by an input image, adapting the video generation based on the image content. This is used for image-to-video generation where you want to turn an image into a moving video while maintaining its style, composition or content.
|
207 |
+
|
208 |
+
- **canny**: Uses edge detection to extract outlines from images for structure-preserving video generation
|
209 |
+
- **custom**: Direct image conditioning without preprocessing, preserving more image details
|
210 |
+
""")
|
211 |
+
|
212 |
+
with gr.Column():
|
213 |
+
self.components["train_qk_norm"] = gr.Checkbox(
|
214 |
+
label="Train QK Normalization Layers",
|
215 |
+
value=DEFAULT_TRAIN_QK_NORM,
|
216 |
+
info="Enable to train query-key normalization layers for better control signal integration"
|
217 |
+
)
|
218 |
+
|
219 |
+
with gr.Accordion("What is QK Normalization?", open=False):
|
220 |
+
gr.Markdown("""
|
221 |
+
**QK Normalization** refers to normalizing the query and key values in the attention mechanism of transformers.
|
222 |
+
|
223 |
+
- When enabled, allows the model to better integrate control signals with content generation
|
224 |
+
- Improves training stability for control models
|
225 |
+
- Generally recommended for control training, especially with image conditioning
|
226 |
+
""")
|
227 |
+
|
228 |
+
with gr.Row(visible=False) as frame_conditioning_row:
|
229 |
+
self.components["frame_conditioning_row"] = frame_conditioning_row
|
230 |
+
with gr.Column():
|
231 |
+
self.components["frame_conditioning_type"] = gr.Dropdown(
|
232 |
+
label="Frame Conditioning Type",
|
233 |
+
choices=["index", "prefix", "random", "first_and_last", "full"],
|
234 |
+
value=DEFAULT_FRAME_CONDITIONING_TYPE,
|
235 |
+
info="Determines which frames receive conditioning during training"
|
236 |
+
)
|
237 |
+
|
238 |
+
with gr.Accordion("Frame Conditioning Type Explanation", open=False):
|
239 |
+
gr.Markdown("""
|
240 |
+
**Frame Conditioning Types** determine which frames in the video receive image conditioning:
|
241 |
+
|
242 |
+
- **index**: Only applies conditioning to a single frame at the specified index
|
243 |
+
- **prefix**: Applies conditioning to all frames before a certain point
|
244 |
+
- **random**: Randomly selects frames to receive conditioning during training
|
245 |
+
- **first_and_last**: Only applies conditioning to the first and last frames
|
246 |
+
- **full**: Applies conditioning to all frames in the video
|
247 |
+
|
248 |
+
For image-to-video tasks, 'index' (usually with index 0) is most common as it conditions only the first frame.
|
249 |
+
""")
|
250 |
+
|
251 |
+
with gr.Column():
|
252 |
+
self.components["frame_conditioning_index"] = gr.Number(
|
253 |
+
label="Frame Conditioning Index",
|
254 |
+
value=DEFAULT_FRAME_CONDITIONING_INDEX,
|
255 |
+
precision=0,
|
256 |
+
info="Specifies which frame receives conditioning when using 'index' type (0 = first frame)"
|
257 |
+
)
|
258 |
+
|
259 |
+
with gr.Row(visible=False) as control_options_row:
|
260 |
+
self.components["control_options_row"] = control_options_row
|
261 |
+
with gr.Column():
|
262 |
+
self.components["frame_conditioning_concatenate_mask"] = gr.Checkbox(
|
263 |
+
label="Concatenate Frame Mask",
|
264 |
+
value=DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK,
|
265 |
+
info="Enable to add frame mask information to the conditioning channels"
|
266 |
+
)
|
267 |
+
|
268 |
+
with gr.Accordion("What is Frame Mask Concatenation?", open=False):
|
269 |
+
gr.Markdown("""
|
270 |
+
**Frame Mask Concatenation** adds an additional channel to the control signal that indicates which frames are being conditioned:
|
271 |
+
|
272 |
+
- Creates a binary mask (0/1) indicating which frames receive conditioning
|
273 |
+
- Helps the model distinguish between conditioned and unconditioned frames
|
274 |
+
- Particularly useful for 'index' conditioning where only select frames are conditioned
|
275 |
+
- Generally improves temporal consistency between conditioned and unconditioned frames
|
276 |
+
""")
|
277 |
+
|
278 |
+
with gr.Column():
|
279 |
+
# Empty column for layout balance
|
280 |
+
pass
|
281 |
|
282 |
with gr.Row():
|
283 |
self.components["train_steps"] = gr.Number(
|
|
|
576 |
inputs=[self.components["lora_alpha"]],
|
577 |
outputs=[]
|
578 |
)
|
579 |
+
|
580 |
+
# Control parameters change events
|
581 |
+
self.components["control_type"].change(
|
582 |
+
fn=lambda v: self.app.update_ui_state(control_type=v),
|
583 |
+
inputs=[self.components["control_type"]],
|
584 |
+
outputs=[]
|
585 |
+
)
|
586 |
+
|
587 |
+
self.components["train_qk_norm"].change(
|
588 |
+
fn=lambda v: self.app.update_ui_state(train_qk_norm=v),
|
589 |
+
inputs=[self.components["train_qk_norm"]],
|
590 |
+
outputs=[]
|
591 |
+
)
|
592 |
+
|
593 |
+
self.components["frame_conditioning_type"].change(
|
594 |
+
fn=lambda v: self.app.update_ui_state(frame_conditioning_type=v),
|
595 |
+
inputs=[self.components["frame_conditioning_type"]],
|
596 |
+
outputs=[]
|
597 |
+
)
|
598 |
+
|
599 |
+
self.components["frame_conditioning_index"].change(
|
600 |
+
fn=lambda v: self.app.update_ui_state(frame_conditioning_index=v),
|
601 |
+
inputs=[self.components["frame_conditioning_index"]],
|
602 |
+
outputs=[]
|
603 |
+
)
|
604 |
+
|
605 |
+
self.components["frame_conditioning_concatenate_mask"].change(
|
606 |
+
fn=lambda v: self.app.update_ui_state(frame_conditioning_concatenate_mask=v),
|
607 |
+
inputs=[self.components["frame_conditioning_concatenate_mask"]],
|
608 |
+
outputs=[]
|
609 |
+
)
|
610 |
|
611 |
self.components["train_steps"].change(
|
612 |
fn=lambda v: self.app.update_ui_state(train_steps=v),
|
|
|
651 |
self.components["save_iterations"],
|
652 |
self.components["preset_info"],
|
653 |
self.components["lora_params_row"],
|
654 |
+
self.components["lora_settings_row"],
|
655 |
self.components["num_gpus"],
|
656 |
self.components["precomputation_items"],
|
657 |
self.components["lr_warmup_steps"],
|
658 |
# Add model_version to the outputs
|
659 |
+
self.components["model_version"],
|
660 |
+
# Control parameters rows visibility
|
661 |
+
self.components["control_params_row"],
|
662 |
+
self.components["control_settings_row"],
|
663 |
+
self.components["frame_conditioning_row"],
|
664 |
+
self.components["control_options_row"],
|
665 |
+
# Control parameter values
|
666 |
+
self.components["control_type"],
|
667 |
+
self.components["train_qk_norm"],
|
668 |
+
self.components["frame_conditioning_type"],
|
669 |
+
self.components["frame_conditioning_index"],
|
670 |
+
self.components["frame_conditioning_concatenate_mask"],
|
671 |
]
|
672 |
)
|
673 |
|
|
|
895 |
# Get model info text
|
896 |
model_info = self.get_model_info(model_type, training_type)
|
897 |
|
898 |
+
# Add general information about the selected training type
|
899 |
+
if training_type == "Full Finetune":
|
900 |
+
finetune_info = """
|
901 |
+
## 🧠 Full Finetune Mode
|
902 |
+
|
903 |
+
Full finetune mode trains all parameters of the model, requiring more VRAM but potentially enabling higher quality results.
|
904 |
+
|
905 |
+
- Requires 20-50GB+ VRAM depending on model
|
906 |
+
- Creates a complete standalone model (~8GB+ file size)
|
907 |
+
- Recommended only for high-end GPUs (A100, H100, etc.)
|
908 |
+
- Not recommended for the larger models like Hunyuan Video on consumer hardware
|
909 |
+
"""
|
910 |
+
model_info = finetune_info + "\n\n" + model_info
|
911 |
+
|
912 |
# Get default parameters for this model type and training type
|
913 |
params = self.get_default_params(MODEL_TYPES.get(model_type), TRAINING_TYPES.get(training_type))
|
914 |
|
915 |
# Check if LoRA params should be visible
|
916 |
+
show_lora_params = training_type in ["LoRA Finetune", "Control LoRA"]
|
917 |
+
|
918 |
+
# Check if Control-specific params should be visible
|
919 |
+
show_control_params = training_type in ["Control LoRA", "Control Full Finetune"]
|
920 |
|
921 |
# Return updates for UI components
|
922 |
return {
|
|
|
925 |
self.components["batch_size"]: params["batch_size"],
|
926 |
self.components["learning_rate"]: params["learning_rate"],
|
927 |
self.components["save_iterations"]: params["save_iterations"],
|
928 |
+
self.components["lora_params_row"]: gr.Row(visible=show_lora_params),
|
929 |
+
self.components["lora_settings_row"]: gr.Row(visible=show_lora_params),
|
930 |
+
self.components["control_params_row"]: gr.Row(visible=show_control_params),
|
931 |
+
self.components["control_settings_row"]: gr.Row(visible=show_control_params),
|
932 |
+
self.components["frame_conditioning_row"]: gr.Row(visible=show_control_params),
|
933 |
+
self.components["control_options_row"]: gr.Row(visible=show_control_params)
|
934 |
}
|
935 |
|
936 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
|
|
944 |
|
945 |
if training_type == "LoRA Finetune":
|
946 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
947 |
+
elif training_type == "Control LoRA":
|
948 |
+
return base_info + "\n- Required VRAM: ~20GB minimum\n- Default LoRA rank: 128 (~400 MB)\n- Supports image conditioning"
|
949 |
+
elif training_type == "Control Full Finetune":
|
950 |
+
return base_info + "\n- Required VRAM: ~50GB minimum\n- Supports image conditioning\n- **Not recommended due to VRAM requirements**"
|
951 |
else:
|
952 |
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
953 |
|
|
|
959 |
|
960 |
if training_type == "LoRA Finetune":
|
961 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
962 |
+
elif training_type == "Control LoRA":
|
963 |
+
return base_info + "\n- Required VRAM: ~20GB minimum\n- Default LoRA rank: 128 (~400 MB)\n- Supports image conditioning"
|
964 |
+
elif training_type == "Control Full Finetune":
|
965 |
+
return base_info + "\n- Required VRAM: ~23GB minimum\n- Supports image conditioning"
|
966 |
else:
|
967 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
968 |
|
|
|
974 |
|
975 |
if training_type == "LoRA Finetune":
|
976 |
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
977 |
+
elif training_type == "Control LoRA":
|
978 |
+
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 32 (~120 MB)\n- Supports image conditioning"
|
979 |
+
elif training_type == "Control Full Finetune":
|
980 |
+
return base_info + "\n- Required VRAM: ~40GB minimum\n- Supports image conditioning\n- **Not recommended due to VRAM requirements**"
|
981 |
else:
|
982 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
983 |
|
|
|
1075 |
info_text = f"{description}{bucket_info}"
|
1076 |
|
1077 |
# Check if LoRA params should be visible
|
1078 |
+
training_type_internal = preset["training_type"]
|
1079 |
+
show_lora_params = training_type_internal == "lora" or training_type_internal == "control-lora"
|
1080 |
+
|
1081 |
+
# Check if Control params should be visible
|
1082 |
+
show_control_params = training_type_internal == "control-lora" or training_type_internal == "control-full-finetune"
|
1083 |
|
1084 |
# Use preset defaults but preserve user-modified values if they exist
|
1085 |
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", DEFAULT_LORA_RANK_STR) else preset.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
|
|
1092 |
precomputation_items_val = current_state.get("precomputation_items") if current_state.get("precomputation_items") != preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) else preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS)
|
1093 |
lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS)
|
1094 |
|
1095 |
+
# Control parameters
|
1096 |
+
control_type_val = current_state.get("control_type") if current_state.get("control_type") != preset.get("control_type", DEFAULT_CONTROL_TYPE) else preset.get("control_type", DEFAULT_CONTROL_TYPE)
|
1097 |
+
train_qk_norm_val = current_state.get("train_qk_norm") if current_state.get("train_qk_norm") != preset.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM) else preset.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM)
|
1098 |
+
frame_conditioning_type_val = current_state.get("frame_conditioning_type") if current_state.get("frame_conditioning_type") != preset.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE) else preset.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE)
|
1099 |
+
frame_conditioning_index_val = current_state.get("frame_conditioning_index") if current_state.get("frame_conditioning_index") != preset.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX) else preset.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX)
|
1100 |
+
frame_conditioning_concatenate_mask_val = current_state.get("frame_conditioning_concatenate_mask") if current_state.get("frame_conditioning_concatenate_mask") != preset.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK) else preset.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK)
|
1101 |
+
|
1102 |
# Get the appropriate model version for the selected model type
|
1103 |
model_versions = self.get_model_version_choices(model_display_name)
|
1104 |
default_model_version = self.get_default_model_version(model_display_name)
|
|
|
1134 |
precomputation_items_val,
|
1135 |
lr_warmup_steps_val,
|
1136 |
model_version_update,
|
1137 |
+
# Control parameters rows visibility
|
1138 |
+
gr.Row(visible=show_control_params),
|
1139 |
+
gr.Row(visible=show_control_params),
|
1140 |
+
gr.Row(visible=show_control_params),
|
1141 |
+
# Control parameter values
|
1142 |
+
control_type_val,
|
1143 |
+
train_qk_norm_val,
|
1144 |
+
frame_conditioning_type_val,
|
1145 |
+
frame_conditioning_index_val,
|
1146 |
+
frame_conditioning_concatenate_mask_val,
|
1147 |
)
|
1148 |
|
1149 |
|