jeffacce
initial commit
393d3de
import abc
import os
import glob
import random
import torch
from typing import List, NamedTuple, Type
from .. import get_libero_path
from .libero_suite_task_map import libero_task_map
BENCHMARK_MAPPING = {}
def register_benchmark(target_class):
"""We design the mapping to be case-INsensitive."""
BENCHMARK_MAPPING[target_class.__name__.lower()] = target_class
def get_benchmark_dict(help=False):
if help:
print("Available benchmarks:")
for benchmark_name in BENCHMARK_MAPPING.keys():
print(f"\t{benchmark_name}")
return BENCHMARK_MAPPING
def get_benchmark(benchmark_name):
return BENCHMARK_MAPPING[benchmark_name.lower()]
def print_benchmark():
print(BENCHMARK_MAPPING)
class Task(NamedTuple):
name: str
language: str
problem: str
problem_folder: str
bddl_file: str
init_states_file: str
def grab_language_from_filename(x):
if x[0].isupper(): # LIBERO-100
if "SCENE10" in x:
language = " ".join(x[x.find("SCENE") + 8 :].split("_"))
else:
language = " ".join(x[x.find("SCENE") + 7 :].split("_"))
else:
language = " ".join(x.split("_"))
en = language.find(".bddl")
return language[:en]
libero_suites = [
"libero_spatial",
"libero_object",
"libero_goal",
"libero_90",
"libero_10",
]
task_maps = {}
max_len = 0
for libero_suite in libero_suites:
task_maps[libero_suite] = {}
for task in libero_task_map[libero_suite]:
language = grab_language_from_filename(task + ".bddl")
task_maps[libero_suite][task] = Task(
name=task,
language=language,
problem="Libero",
problem_folder=libero_suite,
bddl_file=f"{task}.bddl",
init_states_file=f"{task}.pruned_init",
)
# print(language, "\n", f"{task}.bddl", "\n")
# print("")
task_orders = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[4, 6, 8, 7, 3, 1, 2, 0, 9, 5],
[6, 3, 5, 0, 4, 2, 9, 1, 8, 7],
[7, 4, 3, 0, 8, 1, 2, 5, 9, 6],
[4, 5, 6, 3, 8, 0, 2, 7, 1, 9],
[1, 2, 3, 0, 6, 9, 5, 7, 4, 8],
[3, 7, 8, 1, 6, 2, 9, 4, 0, 5],
[4, 2, 9, 7, 6, 8, 5, 1, 3, 0],
[1, 8, 5, 4, 0, 9, 6, 7, 2, 3],
[8, 3, 6, 4, 9, 5, 1, 2, 0, 7],
[6, 9, 0, 5, 7, 1, 2, 8, 3, 4],
[6, 8, 3, 1, 0, 2, 5, 9, 7, 4],
[8, 0, 6, 9, 4, 1, 7, 3, 2, 5],
[3, 8, 6, 4, 2, 5, 0, 7, 1, 9],
[7, 1, 5, 6, 3, 2, 8, 9, 4, 0],
[2, 0, 9, 5, 3, 6, 8, 7, 1, 4],
[3, 5, 9, 6, 2, 4, 8, 7, 1, 0],
[7, 6, 5, 9, 0, 3, 4, 2, 8, 1],
[2, 5, 0, 9, 3, 1, 6, 4, 8, 7],
[3, 5, 1, 2, 7, 8, 6, 0, 4, 9],
[3, 4, 1, 9, 7, 6, 8, 2, 0, 5],
]
class Benchmark(abc.ABC):
"""A Benchmark."""
def __init__(self, task_order_index=0):
self.task_embs = None
self.task_order_index = task_order_index
def _make_benchmark(self):
tasks = list(task_maps[self.name].values())
if self.name == "libero_90":
self.tasks = tasks
else:
print(f"[info] using task orders {task_orders[self.task_order_index]}")
self.tasks = [tasks[i] for i in task_orders[self.task_order_index]]
self.n_tasks = len(self.tasks)
def get_num_tasks(self):
return self.n_tasks
def get_task_names(self):
return [task.name for task in self.tasks]
def get_task_problems(self):
return [task.problem for task in self.tasks]
def get_task_bddl_files(self):
return [task.bddl_file for task in self.tasks]
def get_task_bddl_file_path(self, i):
bddl_file_path = os.path.join(
get_libero_path("bddl_files"),
self.tasks[i].problem_folder,
self.tasks[i].bddl_file,
)
return bddl_file_path
def get_task_demonstration(self, i):
assert (
0 <= i and i < self.n_tasks
), f"[error] task number {i} is outer of range {self.n_tasks}"
# this path is relative to the datasets folder
demo_path = f"{self.tasks[i].problem_folder}/{self.tasks[i].name}_demo.hdf5"
return demo_path
def get_task(self, i):
return self.tasks[i]
def get_task_emb(self, i):
return self.task_embs[i]
def get_task_init_states(self, i):
init_states_path = os.path.join(
get_libero_path("init_states"),
self.tasks[i].problem_folder,
self.tasks[i].init_states_file,
)
init_states = torch.load(init_states_path)
return init_states
def set_task_embs(self, task_embs):
self.task_embs = task_embs
@register_benchmark
class LIBERO_SPATIAL(Benchmark):
def __init__(self, task_order_index=0):
super().__init__(task_order_index=task_order_index)
self.name = "libero_spatial"
self._make_benchmark()
@register_benchmark
class LIBERO_OBJECT(Benchmark):
def __init__(self, task_order_index=0):
super().__init__(task_order_index=task_order_index)
self.name = "libero_object"
self._make_benchmark()
@register_benchmark
class LIBERO_GOAL(Benchmark):
def __init__(self, task_order_index=0):
super().__init__(task_order_index=task_order_index)
self.name = "libero_goal"
self._make_benchmark()
@register_benchmark
class LIBERO_90(Benchmark):
def __init__(self, task_order_index=0):
super().__init__(task_order_index=task_order_index)
assert (
task_order_index == 0
), "[error] currently only support task order for 10-task suites"
self.name = "libero_90"
self._make_benchmark()
@register_benchmark
class LIBERO_10(Benchmark):
def __init__(self, task_order_index=0):
super().__init__(task_order_index=task_order_index)
self.name = "libero_10"
self._make_benchmark()
@register_benchmark
class LIBERO_100(Benchmark):
def __init__(self, task_order_index=0):
super().__init__(task_order_index=task_order_index)
self.name = "libero_100"
self._make_benchmark()