File size: 4,499 Bytes
393d3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import re

from libero.envs.objects import OBJECTS_DICT
from libero.utils.bddl_generation_utils import (
    get_affordance_region_kwargs_list_from_fixture_info,
    get_object_dict,
)
from libero.utils.object_utils import get_affordance_regions

MU_DICT = {}

SCENE_DICT = {}


def register_mu(scene_type="general"):
    def _func(target_class):
        """For reusing initial conditions easily, we register each pre-defined initial conditions in a dictionary."""
        key = "_".join(
            re.sub(r"([A-Z])", r" \1", target_class.__name__).split()
        ).lower()
        MU_DICT[key.lower()] = target_class
        if scene_type not in SCENE_DICT:
            SCENE_DICT[scene_type.lower()] = []
        SCENE_DICT[scene_type.lower()].append(target_class)
        return target_class

    return _func


def get_scene_dict(scene_type=None):
    if scene_type is None:
        return SCENE_DICT
    else:
        return SCENE_DICT[scene_type.lower()]


def get_scene_class(scene_name):
    return MU_DICT[scene_name.lower()]


class InitialSceneTemplates:
    def __init__(
        self, workspace_name="main_table", fixture_num_info={}, object_num_info={}
    ):

        self.workspace_name = workspace_name
        # print(self.workspace_name)

        self.fixture_object_dict = get_object_dict(fixture_num_info)
        self.movable_object_dict = get_object_dict(object_num_info)

        affordances = get_affordance_regions(OBJECTS_DICT)
        affordance_fixture_info_dict = {}
        for fixture_category_name in self.fixture_object_dict.keys():
            if (
                fixture_category_name != self.workspace_name
                and fixture_category_name != "table"
                and fixture_category_name != "living_room_table"
                and fixture_category_name != "study_table"
                and fixture_category_name != "kitchen_table"
            ):
                for fixture_name in self.fixture_object_dict[fixture_category_name]:
                    affordance_fixture_info_dict[fixture_name] = affordances[
                        fixture_category_name
                    ]
        for category_name in self.movable_object_dict.keys():
            if category_name in affordances:
                for object_name in self.movable_object_dict[category_name]:
                    affordance_fixture_info_dict[object_name] = affordances[
                        category_name
                    ]
        # print(affordance_fixture_info_dict)
        self.affordance_region_kwargs_list = (
            get_affordance_region_kwargs_list_from_fixture_info(
                affordance_fixture_info_dict
            )
        )

        self.regions = {}
        self.define_regions()

    @property
    def possible_objects_of_interest(self):
        # objects_of_interest = list(self.fixture_object_dict.keys()) + list(self.movable_object_dict.keys())
        # return objects_of_interest
        objects_of_interest = []
        for category_name in self.fixture_object_dict.keys():
            objects_of_interest += self.fixture_object_dict[category_name]
        for category_name in self.movable_object_dict.keys():
            objects_of_interest += self.movable_object_dict[category_name]
        return objects_of_interest

    @property
    def movable_objects(self):
        return list(self.movable_object_dict)

    def define_regions(self):
        """Override this method to define the layout of a scene."""
        raise NotImplementedError

    def get_region_dict(
        self,
        region_centroid_xy,
        region_name,
        target_name=None,
        region_half_len=0.02,
        yaw_rotation=(0.0, 0.0),
    ):
        """This is a function that creates a default region with rectangular shape."""
        if target_name is None:
            target_name = self.workspace_name
        region_key_value = {
            region_name: {
                "target": target_name,
                "ranges": [
                    (
                        region_centroid_xy[0] - region_half_len,
                        region_centroid_xy[1] - region_half_len,
                        region_centroid_xy[0] + region_half_len,
                        region_centroid_xy[1] + region_half_len,
                    )
                ],
                "yaw_rotation": [yaw_rotation],
            }
        }
        print
        return region_key_value

    @property
    def init_states(self):
        raise NotImplementedError