File size: 7,329 Bytes
b7cc217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
914fbd3
b7cc217
914fbd3
 
b7cc217
914fbd3
b7cc217
 
914fbd3
b7cc217
914fbd3
 
 
 
 
 
 
 
b7cc217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
Models service for Video Model Studio

Handles the model history tracking and management
"""

import logging
import json
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime

from vms.config import (
    STORAGE_PATH, MODEL_TYPES, TRAINING_TYPES
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

@dataclass
class Model:
    """Class for tracking model metadata"""
    id: str
    status: str  # 'draft', 'training', 'trained', 'error'
    model_type: str  # Base model family (e.g. 'hunyuan_video', 'ltx_video', 'wan')
    model_display_name: str  # Display name for the model type
    created_at: datetime
    updated_at: datetime
    training_progress: Optional[float] = 0.0  # Progress as percentage
    current_step: Optional[int] = 0
    total_steps: Optional[int] = 0
    
    @classmethod
    def from_dir(cls, model_dir: Path) -> 'Model':
        """Create a Model instance from a directory"""
        model_id = model_dir.name
        
        # Default values
        status = 'draft'
        model_type = ''
        model_display_name = ''
        created_at = datetime.fromtimestamp(model_dir.stat().st_ctime)
        updated_at = datetime.fromtimestamp(model_dir.stat().st_mtime)
        training_progress = 0.0
        current_step = 0
        total_steps = 0
        
        # Check for UI state file
        ui_state_file = model_dir / "output" / "ui_state.json"
        if ui_state_file.exists():
            try:
                with open(ui_state_file, 'r') as f:
                    ui_state = json.load(f)
                    status = ui_state.get('project_status', 'draft')
                    
                    # Get model type from UI state
                    model_type_value = ui_state.get('model_type', '')
                    
                    # First check if model_type_value is a display name
                    display_name_found = False
                    for display_name, internal_name in MODEL_TYPES.items():
                        if display_name == model_type_value:
                            model_type = internal_name
                            model_display_name = display_name
                            display_name_found = True
                            break
                    
                    # If not a display name, check if it's an internal name
                    if not display_name_found:
                        for display_name, internal_name in MODEL_TYPES.items():
                            if internal_name == model_type_value:
                                model_type = internal_name
                                model_display_name = display_name
                                break
            except Exception as e:
                logger.error(f"Error loading UI state for model {model_id}: {str(e)}")
        
        # Check for status file to get training progress
        status_file = model_dir / "output" / "status.json"
        if status_file.exists():
            try:
                with open(status_file, 'r') as f:
                    status_data = json.load(f)
                    if status_data.get('status') == 'training':
                        status = 'training'
                        current_step = status_data.get('step', 0)
                        total_steps = status_data.get('total_steps', 0)
                        if total_steps > 0:
                            training_progress = (current_step / total_steps) * 100
                    elif status_data.get('status') == 'completed':
                        status = 'trained'
                        training_progress = 100.0
                    elif status_data.get('status') == 'error':
                        status = 'error'
            except Exception as e:
                logger.error(f"Error loading status for model {model_id}: {str(e)}")
        
        # Check for pid file to determine if training is active
        pid_file = model_dir / "output" / "training.pid"
        if pid_file.exists():
            status = 'training'
        
        # Check for model weights to determine if trained
        model_weights = model_dir / "output" / "pytorch_lora_weights.safetensors"
        if model_weights.exists() and status != 'training':
            status = 'trained'
            training_progress = 100.0
        
        return cls(
            id=model_id,
            status=status,
            model_type=model_type,
            model_display_name=model_display_name,
            created_at=created_at,
            updated_at=updated_at,
            training_progress=training_progress,
            current_step=current_step,
            total_steps=total_steps
        )

class ModelsService:
    """Service for tracking and managing model history"""
    
    def __init__(self, app_state=None):
        """Initialize the models service
        
        Args:
            app_state: Reference to main application state
        """
        self.app = app_state
    
    def get_all_models(self) -> List[Model]:
        """Get a list of all models
        
        Returns:
            List of Model objects
        """
        models_dir = STORAGE_PATH / "models"
        if not models_dir.exists():
            return []
            
        models = []
        
        for model_dir in models_dir.iterdir():
            if not model_dir.is_dir():
                continue
                
            try:
                model = Model.from_dir(model_dir)
                models.append(model)
            except Exception as e:
                logger.error(f"Error loading model from {model_dir}: {str(e)}")
                
        # Sort models by updated_at (newest first)
        return sorted(models, key=lambda m: m.updated_at, reverse=True)
    
    def get_draft_models(self) -> List[Model]:
        """Get a list of draft models
        
        Returns:
            List of Model objects with 'draft' status
        """
        return [m for m in self.get_all_models() if m.status == 'draft']
    
    def get_training_models(self) -> List[Model]:
        """Get a list of models currently in training
        
        Returns:
            List of Model objects with 'training' status
        """
        return [m for m in self.get_all_models() if m.status == 'training']
    
    def get_trained_models(self) -> List[Model]:
        """Get a list of completed trained models
        
        Returns:
            List of Model objects with 'trained' status
        """
        return [m for m in self.get_all_models() if m.status == 'trained']
        
    def delete_model(self, model_id: str) -> bool:
        """Delete a model by ID
        
        Args:
            model_id: The model ID to delete
            
        Returns:
            True if deletion was successful
        """
        if not model_id:
            return False
            
        model_dir = STORAGE_PATH / "models" / model_id
        if not model_dir.exists():
            return False
            
        try:
            import shutil
            shutil.rmtree(model_dir)
            return True
        except Exception as e:
            logger.error(f"Error deleting model {model_id}: {str(e)}")
            return False