File size: 3,877 Bytes
393d3de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import h5py
import numpy as np
import json
def get_dataset_info(dataset_path, filter_key=None, verbose=True):
# extract demonstration list from file
all_filter_keys = None
f = h5py.File(dataset_path, "r")
if filter_key is not None:
# use the demonstrations from the filter key instead
print("NOTE: using filter key {}".format(filter_key))
demos = sorted(
[elem.decode("utf-8") for elem in np.array(f["mask/{}".format(filter_key)])]
)
else:
# use all demonstrations
demos = sorted(list(f["data"].keys()))
# extract filter key information
if "mask" in f:
all_filter_keys = {}
for fk in f["mask"]:
fk_demos = sorted(
[elem.decode("utf-8") for elem in np.array(f["mask/{}".format(fk)])]
)
all_filter_keys[fk] = fk_demos
# put demonstration list in increasing episode order
inds = np.argsort([int(elem[5:]) for elem in demos])
demos = [demos[i] for i in inds]
# extract length of each trajectory in the file
traj_lengths = []
action_min = np.inf
action_max = -np.inf
for ep in demos:
traj_lengths.append(f["data/{}/actions".format(ep)].shape[0])
action_min = min(action_min, np.min(f["data/{}/actions".format(ep)][()]))
action_max = max(action_max, np.max(f["data/{}/actions".format(ep)][()]))
traj_lengths = np.array(traj_lengths)
problem_info = json.loads(f["data"].attrs["problem_info"])
language_instruction = "".join(problem_info["language_instruction"])
# report statistics on the data
print("")
print("total transitions: {}".format(np.sum(traj_lengths)))
print("total trajectories: {}".format(traj_lengths.shape[0]))
print("traj length mean: {}".format(np.mean(traj_lengths)))
print("traj length std: {}".format(np.std(traj_lengths)))
print("traj length min: {}".format(np.min(traj_lengths)))
print("traj length max: {}".format(np.max(traj_lengths)))
print("action min: {}".format(action_min))
print("action max: {}".format(action_max))
print("language instruction: {}".format(language_instruction.strip('"')))
print("")
print("==== Filter Keys ====")
if all_filter_keys is not None:
for fk in all_filter_keys:
print("filter key {} with {} demos".format(fk, len(all_filter_keys[fk])))
else:
print("no filter keys")
print("")
if verbose:
if all_filter_keys is not None:
print("==== Filter Key Contents ====")
for fk in all_filter_keys:
print(
"filter_key {} with {} demos: {}".format(
fk, len(all_filter_keys[fk]), all_filter_keys[fk]
)
)
print("")
env_meta = json.loads(f["data"].attrs["env_args"])
print("==== Env Meta ====")
print(json.dumps(env_meta, indent=4))
print("")
print("==== Dataset Structure ====")
for ep in demos:
print(
"episode {} with {} transitions".format(
ep, f["data/{}".format(ep)].attrs["num_samples"]
)
)
for k in f["data/{}".format(ep)]:
if k in ["obs", "next_obs"]:
print(" key: {}".format(k))
for obs_k in f["data/{}/{}".format(ep, k)]:
shape = f["data/{}/{}/{}".format(ep, k, obs_k)].shape
print(
" observation key {} with shape {}".format(obs_k, shape)
)
elif isinstance(f["data/{}/{}".format(ep, k)], h5py.Dataset):
key_shape = f["data/{}/{}".format(ep, k)].shape
print(" key: {} with shape {}".format(k, key_shape))
if not verbose:
break
f.close()
|