File size: 15,469 Bytes
e6a18b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3868ba
e6a18b7
 
 
 
 
e3868ba
 
 
e6a18b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be82503
ac7b808
 
 
 
 
 
 
 
be82503
e6a18b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b3ba7
e6a18b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368

import logging
import traceback
from typing import Dict, List, Any, Optional

logger = logging.getLogger(__name__)

class ObjectExtractor:
    """
    專門處理物件檢測結果的提取和預處理
    負責從YOLO檢測結果提取物件資訊、物件分類和核心物件的辨識
    """

    def __init__(self, class_names: Dict[int, str] = None, object_categories: Dict[str, List[int]] = None):
        """
        初始化物件提取器

        Args:
            class_names: 類別ID到類別名稱的映射字典
            object_categories: 物件類別分組字典
        """
        try:
            self.class_names = class_names or {}
            self.object_categories = object_categories or {}

            # 1. 讀取並設定基本信心度門檻(如果外部沒傳,就預設 0.25)
            self.base_conf_threshold = 0.25

            # 2. 動態信心度調整映射表 (key: 小寫 class_name, value: 調整係數)
            #    最終的門檻 = base_conf_threshold * factor
            #    如果某個 class_name 沒在這裡,就直接用 base_conf_threshold(相當於 factor=1.0)
            self.dynamic_conf_map = {
                "traffic light": 0.6,  
                "car": 0.8,            
                "person": 0.7,         
                
            }

            logger.info(f"ObjectExtractor initialized with {len(self.class_names)} class names and {len(self.object_categories)} object categories")

        except Exception as e:
            logger.error(f"Failed to initialize ObjectExtractor: {str(e)}")
            logger.error(traceback.format_exc())
            raise

    def _get_dynamic_threshold(self, class_name: str) -> float:
        """
        根據 class_name 從 dynamic_conf_map 拿到 factor,計算最終的信心度門檻:
            threshold = base_conf_threshold * factor

        如果 class_name 不在映射表裡,就回傳 base_conf_threshold。
        """
        # 使用小寫做匹配,確保在 dynamic_conf_map 裡的 key 也都用小寫
        key = class_name.lower()
        factor = self.dynamic_conf_map.get(key, 1.0)
        return self.base_conf_threshold * factor

    def extract_detected_objects(
            self,
            detection_result: Any,
            confidence_threshold: float = 0.25,
            region_analyzer=None
        ) -> List[Dict]:
            """
            從檢測結果中提取物件資訊,包含位置資訊

            Args:
                detection_result: YOLO檢測結果
                confidence_threshold: 改由動態門檻決定
                region_analyzer: 區域分析器實例,用於判斷物件所屬區域

            Returns:
                包含檢測物件資訊的字典列表
            """
            try:
                # 調試信息:記錄當前類別映射狀態
                logger.info(f"ObjectExtractor.extract_detected_objects called")
                logger.info(f"Current class_names keys: {list(self.class_names.keys()) if self.class_names else 'None'}")

                if detection_result is None:
                    logger.warning("Detection result is None")
                    return []

                if not hasattr(detection_result, 'boxes'):
                    logger.error("Detection result does not have boxes attribute")
                    return []

                boxes = detection_result.boxes.xyxy.cpu().numpy()
                classes = detection_result.boxes.cls.cpu().numpy().astype(int)
                confidences = detection_result.boxes.conf.cpu().numpy()

                # 獲取圖像尺寸
                img_height, img_width = detection_result.orig_shape[:2]

                detected_objects = []

                for box, class_id, confidence in zip(boxes, classes, confidences):
                    try:
                        # 1. 先拿到這筆偵測物件的 class_name
                        class_name = self.class_names.get(int(class_id), f"unknown_class_{class_id}")
                        # 2. 計算這個 class 應該採用的動態 threshold
                        dyn_thr = self._get_dynamic_threshold(class_name)  # e.g. 0.25 * factor
                        # 3. 如果 confidence < dyn_thr,就跳過這一筆
                        if confidence < dyn_thr:
                            continue

                        # 後面維持原本的座標、中心、大小、區域等資訊計算
                        x1, y1, x2, y2 = box
                        width = x2 - x1
                        height = y2 - y1

                        # 中心點計算
                        center_x = (x1 + x2) / 2
                        center_y = (y1 + y2) / 2

                        # 標準化位置 (0-1)
                        norm_x = center_x / img_width
                        norm_y = center_y / img_height
                        norm_width = width / img_width
                        norm_height = height / img_height

                        # 面積計算
                        area = width * height
                        norm_area = area / (img_width * img_height)

                        # 區域判斷
                        object_region = "unknown"
                        if region_analyzer:
                            object_region = region_analyzer.determine_region(norm_x, norm_y)

                        # 調試信息:記錄映射過程
                        if class_name.startswith("unknown_class_"):
                            logger.warning(
                                f"Class ID {class_id} not found in class_names. "
                                f"Available keys: {list(self.class_names.keys())}"
                            )
                        else:
                            logger.debug(f"Successfully mapped class ID {class_id} to '{class_name}'")

                        detected_objects.append({
                            "class_id": int(class_id),
                            "class_name": class_name,
                            "confidence": float(confidence),
                            "box": [float(x1), float(y1), float(x2), float(y2)],
                            "center": [float(center_x), float(center_y)],
                            "normalized_center": [float(norm_x), float(norm_y)],
                            "size": [float(width), float(height)],
                            "normalized_size": [float(norm_width), float(norm_height)],
                            "area": float(area),
                            "normalized_area": float(norm_area),
                            "region": object_region
                        })

                    except Exception as e:
                        logger.error(f"Error processing object with class_id {class_id}: {str(e)}")
                        continue

                logger.info(f"Extracted {len(detected_objects)} objects from detection result")

                # print(f"DEBUG: ObjectExtractor filtered objects by class:")
                # for class_name in ["car", "traffic light", "person", "handbag"]:    
                #     class_objects = [obj for obj in detected_objects if obj.get("class_name") == class_name]
                #     if class_objects:                        
                #         confidences = [obj.get("confidence", 0) for obj in class_objects]
                #         print(f"DEBUG: {class_name}: {len(class_objects)} objects, confidences: {confidences}")
                # print(f"DEBUG: base_conf_threshold: {self.base_conf_threshold}")
                # print(f"DEBUG: dynamic_conf_map: {self.dynamic_conf_map}")
                
                return detected_objects

            except Exception as e:
                logger.error(f"Error extracting detected objects: {str(e)}")
                logger.error(traceback.format_exc())
                return []

    def update_class_names(self, class_names: Dict[int, str]):
        """
        動態更新類別名稱映射

        Args:
            class_names: 新的類別名稱映射字典
        """
        try:
            self.class_names = class_names or {}
            logger.info(f"Class names updated: {len(self.class_names)} classes")
            logger.debug(f"Updated class names: {self.class_names}")
        except Exception as e:
            logger.error(f"Failed to update class names: {str(e)}")

    def categorize_object(self, obj: Dict) -> str:
        """
        將檢測到的物件分類到功能類別中,用於區域識別

        Args:
            obj: 物件字典

        Returns:
            物件功能類別字串
        """
        try:
            class_id = obj.get("class_id", -1)
            class_name = obj.get("class_name", "").lower()

            # 使用現有的類別映射(如果可用)
            if self.object_categories:
                for category, ids in self.object_categories.items():
                    if class_id in ids:
                        return category

            # 基於COCO類別名稱的後備分類
            furniture_items = ["chair", "couch", "bed", "dining table", "toilet"]
            plant_items = ["potted plant"]
            electronic_items = ["tv", "laptop", "mouse", "remote", "keyboard", "cell phone"]
            vehicle_items = ["bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat"]
            person_items = ["person"]
            kitchen_items = ["bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
                            "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog",
                            "pizza", "donut", "cake", "refrigerator", "oven", "toaster", "sink", "microwave"]
            sports_items = ["frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
                        "baseball glove", "skateboard", "surfboard", "tennis racket"]
            personal_items = ["handbag", "tie", "suitcase", "umbrella", "backpack"]

            if any(item in class_name for item in furniture_items):
                return "furniture"
            elif any(item in class_name for item in plant_items):
                return "plant"
            elif any(item in class_name for item in electronic_items):
                return "electronics"
            elif any(item in class_name for item in vehicle_items):
                return "vehicle"
            elif any(item in class_name for item in person_items):
                return "person"
            elif any(item in class_name for item in kitchen_items):
                return "kitchen_items"
            elif any(item in class_name for item in sports_items):
                return "sports"
            elif any(item in class_name for item in personal_items):
                return "personal_items"
            else:
                return "misc"

        except Exception as e:
            logger.error(f"Error categorizing object: {str(e)}")
            logger.error(traceback.format_exc())
            return "misc"

    def get_object_categories(self, detected_objects: List[Dict]) -> set:
        """
        從檢測到的物件中取得唯一的物件類別

        Args:
            detected_objects: 檢測到的物件列表

        Returns:
            唯一物件類別的集合
        """
        try:
            object_categories = set()
            for obj in detected_objects:
                category = self.categorize_object(obj)
                if category:
                    object_categories.add(category)

            logger.info(f"Found {len(object_categories)} unique object categories")
            return object_categories

        except Exception as e:
            logger.error(f"Error getting object categories: {str(e)}")
            logger.error(traceback.format_exc())
            return set()

    def identify_core_objects_for_scene(self, detected_objects: List[Dict], scene_type: str) -> List[Dict]:
        """
        識別定義特定場景類型的核心物件

        Args:
            detected_objects: 檢測到的物件列表
            scene_type: 場景類型

        Returns:
            場景的核心物件列表
        """
        try:
            core_objects = []

            # 場景核心物件映射
            scene_core_mapping = {
                "bedroom": [59],  # bed
                "kitchen": [68, 69, 71, 72],  # microwave, oven, sink, refrigerator
                "living_room": [57, 58, 62],  # sofa, chair, tv
                "dining_area": [60, 42, 43],  # dining table, fork, knife
                "office_workspace": [63, 64, 66, 73]  # laptop, mouse, keyboard, book
            }

            if scene_type in scene_core_mapping:
                core_class_ids = scene_core_mapping[scene_type]
                for obj in detected_objects:
                    if obj.get("class_id") in core_class_ids and obj.get("confidence", 0) >= 0.4:
                        core_objects.append(obj)

            logger.info(f"Identified {len(core_objects)} core objects for scene type '{scene_type}'")
            return core_objects

        except Exception as e:
            logger.error(f"Error identifying core objects for scene '{scene_type}': {str(e)}")
            logger.error(traceback.format_exc())
            return []

    def group_objects_by_category_and_region(self, detected_objects: List[Dict]) -> Dict:
        """
        將物件按類別和區域分組

        Args:
            detected_objects: 檢測到的物件列表

        Returns:
            按類別和區域分組的物件字典
        """
        try:
            category_regions = {}

            for obj in detected_objects:
                category = self.categorize_object(obj)
                if not category:
                    continue

                if category not in category_regions:
                    category_regions[category] = {}

                region = obj.get("region", "center")
                if region not in category_regions[category]:
                    category_regions[category][region] = []

                category_regions[category][region].append(obj)

            logger.info(f"Grouped objects into {len(category_regions)} categories across regions")
            return category_regions

        except Exception as e:
            logger.error(f"Error grouping objects by category and region: {str(e)}")
            logger.error(traceback.format_exc())
            return {}

    def filter_objects_by_confidence(self, detected_objects: List[Dict], min_confidence: float) -> List[Dict]:
        """
        根據信心度過濾物件

        Args:
            detected_objects: 檢測到的物件列表
            min_confidence: 最小信心度閾值

        Returns:
            過濾後的物件列表
        """
        try:
            filtered_objects = [
                obj for obj in detected_objects
                if obj.get("confidence", 0) >= min_confidence
            ]

            logger.info(f"Filtered {len(detected_objects)} objects to {len(filtered_objects)} objects with confidence >= {min_confidence}")
            return filtered_objects

        except Exception as e:
            logger.error(f"Error filtering objects by confidence: {str(e)}")
            logger.error(traceback.format_exc())
            return detected_objects  # 發生錯誤時返回原始列表