Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,830 Bytes
e6a18b7 beedc2e e6a18b7 beedc2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 |
import os
import traceback
import logging
from typing import Dict, Optional, Any, Tuple
from spatial_analyzer import SpatialAnalyzer
from scene_description import SceneDescriptor
from enhanced_scene_describer import EnhancedSceneDescriber
from clip_analyzer import CLIPAnalyzer
from clip_zero_shot_classifier import CLIPZeroShotClassifier
from llm_enhancer import LLMEnhancer
from landmark_activities import LANDMARK_ACTIVITIES
from scene_type import SCENE_TYPES
from object_categories import OBJECT_CATEGORIES
class ComponentInitializer:
"""
負責初始化和管理 SceneAnalyzer 的所有子組件。
處理組件初始化失敗的情況並提供優雅的降級機制。
"""
def __init__(self, class_names: Dict[int, str] = None, use_llm: bool = True,
use_clip: bool = True, enable_landmark: bool = True,
llm_model_path: str = None):
"""
初始化組件管理器。
Args:
class_names: YOLO 類別 ID 到名稱的映射字典
use_llm: 是否啟用 LLM 增強功能
use_clip: 是否啟用 CLIP 分析功能
enable_landmark: 是否啟用地標檢測功能
llm_model_path: LLM 模型路徑(可選)
"""
self.logger = logging.getLogger(__name__)
# 存儲初始化參數
self.class_names = class_names
self.use_llm = use_llm
self.use_clip = use_clip
self.enable_landmark = enable_landmark
self.llm_model_path = llm_model_path
# 初始化組件容器
self.components = {}
self.data_structures = {}
self.initialization_status = {}
# 初始化所有組件
self._initialize_all_components()
def _initialize_all_components(self):
"""初始化所有必要的組件和數據結構。"""
try:
# 1. 首先載入數據
self._load_data_structures()
# 2. 初始化核心分析組件
self._initialize_core_analyzers()
# 3. 初始化 CLIP 相關內容
if self.use_clip:
self._initialize_clip_components()
# 4. 初始化 LLM 組件
if self.use_llm:
self._initialize_llm_components()
self.logger.info("All components initialized successfully")
except Exception as e:
self.logger.error(f"Error during component initialization: {e}")
traceback.print_exc()
raise
def _load_data_structures(self):
"""載入必要的數據結構。"""
data_loaders = {
'LANDMARK_ACTIVITIES': self._load_landmark_activities,
'SCENE_TYPES': self._load_scene_types,
'OBJECT_CATEGORIES': self._load_object_categories
}
for data_name, loader_func in data_loaders.items():
try:
self.data_structures[data_name] = loader_func()
self.initialization_status[data_name] = True
self.logger.info(f"Loaded {data_name} successfully")
except Exception as e:
self.logger.warning(f"Failed to load {data_name}: {e}")
self.data_structures[data_name] = {}
self.initialization_status[data_name] = False
def _load_landmark_activities(self) -> Dict:
"""載入地標活動數據。"""
try:
return LANDMARK_ACTIVITIES
except ImportError as e:
self.logger.warning(f"Could not import LANDMARK_ACTIVITIES: {e}")
return {}
def _load_scene_types(self) -> Dict:
"""載入場景類型數據。"""
try:
return SCENE_TYPES
except ImportError as e:
self.logger.warning(f"Could not import SCENE_TYPES: {e}")
return {}
def _load_object_categories(self) -> Dict:
"""載入物體類別數據。"""
try:
return OBJECT_CATEGORIES
except ImportError as e:
self.logger.warning(f"Could not import OBJECT_CATEGORIES: {e}")
return {}
def _initialize_core_analyzers(self):
"""初始化核心分析組件。"""
# 初始化 SpatialAnalyzer
try:
self.components['spatial_analyzer'] = SpatialAnalyzer(
class_names=self.class_names,
object_categories=self.data_structures.get('OBJECT_CATEGORIES', {})
)
self.initialization_status['spatial_analyzer'] = True
self.logger.info("Initialized SpatialAnalyzer successfully")
except Exception as e:
self.logger.error(f"Error initializing SpatialAnalyzer: {e}")
traceback.print_exc()
self.initialization_status['spatial_analyzer'] = False
self.components['spatial_analyzer'] = None
# 初始化 SceneDescriptor
try:
self.components['descriptor'] = SceneDescriptor(
scene_types=self.data_structures.get('SCENE_TYPES', {}),
object_categories=self.data_structures.get('OBJECT_CATEGORIES', {})
)
self.initialization_status['descriptor'] = True
self.logger.info("Initialized SceneDescriptor successfully")
except Exception as e:
self.logger.error(f"Error initializing SceneDescriptor: {e}")
traceback.print_exc()
self.initialization_status['descriptor'] = False
self.components['descriptor'] = None
# 初始化 EnhancedSceneDescriber
try:
if self.components.get('spatial_analyzer'):
self.components['scene_describer'] = EnhancedSceneDescriber(
scene_types=self.data_structures.get('SCENE_TYPES', {}),
spatial_analyzer_instance=self.components['spatial_analyzer']
)
self.initialization_status['scene_describer'] = True
self.logger.info("Initialized EnhancedSceneDescriber successfully")
else:
self.logger.warning("Cannot initialize EnhancedSceneDescriber without SpatialAnalyzer")
self.initialization_status['scene_describer'] = False
self.components['scene_describer'] = None
except Exception as e:
self.logger.error(f"Error initializing EnhancedSceneDescriber: {e}")
traceback.print_exc()
self.initialization_status['scene_describer'] = False
self.components['scene_describer'] = None
def _initialize_clip_components(self):
"""初始化 CLIP 相關組件。"""
# 初始化 CLIPAnalyzer
try:
self.components['clip_analyzer'] = CLIPAnalyzer()
self.initialization_status['clip_analyzer'] = True
self.logger.info("Initialized CLIPAnalyzer successfully")
# 如果啟用地標檢測,初始化 CLIPZeroShotClassifier
if self.enable_landmark:
self._initialize_landmark_classifier()
except Exception as e:
self.logger.warning(f"Could not initialize CLIP analyzer: {e}")
self.logger.info("Scene analysis will proceed without CLIP. Install CLIP with 'pip install clip' for enhanced scene understanding.")
self.use_clip = False
self.initialization_status['clip_analyzer'] = False
self.components['clip_analyzer'] = None
def _initialize_landmark_classifier(self):
"""初始化地標分類器。"""
try:
# 嘗試使用已載入的 CLIP 模型實例
if (self.components.get('clip_analyzer') and
hasattr(self.components['clip_analyzer'], 'get_clip_instance')):
model, preprocess, device = self.components['clip_analyzer'].get_clip_instance()
self.components['landmark_classifier'] = CLIPZeroShotClassifier(device=device)
self.logger.info("Initialized landmark classifier with shared CLIP model")
else:
self.components['landmark_classifier'] = CLIPZeroShotClassifier()
self.logger.info("Initialized landmark classifier with independent CLIP model")
# 配置地標檢測器參數
self._configure_landmark_classifier()
self.initialization_status['landmark_classifier'] = True
except (ImportError, Exception) as e:
self.logger.warning(f"Could not initialize landmark classifier: {e}")
self.initialization_status['landmark_classifier'] = False
self.components['landmark_classifier'] = None
# 不完全禁用地標檢測,允許運行時重新嘗試
def _configure_landmark_classifier(self):
"""配置地標分類器的參數。"""
if self.components.get('landmark_classifier'):
try:
classifier = self.components['landmark_classifier']
classifier.set_batch_size(8)
classifier.adjust_confidence_threshold("full_image", 0.8)
classifier.adjust_confidence_threshold("distant", 0.65)
self.logger.info("Landmark detection enabled with optimized settings")
except Exception as e:
self.logger.warning(f"Error configuring landmark classifier: {e}")
def _initialize_llm_components(self):
"""初始化 LLM 組件。"""
try:
self.components['llm_enhancer'] = LLMEnhancer(model_path=self.llm_model_path)
self.initialization_status['llm_enhancer'] = True
self.logger.info("LLM enhancer initialized successfully")
except Exception as e:
self.logger.warning(f"Could not initialize LLM enhancer: {e}")
self.logger.info("Scene analysis will proceed without LLM. Make sure required packages are installed.")
self.use_llm = False
self.initialization_status['llm_enhancer'] = False
self.components['llm_enhancer'] = None
def get_component(self, component_name: str) -> Optional[Any]:
"""
獲取指定的組件實例。
Args:
component_name: 組件名稱
Returns:
組件實例或 None(如果未初始化成功)
"""
return self.components.get(component_name)
def get_data_structure(self, data_name: str) -> Dict:
"""
獲取指定的數據結構。
Args:
data_name: 數據結構名稱
Returns:
數據結構字典
"""
return self.data_structures.get(data_name, {})
def is_component_available(self, component_name: str) -> bool:
"""
檢查指定組件是否可用。
Args:
component_name: 組件名稱
Returns:
組件是否可用
"""
return self.initialization_status.get(component_name, False)
def get_initialization_summary(self) -> Dict[str, bool]:
"""
獲取所有組件的初始化狀態摘要。
Returns:
組件名稱到初始化狀態的映射
"""
return self.initialization_status.copy()
def reinitialize_component(self, component_name: str) -> bool:
"""
重新初始化指定的組件。
Args:
component_name: 要重新初始化的組件名稱
Returns:
重新初始化是否成功
"""
try:
if component_name == 'landmark_classifier' and self.use_clip and self.enable_landmark:
self._initialize_landmark_classifier()
return self.initialization_status.get('landmark_classifier', False)
else:
self.logger.warning(f"Reinitializing {component_name} is not supported")
return False
except Exception as e:
self.logger.error(f"Error reinitializing {component_name}: {e}")
return False
def update_landmark_enable_status(self, enable_landmark: bool):
"""
更新地標檢測的啟用狀態。
Args:
enable_landmark: 是否啟用地標檢測
"""
self.enable_landmark = enable_landmark
# 如果啟用地標檢測但分類器不可用,嘗試重新初始化
if enable_landmark and not self.is_component_available('landmark_classifier'):
if self.use_clip:
self.reinitialize_component('landmark_classifier')
# 更新相關組件的狀態
for component_name in ['scene_describer', 'clip_analyzer', 'landmark_classifier']:
component = self.get_component(component_name)
if component and hasattr(component, 'enable_landmark'):
component.enable_landmark = enable_landmark |