File size: 2,807 Bytes
393d3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import torch
import einops
import numpy as np
from pathlib import Path
from typing import Optional
from datasets.core import TrajectoryDataset


class PushMultiviewTrajectoryDataset(TrajectoryDataset):
    def __init__(
        self,
        data_directory: os.PathLike,
        onehot_goals=False,
        subset_fraction: Optional[float] = None,
        prefetch: bool = False,
    ):
        self.data_directory = Path(data_directory)
        self.states = np.load(self.data_directory / "multimodal_push_observations.npy")
        self.actions = np.load(self.data_directory / "multimodal_push_actions.npy")
        self.masks = np.load(self.data_directory / "multimodal_push_masks.npy")

        self.subset_fraction = subset_fraction
        if self.subset_fraction:
            assert self.subset_fraction > 0 and self.subset_fraction <= 1
            n = int(len(self.states) * self.subset_fraction)
        else:
            n = len(self.states)
        self.states = self.states[:n]
        self.actions = self.actions[:n]
        self.masks = self.masks[:n]

        self.states = torch.from_numpy(self.states).float()
        self.actions = torch.from_numpy(self.actions).float() / 0.03
        self.masks = torch.from_numpy(self.masks).bool()
        self.prefetch = prefetch
        if self.prefetch:
            self.obses = []
            for i in range(n):
                vid_path = self.data_directory / "obs_multiview" / f"{i:03d}.pth"
                self.obses.append(torch.load(vid_path))
        self.onehot_goals = onehot_goals
        if self.onehot_goals:
            self.goals = torch.load(self.data_directory / "onehot_goals.pth").float()
            self.goals = self.goals[:n]

    def get_seq_length(self, idx):
        return int(self.masks[idx].sum().item())

    def get_all_actions(self):
        result = []
        # mask out invalid actions
        for i in range(len(self.masks)):
            T = int(self.masks[i].sum().item())
            result.append(self.actions[i, :T, :])
        return torch.cat(result, dim=0)

    def get_frames(self, idx, frames):
        if self.prefetch:
            obs = self.obses[idx][frames]
        else:
            obs = torch.load(self.data_directory / "obs_multiview" / f"{idx:03d}.pth")[
                frames
            ]
        obs = einops.rearrange(obs, "T V H W C -> T V C H W") / 255.0
        act = self.actions[idx, frames]
        mask = self.masks[idx, frames]
        if self.onehot_goals:
            goal = self.goals[idx, frames]
            return obs, act, mask, goal
        else:
            return obs, act, mask

    def __getitem__(self, idx):
        T = self.masks[idx].sum().int().item()
        return self.get_frames(idx, range(T))

    def __len__(self):
        return len(self.states)