VisionScout / object_extractor.py
DawnC's picture
Update object_extractor.py
95b3ba7 verified
import logging
import traceback
from typing import Dict, List, Any, Optional
logger = logging.getLogger(__name__)
class ObjectExtractor:
"""
專門處理物件檢測結果的提取和預處理
負責從YOLO檢測結果提取物件資訊、物件分類和核心物件的辨識
"""
def __init__(self, class_names: Dict[int, str] = None, object_categories: Dict[str, List[int]] = None):
"""
初始化物件提取器
Args:
class_names: 類別ID到類別名稱的映射字典
object_categories: 物件類別分組字典
"""
try:
self.class_names = class_names or {}
self.object_categories = object_categories or {}
# 1. 讀取並設定基本信心度門檻(如果外部沒傳,就預設 0.25)
self.base_conf_threshold = 0.25
# 2. 動態信心度調整映射表 (key: 小寫 class_name, value: 調整係數)
# 最終的門檻 = base_conf_threshold * factor
# 如果某個 class_name 沒在這裡,就直接用 base_conf_threshold(相當於 factor=1.0)
self.dynamic_conf_map = {
"traffic light": 0.6,
"car": 0.8,
"person": 0.7,
}
logger.info(f"ObjectExtractor initialized with {len(self.class_names)} class names and {len(self.object_categories)} object categories")
except Exception as e:
logger.error(f"Failed to initialize ObjectExtractor: {str(e)}")
logger.error(traceback.format_exc())
raise
def _get_dynamic_threshold(self, class_name: str) -> float:
"""
根據 class_name 從 dynamic_conf_map 拿到 factor,計算最終的信心度門檻:
threshold = base_conf_threshold * factor
如果 class_name 不在映射表裡,就回傳 base_conf_threshold。
"""
# 使用小寫做匹配,確保在 dynamic_conf_map 裡的 key 也都用小寫
key = class_name.lower()
factor = self.dynamic_conf_map.get(key, 1.0)
return self.base_conf_threshold * factor
def extract_detected_objects(
self,
detection_result: Any,
confidence_threshold: float = 0.25,
region_analyzer=None
) -> List[Dict]:
"""
從檢測結果中提取物件資訊,包含位置資訊
Args:
detection_result: YOLO檢測結果
confidence_threshold: 改由動態門檻決定
region_analyzer: 區域分析器實例,用於判斷物件所屬區域
Returns:
包含檢測物件資訊的字典列表
"""
try:
# 調試信息:記錄當前類別映射狀態
logger.info(f"ObjectExtractor.extract_detected_objects called")
logger.info(f"Current class_names keys: {list(self.class_names.keys()) if self.class_names else 'None'}")
if detection_result is None:
logger.warning("Detection result is None")
return []
if not hasattr(detection_result, 'boxes'):
logger.error("Detection result does not have boxes attribute")
return []
boxes = detection_result.boxes.xyxy.cpu().numpy()
classes = detection_result.boxes.cls.cpu().numpy().astype(int)
confidences = detection_result.boxes.conf.cpu().numpy()
# 獲取圖像尺寸
img_height, img_width = detection_result.orig_shape[:2]
detected_objects = []
for box, class_id, confidence in zip(boxes, classes, confidences):
try:
# 1. 先拿到這筆偵測物件的 class_name
class_name = self.class_names.get(int(class_id), f"unknown_class_{class_id}")
# 2. 計算這個 class 應該採用的動態 threshold
dyn_thr = self._get_dynamic_threshold(class_name) # e.g. 0.25 * factor
# 3. 如果 confidence < dyn_thr,就跳過這一筆
if confidence < dyn_thr:
continue
# 後面維持原本的座標、中心、大小、區域等資訊計算
x1, y1, x2, y2 = box
width = x2 - x1
height = y2 - y1
# 中心點計算
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
# 標準化位置 (0-1)
norm_x = center_x / img_width
norm_y = center_y / img_height
norm_width = width / img_width
norm_height = height / img_height
# 面積計算
area = width * height
norm_area = area / (img_width * img_height)
# 區域判斷
object_region = "unknown"
if region_analyzer:
object_region = region_analyzer.determine_region(norm_x, norm_y)
# 調試信息:記錄映射過程
if class_name.startswith("unknown_class_"):
logger.warning(
f"Class ID {class_id} not found in class_names. "
f"Available keys: {list(self.class_names.keys())}"
)
else:
logger.debug(f"Successfully mapped class ID {class_id} to '{class_name}'")
detected_objects.append({
"class_id": int(class_id),
"class_name": class_name,
"confidence": float(confidence),
"box": [float(x1), float(y1), float(x2), float(y2)],
"center": [float(center_x), float(center_y)],
"normalized_center": [float(norm_x), float(norm_y)],
"size": [float(width), float(height)],
"normalized_size": [float(norm_width), float(norm_height)],
"area": float(area),
"normalized_area": float(norm_area),
"region": object_region
})
except Exception as e:
logger.error(f"Error processing object with class_id {class_id}: {str(e)}")
continue
logger.info(f"Extracted {len(detected_objects)} objects from detection result")
# print(f"DEBUG: ObjectExtractor filtered objects by class:")
# for class_name in ["car", "traffic light", "person", "handbag"]:
# class_objects = [obj for obj in detected_objects if obj.get("class_name") == class_name]
# if class_objects:
# confidences = [obj.get("confidence", 0) for obj in class_objects]
# print(f"DEBUG: {class_name}: {len(class_objects)} objects, confidences: {confidences}")
# print(f"DEBUG: base_conf_threshold: {self.base_conf_threshold}")
# print(f"DEBUG: dynamic_conf_map: {self.dynamic_conf_map}")
return detected_objects
except Exception as e:
logger.error(f"Error extracting detected objects: {str(e)}")
logger.error(traceback.format_exc())
return []
def update_class_names(self, class_names: Dict[int, str]):
"""
動態更新類別名稱映射
Args:
class_names: 新的類別名稱映射字典
"""
try:
self.class_names = class_names or {}
logger.info(f"Class names updated: {len(self.class_names)} classes")
logger.debug(f"Updated class names: {self.class_names}")
except Exception as e:
logger.error(f"Failed to update class names: {str(e)}")
def categorize_object(self, obj: Dict) -> str:
"""
將檢測到的物件分類到功能類別中,用於區域識別
Args:
obj: 物件字典
Returns:
物件功能類別字串
"""
try:
class_id = obj.get("class_id", -1)
class_name = obj.get("class_name", "").lower()
# 使用現有的類別映射(如果可用)
if self.object_categories:
for category, ids in self.object_categories.items():
if class_id in ids:
return category
# 基於COCO類別名稱的後備分類
furniture_items = ["chair", "couch", "bed", "dining table", "toilet"]
plant_items = ["potted plant"]
electronic_items = ["tv", "laptop", "mouse", "remote", "keyboard", "cell phone"]
vehicle_items = ["bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat"]
person_items = ["person"]
kitchen_items = ["bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog",
"pizza", "donut", "cake", "refrigerator", "oven", "toaster", "sink", "microwave"]
sports_items = ["frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket"]
personal_items = ["handbag", "tie", "suitcase", "umbrella", "backpack"]
if any(item in class_name for item in furniture_items):
return "furniture"
elif any(item in class_name for item in plant_items):
return "plant"
elif any(item in class_name for item in electronic_items):
return "electronics"
elif any(item in class_name for item in vehicle_items):
return "vehicle"
elif any(item in class_name for item in person_items):
return "person"
elif any(item in class_name for item in kitchen_items):
return "kitchen_items"
elif any(item in class_name for item in sports_items):
return "sports"
elif any(item in class_name for item in personal_items):
return "personal_items"
else:
return "misc"
except Exception as e:
logger.error(f"Error categorizing object: {str(e)}")
logger.error(traceback.format_exc())
return "misc"
def get_object_categories(self, detected_objects: List[Dict]) -> set:
"""
從檢測到的物件中取得唯一的物件類別
Args:
detected_objects: 檢測到的物件列表
Returns:
唯一物件類別的集合
"""
try:
object_categories = set()
for obj in detected_objects:
category = self.categorize_object(obj)
if category:
object_categories.add(category)
logger.info(f"Found {len(object_categories)} unique object categories")
return object_categories
except Exception as e:
logger.error(f"Error getting object categories: {str(e)}")
logger.error(traceback.format_exc())
return set()
def identify_core_objects_for_scene(self, detected_objects: List[Dict], scene_type: str) -> List[Dict]:
"""
識別定義特定場景類型的核心物件
Args:
detected_objects: 檢測到的物件列表
scene_type: 場景類型
Returns:
場景的核心物件列表
"""
try:
core_objects = []
# 場景核心物件映射
scene_core_mapping = {
"bedroom": [59], # bed
"kitchen": [68, 69, 71, 72], # microwave, oven, sink, refrigerator
"living_room": [57, 58, 62], # sofa, chair, tv
"dining_area": [60, 42, 43], # dining table, fork, knife
"office_workspace": [63, 64, 66, 73] # laptop, mouse, keyboard, book
}
if scene_type in scene_core_mapping:
core_class_ids = scene_core_mapping[scene_type]
for obj in detected_objects:
if obj.get("class_id") in core_class_ids and obj.get("confidence", 0) >= 0.4:
core_objects.append(obj)
logger.info(f"Identified {len(core_objects)} core objects for scene type '{scene_type}'")
return core_objects
except Exception as e:
logger.error(f"Error identifying core objects for scene '{scene_type}': {str(e)}")
logger.error(traceback.format_exc())
return []
def group_objects_by_category_and_region(self, detected_objects: List[Dict]) -> Dict:
"""
將物件按類別和區域分組
Args:
detected_objects: 檢測到的物件列表
Returns:
按類別和區域分組的物件字典
"""
try:
category_regions = {}
for obj in detected_objects:
category = self.categorize_object(obj)
if not category:
continue
if category not in category_regions:
category_regions[category] = {}
region = obj.get("region", "center")
if region not in category_regions[category]:
category_regions[category][region] = []
category_regions[category][region].append(obj)
logger.info(f"Grouped objects into {len(category_regions)} categories across regions")
return category_regions
except Exception as e:
logger.error(f"Error grouping objects by category and region: {str(e)}")
logger.error(traceback.format_exc())
return {}
def filter_objects_by_confidence(self, detected_objects: List[Dict], min_confidence: float) -> List[Dict]:
"""
根據信心度過濾物件
Args:
detected_objects: 檢測到的物件列表
min_confidence: 最小信心度閾值
Returns:
過濾後的物件列表
"""
try:
filtered_objects = [
obj for obj in detected_objects
if obj.get("confidence", 0) >= min_confidence
]
logger.info(f"Filtered {len(detected_objects)} objects to {len(filtered_objects)} objects with confidence >= {min_confidence}")
return filtered_objects
except Exception as e:
logger.error(f"Error filtering objects by confidence: {str(e)}")
logger.error(traceback.format_exc())
return detected_objects # 發生錯誤時返回原始列表