File size: 1,156 Bytes
393d3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
from ..transformer_encoder import TransformerEncoderConfig, TransformerEncoder


class InverseDynamicsProjector(nn.Module):
    def __init__(
        self,
        window_size: int,
        input_dim: int,
        n_layer: int,
        n_head: int,
        n_embd: int,
        output_dim: int,
        dropout: float,
    ):
        super().__init__()
        self.config = TransformerEncoderConfig(
            block_size=window_size,
            input_dim=input_dim,
            n_layer=n_layer,
            n_head=n_head,
            n_embd=n_embd,
            output_dim=output_dim,
            dropout=dropout,
        )
        self.model = TransformerEncoder(self.config)

    def forward(self, obs_enc: torch.Tensor):
        N, T, V = obs_enc.shape[:3]
        obs_proj = []
        for i in range(V):
            this_view_proj = self.model(obs_enc[:, :, i])  # (N, T, Z)
            obs_proj.append(this_view_proj)
        return torch.stack(obs_proj, dim=2)  # (N, T, V, Z)

    def configure_optimizers(self, weight_decay, lr, betas):
        return self.model.configure_optimizers(weight_decay, lr, betas)