|
""" |
|
Download functionalities adapted from Mandlekar et. al.: https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/file_utils.py |
|
""" |
|
import os |
|
import time |
|
from tqdm import tqdm |
|
from termcolor import colored |
|
from pathlib import Path |
|
import zipfile |
|
import io |
|
import urllib.request |
|
import shutil |
|
|
|
from libero import get_libero_path |
|
|
|
DIR = os.path.dirname(__file__) |
|
|
|
DATASET_LINKS = { |
|
"libero_object": "https://utexas.box.com/shared/static/avkklgeq0e1dgzxz52x488whpu8mgspk.zip", |
|
"libero_goal": "https://utexas.box.com/shared/static/iv5e4dos8yy2b212pkzkpxu9wbdgjfeg.zip", |
|
"libero_spatial": "https://utexas.box.com/shared/static/04k94hyizn4huhbv5sz4ev9p2h1p6s7f.zip", |
|
"libero_100": "https://utexas.box.com/shared/static/cv73j8zschq8auh9npzt876fdc1akvmk.zip", |
|
} |
|
|
|
|
|
class DownloadProgressBar(tqdm): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def update_to(self, b=1, bsize=1, tsize=None): |
|
if tsize is not None: |
|
self.total = tsize |
|
self.update(b * bsize - self.n) |
|
|
|
|
|
def url_is_alive(url): |
|
""" |
|
Checks that a given URL is reachable. |
|
From https://gist.github.com/dehowell/884204. |
|
Args: |
|
url (str): url string |
|
Returns: |
|
is_alive (bool): True if url is reachable, False otherwise |
|
""" |
|
request = urllib.request.Request(url) |
|
|
|
|
|
try: |
|
urllib.request.urlopen(request) |
|
return True |
|
except urllib.request.HTTPError: |
|
return False |
|
|
|
|
|
def download_url(url, download_dir, check_overwrite=True, is_zipfile=True): |
|
""" |
|
First checks that @url is reachable, then downloads the file |
|
at that url into the directory specified by @download_dir. |
|
Prints a progress bar during the download using tqdm. |
|
Modified from https://github.com/tqdm/tqdm#hooks-and-callbacks, and |
|
https://stackoverflow.com/a/53877507. |
|
Args: |
|
url (str): url string |
|
download_dir (str): path to directory where file should be downloaded |
|
check_overwrite (bool): if True, will sanity check the download fpath to make sure a file of that name |
|
doesn't already exist there |
|
""" |
|
|
|
|
|
assert url_is_alive(url), "@download_url got unreachable url: {}".format(url) |
|
time.sleep(0.5) |
|
|
|
|
|
fname = url.split("/")[-1] |
|
file_to_write = os.path.join(download_dir, fname) |
|
|
|
|
|
|
|
user_response = None |
|
if check_overwrite and os.path.exists(file_to_write): |
|
user_response = input( |
|
f"Warning: file {file_to_write} already exists. Overwrite? y/n\n" |
|
) |
|
|
|
|
|
if user_response is None or user_response.lower() in {"yes", "y"}: |
|
with DownloadProgressBar( |
|
unit="B", unit_scale=True, miniters=1, desc=fname |
|
) as t: |
|
urllib.request.urlretrieve( |
|
url, filename=file_to_write, reporthook=t.update_to |
|
) |
|
if is_zipfile: |
|
with zipfile.ZipFile(file_to_write, "r") as archive: |
|
archive.extractall(path=download_dir) |
|
if os.path.isfile(file_to_write): |
|
os.remove(file_to_write) |
|
|
|
|
|
def libero_dataset_download(datasets="all", download_dir=None, check_overwrite=True): |
|
"""Download libero datasets |
|
|
|
Args: |
|
datasets (str, optional): Specify which datasets to save. Defaults to "all", downloading all the datasets. |
|
download_dir (str, optional): Target location for storing datasets. Defaults to None, using the default path. |
|
check_overwrite (bool, optional): Check if overwriting datasets. Defaults to True. |
|
""" |
|
|
|
if download_dir is None: |
|
download_dir = get_libero_path("datasets") |
|
if not os.path.exists(download_dir): |
|
os.makedirs(download_dir) |
|
|
|
assert datasets in [ |
|
"all", |
|
"libero_object", |
|
"libero_goal", |
|
"libero_spatial", |
|
"libero_100", |
|
] |
|
|
|
for dataset_name in [ |
|
"libero_object", |
|
"libero_goal", |
|
"libero_spatial", |
|
"libero_100", |
|
]: |
|
if datasets == dataset_name or datasets == "all": |
|
print(f"Downloading {dataset_name}") |
|
download_url( |
|
DATASET_LINKS[dataset_name], |
|
download_dir=download_dir, |
|
check_overwrite=check_overwrite, |
|
) |
|
|
|
|
|
|
|
|
|
def check_libero_dataset(download_dir=None): |
|
"""Check the integrity of the downloaded datasets. |
|
|
|
Args: |
|
download_dir (str, optional): The path where datasets are stored. Defaults to None, using the default path. |
|
|
|
Returns: |
|
bool: True if the datasets are successfully downloaded, False otherwise. |
|
""" |
|
if download_dir is None: |
|
download_dir = get_libero_path("datasets") |
|
check_result = True |
|
for dataset_name in [ |
|
"libero_object", |
|
"libero_goal", |
|
"libero_spatial", |
|
"libero_10", |
|
"libero_90", |
|
]: |
|
info_str = "" |
|
dataset_status = False |
|
dataset_dir = os.path.join(download_dir, dataset_name) |
|
if os.path.exists(dataset_dir): |
|
count = 0 |
|
for path in Path(dataset_dir).glob("*.hdf5"): |
|
count += 1 |
|
if (count == 10 and dataset_name != "libero_90") or ( |
|
count == 90 and dataset_name == "libero_90" |
|
): |
|
dataset_status = True |
|
info_str = colored( |
|
f"[X] Dataset {dataset_name} is complete", "green", attrs=["bold"] |
|
) |
|
else: |
|
colored( |
|
f"[?] Dataset {dataset_name} is not downloaded completely", |
|
"yellow", |
|
attrs=["bold"], |
|
) |
|
else: |
|
info_str = colored( |
|
f"[ ] Dataset {dataset_name} not found!!!", "red", attrs=["bold"] |
|
) |
|
|
|
print(info_str) |
|
check_result = check_result and dataset_status |
|
return check_result |
|
|