File size: 12,693 Bytes
393d3de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
import abc
import utils
import torch
import numpy as np
from torch import default_generator, randperm
from torch.utils.data import Dataset, Subset
from typing import Callable, Optional, Sequence, List, Any
from torch.nn.utils.rnn import pad_sequence
# Taken from python 3.5 docs
def _accumulate(iterable, fn=lambda x, y: x + y):
"Return running totals"
# _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = fn(total, element)
yield total
class TrajectoryDataset(Dataset, abc.ABC):
"""
A dataset containing trajectories.
TrajectoryDataset[i] returns: (observations, actions, mask)
observations: Tensor[T, ...], T frames of observations
actions: Tensor[T, ...], T frames of actions
mask: Tensor[T]: False: invalid; True: valid
"""
@abc.abstractmethod
def get_seq_length(self, idx):
"""
Returns the length of the idx-th trajectory.
"""
raise NotImplementedError
@abc.abstractmethod
def get_frames(self, idx, frames):
"""
Returns the frames from the idx-th trajectory at the specified frames.
Used to speed up slicing.
"""
raise NotImplementedError
class TrajectorySubset(TrajectoryDataset, Subset):
"""
Subset of a trajectory dataset at specified indices.
Args:
dataset (TrajectoryDataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset: TrajectoryDataset, indices: Sequence[int]):
Subset.__init__(self, dataset, indices)
def get_seq_length(self, idx):
return self.dataset.get_seq_length(self.indices[idx])
def get_all_actions(self):
return self.dataset.get_all_actions()
def get_frames(self, idx, frames):
return self.dataset.get_frames(self.indices[idx], frames)
class TrajectorySlicerDataset:
def __init__(
self,
dataset: TrajectoryDataset,
window: int,
future_conditional: bool = False,
min_future_sep: int = 0,
future_seq_len: Optional[int] = None,
only_sample_tail: bool = False,
transform: Optional[Callable] = None,
num_extra_predicted_actions: Optional[int] = None,
frame_step: int = 1,
repeat_first_frame: bool = False,
):
"""
Slice a trajectory dataset into unique (but overlapping) sequences of length `window`.
dataset: a trajectory dataset that satisfies:
dataset.get_seq_length(i) is implemented to return the length of sequence i
dataset[i] = (observations, actions, mask)
observations: Tensor[T, ...]
actions: Tensor[T, ...]
mask: Tensor[T]
False: invalid
True: valid
window: int
number of timesteps to include in each slice
future_conditional: bool = False
if True, observations will be augmented with future observations sampled from the same trajectory
min_future_sep: int = 0
minimum number of timesteps between the end of the current sequence and the start of the future sequence
for the future conditional
future_seq_len: Optional[int] = None
the length of the future conditional sequence;
required if future_conditional is True
only_sample_tail: bool = False
if True, only sample future sequences from the tail of the trajectory
transform: function (observations, actions, mask[, goal]) -> (observations, actions, mask[, goal])
"""
if future_conditional:
assert future_seq_len is not None, "must specify a future_seq_len"
self.dataset = dataset
self.window = window
self.future_conditional = future_conditional
self.min_future_sep = min_future_sep
self.future_seq_len = future_seq_len
self.only_sample_tail = only_sample_tail
self.transform = transform
self.num_extra_predicted_actions = num_extra_predicted_actions or 0
self.slices = []
self.frame_step = frame_step
min_seq_length = np.inf
if num_extra_predicted_actions:
window = window + num_extra_predicted_actions
for i in range(len(self.dataset)): # type: ignore
T = self.dataset.get_seq_length(i) # avoid reading actual seq (slow)
min_seq_length = min(T, min_seq_length)
if T - window < 0:
print(f"Ignored short sequence #{i}: len={T}, window={window}")
else:
if repeat_first_frame:
self.slices += [(i, 0, end + 1) for end in range(window - 1)]
window_len_with_step = (window - 1) * frame_step + 1
last_start = T - window_len_with_step
self.slices += [
(i, start, start + window_len_with_step)
for start in range(last_start)
] # slice indices follow convention [start, end)
if min_seq_length < window:
print(
f"Ignored short sequences. To include all, set window <= {min_seq_length}."
)
def get_seq_length(self, idx: int) -> int:
if self.future_conditional:
return self.future_seq_len + self.window
else:
return self.window
def get_all_actions(self) -> torch.Tensor:
return self.dataset.get_all_actions()
def __len__(self):
return len(self.slices)
def __getitem__(self, idx):
i, start, end = self.slices[idx]
T = self.dataset.get_seq_length(i)
if (
self.num_extra_predicted_actions is not None
and self.num_extra_predicted_actions != 0
):
assert self.frame_step == 1, "NOT TESTED"
if self.future_conditional:
raise NotImplementedError(
"num_extra_predicted_actions with future_conditional not implemented"
)
assert end <= T, f"end={end} > T={T}"
observations, actions, mask = self.dataset.get_frames(i, range(start, end))
observations = observations[: self.window]
values = [observations, actions, mask.bool()]
else:
if self.future_conditional:
assert self.frame_step == 1, "NOT TESTED"
valid_start_range = (
end + self.min_future_sep,
self.dataset.get_seq_length(i) - self.future_seq_len,
)
if valid_start_range[0] < valid_start_range[1]:
if self.only_sample_tail:
future_obs_range = range(T - self.future_seq_len, T)
else:
future_start = np.random.randint(*valid_start_range)
future_end = future_start + self.future_seq_len
future_obs_range = range(future_start, future_end)
obs, actions, mask = self.dataset.get_frames(
i, list(range(start, end)) + list(future_obs_range)
)
future_obs = obs[end - start :]
obs = obs[: end - start]
actions = actions[: end - start]
mask = mask[: end - start]
else:
# zeros placeholder T x obs_dim
obs, actions, mask = self.dataset.get_frames(i, range(start, end))
obs_dims = obs.shape[1:]
future_obs = torch.zeros((self.future_seq_len, *obs_dims))
# [observations, actions, mask, future_obs (goal conditional)]
values = [obs, actions, mask.bool(), future_obs]
else:
observations, actions, mask = self.dataset.get_frames(
i, range(start, end, self.frame_step)
)
values = [observations, actions, mask.bool()]
if end - start < self.window + self.num_extra_predicted_actions:
# this only happens for repeating the very first frames
values = [
utils.inference.repeat_start_to_length(
x, self.window + self.num_extra_predicted_actions, dim=0
)
for x in values
]
values[0] = values[0][: self.window]
# optionally apply transform
if self.transform is not None:
values = self.transform(values)
return tuple(values)
class TrajectoryEmbeddingDataset(TrajectoryDataset):
def __init__(
self,
model,
dataset: TrajectoryDataset,
device="cpu",
embed_goal=False,
):
self.data = utils.inference.embed_trajectory_dataset(
model,
dataset,
obs_only=False,
device=device,
embed_goal=embed_goal,
)
assert len(self.data) == len(dataset)
self.seq_lengths = [len(x[0]) for x in self.data]
self.on_device_data = []
n_tensors = len(self.data[0])
for i in range(n_tensors):
self.on_device_data.append(
pad_sequence([x[i] for x in self.data], batch_first=True).to(device)
)
self.data = self.on_device_data
def get_seq_length(self, idx):
return self.seq_lengths[idx]
def get_all_actions(self):
return torch.cat([x[1] for x in self.data], dim=0)
def get_frames(self, idx, frames):
return [x[idx, frames] for x in self.data]
def __getitem__(self, idx):
return self.get_frames(idx, range(self.get_seq_length(idx)))
def __len__(self):
return len(self.seq_lengths)
def get_train_val_sliced(
traj_dataset: TrajectoryDataset,
train_fraction: float = 0.9,
random_seed: int = 42,
window_size: int = 10,
future_conditional: bool = False,
min_future_sep: int = 0,
future_seq_len: Optional[int] = None,
only_sample_tail: bool = False,
transform: Optional[Callable[[Any], Any]] = None,
num_extra_predicted_actions: Optional[int] = None,
frame_step: int = 1,
):
train, val = split_traj_datasets(
traj_dataset,
train_fraction=train_fraction,
random_seed=random_seed,
)
traj_slicer_kwargs = {
"window": window_size,
"future_conditional": future_conditional,
"min_future_sep": min_future_sep,
"future_seq_len": future_seq_len,
"only_sample_tail": only_sample_tail,
"transform": transform,
"num_extra_predicted_actions": num_extra_predicted_actions,
"frame_step": frame_step,
}
train_slices = TrajectorySlicerDataset(train, **traj_slicer_kwargs)
val_slices = TrajectorySlicerDataset(val, **traj_slicer_kwargs)
return train_slices, val_slices
def random_split_traj(
dataset: TrajectoryDataset,
lengths: Sequence[int],
generator: Optional[torch.Generator] = default_generator,
) -> List[TrajectorySubset]:
"""
(Modified from torch.utils.data.dataset.random_split)
Randomly split a trajectory dataset into non-overlapping new datasets of given lengths.
Optionally fix the generator for reproducible results, e.g.:
>>> random_split_traj(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
Args:
dataset (TrajectoryDataset): TrajectoryDataset to be split
lengths (sequence): lengths of splits to be produced
generator (Generator): Generator used for the random permutation.
"""
# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
)
indices = randperm(sum(lengths), generator=generator).tolist()
return [
TrajectorySubset(dataset, indices[offset - length : offset])
for offset, length in zip(_accumulate(lengths), lengths)
]
def split_traj_datasets(dataset, train_fraction=0.95, random_seed=42):
dataset_length = len(dataset)
lengths = [
int(train_fraction * dataset_length),
dataset_length - int(train_fraction * dataset_length),
]
train_set, val_set = random_split_traj(
dataset, lengths, generator=torch.Generator().manual_seed(random_seed)
)
return train_set, val_set
|