File size: 6,342 Bytes
393d3de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
"""
Download functionalities adapted from Mandlekar et. al.: https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/file_utils.py
"""
import os
import time
from tqdm import tqdm
from termcolor import colored
from pathlib import Path
import zipfile
import io
import urllib.request
import shutil
from libero import get_libero_path
DIR = os.path.dirname(__file__)
DATASET_LINKS = {
"libero_object": "https://utexas.box.com/shared/static/avkklgeq0e1dgzxz52x488whpu8mgspk.zip",
"libero_goal": "https://utexas.box.com/shared/static/iv5e4dos8yy2b212pkzkpxu9wbdgjfeg.zip",
"libero_spatial": "https://utexas.box.com/shared/static/04k94hyizn4huhbv5sz4ev9p2h1p6s7f.zip",
"libero_100": "https://utexas.box.com/shared/static/cv73j8zschq8auh9npzt876fdc1akvmk.zip",
}
class DownloadProgressBar(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def url_is_alive(url):
"""
Checks that a given URL is reachable.
From https://gist.github.com/dehowell/884204.
Args:
url (str): url string
Returns:
is_alive (bool): True if url is reachable, False otherwise
"""
request = urllib.request.Request(url)
# request.get_method = lambda: 'HEAD'
try:
urllib.request.urlopen(request)
return True
except urllib.request.HTTPError:
return False
def download_url(url, download_dir, check_overwrite=True, is_zipfile=True):
"""
First checks that @url is reachable, then downloads the file
at that url into the directory specified by @download_dir.
Prints a progress bar during the download using tqdm.
Modified from https://github.com/tqdm/tqdm#hooks-and-callbacks, and
https://stackoverflow.com/a/53877507.
Args:
url (str): url string
download_dir (str): path to directory where file should be downloaded
check_overwrite (bool): if True, will sanity check the download fpath to make sure a file of that name
doesn't already exist there
"""
# check if url is reachable. We need the sleep to make sure server doesn't reject subsequent requests
assert url_is_alive(url), "@download_url got unreachable url: {}".format(url)
time.sleep(0.5)
# infer filename from url link
fname = url.split("/")[-1]
file_to_write = os.path.join(download_dir, fname)
# If we're checking overwrite and the path already exists,
# we ask the user to verify that they want to overwrite the file
user_response = None
if check_overwrite and os.path.exists(file_to_write):
user_response = input(
f"Warning: file {file_to_write} already exists. Overwrite? y/n\n"
)
# assert user_response.lower() in {"yes", "y"}, f"Did not receive confirmation. Aborting download."
if user_response is None or user_response.lower() in {"yes", "y"}:
with DownloadProgressBar(
unit="B", unit_scale=True, miniters=1, desc=fname
) as t:
urllib.request.urlretrieve(
url, filename=file_to_write, reporthook=t.update_to
)
if is_zipfile:
with zipfile.ZipFile(file_to_write, "r") as archive:
archive.extractall(path=download_dir)
if os.path.isfile(file_to_write):
os.remove(file_to_write)
def libero_dataset_download(datasets="all", download_dir=None, check_overwrite=True):
"""Download libero datasets
Args:
datasets (str, optional): Specify which datasets to save. Defaults to "all", downloading all the datasets.
download_dir (str, optional): Target location for storing datasets. Defaults to None, using the default path.
check_overwrite (bool, optional): Check if overwriting datasets. Defaults to True.
"""
if download_dir is None:
download_dir = get_libero_path("datasets")
if not os.path.exists(download_dir):
os.makedirs(download_dir)
assert datasets in [
"all",
"libero_object",
"libero_goal",
"libero_spatial",
"libero_100",
]
for dataset_name in [
"libero_object",
"libero_goal",
"libero_spatial",
"libero_100",
]:
if datasets == dataset_name or datasets == "all":
print(f"Downloading {dataset_name}")
download_url(
DATASET_LINKS[dataset_name],
download_dir=download_dir,
check_overwrite=check_overwrite,
)
# (TODO): unzip the files
def check_libero_dataset(download_dir=None):
"""Check the integrity of the downloaded datasets.
Args:
download_dir (str, optional): The path where datasets are stored. Defaults to None, using the default path.
Returns:
bool: True if the datasets are successfully downloaded, False otherwise.
"""
if download_dir is None:
download_dir = get_libero_path("datasets")
check_result = True
for dataset_name in [
"libero_object",
"libero_goal",
"libero_spatial",
"libero_10",
"libero_90",
]:
info_str = ""
dataset_status = False
dataset_dir = os.path.join(download_dir, dataset_name)
if os.path.exists(dataset_dir):
count = 0
for path in Path(dataset_dir).glob("*.hdf5"):
count += 1
if (count == 10 and dataset_name != "libero_90") or (
count == 90 and dataset_name == "libero_90"
):
dataset_status = True
info_str = colored(
f"[X] Dataset {dataset_name} is complete", "green", attrs=["bold"]
)
else:
colored(
f"[?] Dataset {dataset_name} is not downloaded completely",
"yellow",
attrs=["bold"],
)
else:
info_str = colored(
f"[ ] Dataset {dataset_name} not found!!!", "red", attrs=["bold"]
)
print(info_str)
check_result = check_result and dataset_status
return check_result
|