dynamo_ssl / envs /libero /utils /download_utils.py
jeffacce
initial commit
393d3de
"""
Download functionalities adapted from Mandlekar et. al.: https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/file_utils.py
"""
import os
import time
from tqdm import tqdm
from termcolor import colored
from pathlib import Path
import zipfile
import io
import urllib.request
import shutil
from libero import get_libero_path
DIR = os.path.dirname(__file__)
DATASET_LINKS = {
"libero_object": "https://utexas.box.com/shared/static/avkklgeq0e1dgzxz52x488whpu8mgspk.zip",
"libero_goal": "https://utexas.box.com/shared/static/iv5e4dos8yy2b212pkzkpxu9wbdgjfeg.zip",
"libero_spatial": "https://utexas.box.com/shared/static/04k94hyizn4huhbv5sz4ev9p2h1p6s7f.zip",
"libero_100": "https://utexas.box.com/shared/static/cv73j8zschq8auh9npzt876fdc1akvmk.zip",
}
class DownloadProgressBar(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def url_is_alive(url):
"""
Checks that a given URL is reachable.
From https://gist.github.com/dehowell/884204.
Args:
url (str): url string
Returns:
is_alive (bool): True if url is reachable, False otherwise
"""
request = urllib.request.Request(url)
# request.get_method = lambda: 'HEAD'
try:
urllib.request.urlopen(request)
return True
except urllib.request.HTTPError:
return False
def download_url(url, download_dir, check_overwrite=True, is_zipfile=True):
"""
First checks that @url is reachable, then downloads the file
at that url into the directory specified by @download_dir.
Prints a progress bar during the download using tqdm.
Modified from https://github.com/tqdm/tqdm#hooks-and-callbacks, and
https://stackoverflow.com/a/53877507.
Args:
url (str): url string
download_dir (str): path to directory where file should be downloaded
check_overwrite (bool): if True, will sanity check the download fpath to make sure a file of that name
doesn't already exist there
"""
# check if url is reachable. We need the sleep to make sure server doesn't reject subsequent requests
assert url_is_alive(url), "@download_url got unreachable url: {}".format(url)
time.sleep(0.5)
# infer filename from url link
fname = url.split("/")[-1]
file_to_write = os.path.join(download_dir, fname)
# If we're checking overwrite and the path already exists,
# we ask the user to verify that they want to overwrite the file
user_response = None
if check_overwrite and os.path.exists(file_to_write):
user_response = input(
f"Warning: file {file_to_write} already exists. Overwrite? y/n\n"
)
# assert user_response.lower() in {"yes", "y"}, f"Did not receive confirmation. Aborting download."
if user_response is None or user_response.lower() in {"yes", "y"}:
with DownloadProgressBar(
unit="B", unit_scale=True, miniters=1, desc=fname
) as t:
urllib.request.urlretrieve(
url, filename=file_to_write, reporthook=t.update_to
)
if is_zipfile:
with zipfile.ZipFile(file_to_write, "r") as archive:
archive.extractall(path=download_dir)
if os.path.isfile(file_to_write):
os.remove(file_to_write)
def libero_dataset_download(datasets="all", download_dir=None, check_overwrite=True):
"""Download libero datasets
Args:
datasets (str, optional): Specify which datasets to save. Defaults to "all", downloading all the datasets.
download_dir (str, optional): Target location for storing datasets. Defaults to None, using the default path.
check_overwrite (bool, optional): Check if overwriting datasets. Defaults to True.
"""
if download_dir is None:
download_dir = get_libero_path("datasets")
if not os.path.exists(download_dir):
os.makedirs(download_dir)
assert datasets in [
"all",
"libero_object",
"libero_goal",
"libero_spatial",
"libero_100",
]
for dataset_name in [
"libero_object",
"libero_goal",
"libero_spatial",
"libero_100",
]:
if datasets == dataset_name or datasets == "all":
print(f"Downloading {dataset_name}")
download_url(
DATASET_LINKS[dataset_name],
download_dir=download_dir,
check_overwrite=check_overwrite,
)
# (TODO): unzip the files
def check_libero_dataset(download_dir=None):
"""Check the integrity of the downloaded datasets.
Args:
download_dir (str, optional): The path where datasets are stored. Defaults to None, using the default path.
Returns:
bool: True if the datasets are successfully downloaded, False otherwise.
"""
if download_dir is None:
download_dir = get_libero_path("datasets")
check_result = True
for dataset_name in [
"libero_object",
"libero_goal",
"libero_spatial",
"libero_10",
"libero_90",
]:
info_str = ""
dataset_status = False
dataset_dir = os.path.join(download_dir, dataset_name)
if os.path.exists(dataset_dir):
count = 0
for path in Path(dataset_dir).glob("*.hdf5"):
count += 1
if (count == 10 and dataset_name != "libero_90") or (
count == 90 and dataset_name == "libero_90"
):
dataset_status = True
info_str = colored(
f"[X] Dataset {dataset_name} is complete", "green", attrs=["bold"]
)
else:
colored(
f"[?] Dataset {dataset_name} is not downloaded completely",
"yellow",
attrs=["bold"],
)
else:
info_str = colored(
f"[ ] Dataset {dataset_name} not found!!!", "red", attrs=["bold"]
)
print(info_str)
check_result = check_result and dataset_status
return check_result