Spaces:
Running
Running
""" | |
Models service for Video Model Studio | |
Handles the model history tracking and management | |
""" | |
import logging | |
import json | |
from pathlib import Path | |
from typing import Dict, Any, List, Optional, Tuple | |
from dataclasses import dataclass | |
from datetime import datetime | |
from vms.config import ( | |
STORAGE_PATH, MODEL_TYPES, TRAINING_TYPES | |
) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
class Model: | |
"""Class for tracking model metadata""" | |
id: str | |
status: str # 'draft', 'training', 'trained', 'error' | |
model_type: str # Base model family (e.g. 'hunyuan_video', 'ltx_video', 'wan') | |
model_display_name: str # Display name for the model type | |
created_at: datetime | |
updated_at: datetime | |
training_progress: Optional[float] = 0.0 # Progress as percentage | |
current_step: Optional[int] = 0 | |
total_steps: Optional[int] = 0 | |
def from_dir(cls, model_dir: Path) -> 'Model': | |
"""Create a Model instance from a directory""" | |
model_id = model_dir.name | |
# Default values | |
status = 'draft' | |
model_type = '' | |
model_display_name = '' | |
created_at = datetime.fromtimestamp(model_dir.stat().st_ctime) | |
updated_at = datetime.fromtimestamp(model_dir.stat().st_mtime) | |
training_progress = 0.0 | |
current_step = 0 | |
total_steps = 0 | |
# Check for UI state file | |
ui_state_file = model_dir / "output" / "ui_state.json" | |
if ui_state_file.exists(): | |
try: | |
with open(ui_state_file, 'r') as f: | |
ui_state = json.load(f) | |
status = ui_state.get('project_status', 'draft') | |
# Get model type from UI state | |
model_type_value = ui_state.get('model_type', '') | |
# First check if model_type_value is a display name | |
display_name_found = False | |
for display_name, internal_name in MODEL_TYPES.items(): | |
if display_name == model_type_value: | |
model_type = internal_name | |
model_display_name = display_name | |
display_name_found = True | |
break | |
# If not a display name, check if it's an internal name | |
if not display_name_found: | |
for display_name, internal_name in MODEL_TYPES.items(): | |
if internal_name == model_type_value: | |
model_type = internal_name | |
model_display_name = display_name | |
break | |
except Exception as e: | |
logger.error(f"Error loading UI state for model {model_id}: {str(e)}") | |
# Check for status file to get training progress | |
status_file = model_dir / "output" / "status.json" | |
if status_file.exists(): | |
try: | |
with open(status_file, 'r') as f: | |
status_data = json.load(f) | |
if status_data.get('status') == 'training': | |
status = 'training' | |
current_step = status_data.get('step', 0) | |
total_steps = status_data.get('total_steps', 0) | |
if total_steps > 0: | |
training_progress = (current_step / total_steps) * 100 | |
elif status_data.get('status') == 'completed': | |
status = 'trained' | |
training_progress = 100.0 | |
elif status_data.get('status') == 'error': | |
status = 'error' | |
except Exception as e: | |
logger.error(f"Error loading status for model {model_id}: {str(e)}") | |
# Check for pid file to determine if training is active | |
pid_file = model_dir / "output" / "training.pid" | |
if pid_file.exists(): | |
status = 'training' | |
# Check for model weights to determine if trained | |
model_weights = model_dir / "output" / "pytorch_lora_weights.safetensors" | |
if model_weights.exists() and status != 'training': | |
status = 'trained' | |
training_progress = 100.0 | |
return cls( | |
id=model_id, | |
status=status, | |
model_type=model_type, | |
model_display_name=model_display_name, | |
created_at=created_at, | |
updated_at=updated_at, | |
training_progress=training_progress, | |
current_step=current_step, | |
total_steps=total_steps | |
) | |
class ModelsService: | |
"""Service for tracking and managing model history""" | |
def __init__(self, app_state=None): | |
"""Initialize the models service | |
Args: | |
app_state: Reference to main application state | |
""" | |
self.app = app_state | |
def get_all_models(self) -> List[Model]: | |
"""Get a list of all models | |
Returns: | |
List of Model objects | |
""" | |
models_dir = STORAGE_PATH / "models" | |
if not models_dir.exists(): | |
return [] | |
models = [] | |
for model_dir in models_dir.iterdir(): | |
if not model_dir.is_dir(): | |
continue | |
try: | |
model = Model.from_dir(model_dir) | |
models.append(model) | |
except Exception as e: | |
logger.error(f"Error loading model from {model_dir}: {str(e)}") | |
# Sort models by updated_at (newest first) | |
return sorted(models, key=lambda m: m.updated_at, reverse=True) | |
def get_draft_models(self) -> List[Model]: | |
"""Get a list of draft models | |
Returns: | |
List of Model objects with 'draft' status | |
""" | |
return [m for m in self.get_all_models() if m.status == 'draft'] | |
def get_training_models(self) -> List[Model]: | |
"""Get a list of models currently in training | |
Returns: | |
List of Model objects with 'training' status | |
""" | |
return [m for m in self.get_all_models() if m.status == 'training'] | |
def get_trained_models(self) -> List[Model]: | |
"""Get a list of completed trained models | |
Returns: | |
List of Model objects with 'trained' status | |
""" | |
return [m for m in self.get_all_models() if m.status == 'trained'] | |
def delete_model(self, model_id: str) -> bool: | |
"""Delete a model by ID | |
Args: | |
model_id: The model ID to delete | |
Returns: | |
True if deletion was successful | |
""" | |
if not model_id: | |
return False | |
model_dir = STORAGE_PATH / "models" / model_id | |
if not model_dir.exists(): | |
return False | |
try: | |
import shutil | |
shutil.rmtree(model_dir) | |
return True | |
except Exception as e: | |
logger.error(f"Error deleting model {model_id}: {str(e)}") | |
return False |