|
import h5py |
|
import numpy as np |
|
import json |
|
|
|
|
|
def get_dataset_info(dataset_path, filter_key=None, verbose=True): |
|
|
|
all_filter_keys = None |
|
f = h5py.File(dataset_path, "r") |
|
if filter_key is not None: |
|
|
|
print("NOTE: using filter key {}".format(filter_key)) |
|
demos = sorted( |
|
[elem.decode("utf-8") for elem in np.array(f["mask/{}".format(filter_key)])] |
|
) |
|
else: |
|
|
|
demos = sorted(list(f["data"].keys())) |
|
|
|
|
|
if "mask" in f: |
|
all_filter_keys = {} |
|
for fk in f["mask"]: |
|
fk_demos = sorted( |
|
[elem.decode("utf-8") for elem in np.array(f["mask/{}".format(fk)])] |
|
) |
|
all_filter_keys[fk] = fk_demos |
|
|
|
|
|
inds = np.argsort([int(elem[5:]) for elem in demos]) |
|
demos = [demos[i] for i in inds] |
|
|
|
|
|
traj_lengths = [] |
|
action_min = np.inf |
|
action_max = -np.inf |
|
for ep in demos: |
|
traj_lengths.append(f["data/{}/actions".format(ep)].shape[0]) |
|
action_min = min(action_min, np.min(f["data/{}/actions".format(ep)][()])) |
|
action_max = max(action_max, np.max(f["data/{}/actions".format(ep)][()])) |
|
traj_lengths = np.array(traj_lengths) |
|
|
|
problem_info = json.loads(f["data"].attrs["problem_info"]) |
|
|
|
language_instruction = "".join(problem_info["language_instruction"]) |
|
|
|
print("") |
|
print("total transitions: {}".format(np.sum(traj_lengths))) |
|
print("total trajectories: {}".format(traj_lengths.shape[0])) |
|
print("traj length mean: {}".format(np.mean(traj_lengths))) |
|
print("traj length std: {}".format(np.std(traj_lengths))) |
|
print("traj length min: {}".format(np.min(traj_lengths))) |
|
print("traj length max: {}".format(np.max(traj_lengths))) |
|
print("action min: {}".format(action_min)) |
|
print("action max: {}".format(action_max)) |
|
print("language instruction: {}".format(language_instruction.strip('"'))) |
|
print("") |
|
print("==== Filter Keys ====") |
|
if all_filter_keys is not None: |
|
for fk in all_filter_keys: |
|
print("filter key {} with {} demos".format(fk, len(all_filter_keys[fk]))) |
|
else: |
|
print("no filter keys") |
|
print("") |
|
if verbose: |
|
if all_filter_keys is not None: |
|
print("==== Filter Key Contents ====") |
|
for fk in all_filter_keys: |
|
print( |
|
"filter_key {} with {} demos: {}".format( |
|
fk, len(all_filter_keys[fk]), all_filter_keys[fk] |
|
) |
|
) |
|
print("") |
|
env_meta = json.loads(f["data"].attrs["env_args"]) |
|
print("==== Env Meta ====") |
|
print(json.dumps(env_meta, indent=4)) |
|
print("") |
|
|
|
print("==== Dataset Structure ====") |
|
for ep in demos: |
|
print( |
|
"episode {} with {} transitions".format( |
|
ep, f["data/{}".format(ep)].attrs["num_samples"] |
|
) |
|
) |
|
for k in f["data/{}".format(ep)]: |
|
if k in ["obs", "next_obs"]: |
|
print(" key: {}".format(k)) |
|
for obs_k in f["data/{}/{}".format(ep, k)]: |
|
shape = f["data/{}/{}/{}".format(ep, k, obs_k)].shape |
|
print( |
|
" observation key {} with shape {}".format(obs_k, shape) |
|
) |
|
elif isinstance(f["data/{}/{}".format(ep, k)], h5py.Dataset): |
|
key_shape = f["data/{}/{}".format(ep, k)].shape |
|
print(" key: {} with shape {}".format(k, key_shape)) |
|
|
|
if not verbose: |
|
break |
|
|
|
f.close() |
|
|