dynamo_ssl / models /projector /inverse_dynamics.py
jeffacce
initial commit
393d3de
import torch
import torch.nn as nn
from ..transformer_encoder import TransformerEncoderConfig, TransformerEncoder
class InverseDynamicsProjector(nn.Module):
def __init__(
self,
window_size: int,
input_dim: int,
n_layer: int,
n_head: int,
n_embd: int,
output_dim: int,
dropout: float,
):
super().__init__()
self.config = TransformerEncoderConfig(
block_size=window_size,
input_dim=input_dim,
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd,
output_dim=output_dim,
dropout=dropout,
)
self.model = TransformerEncoder(self.config)
def forward(self, obs_enc: torch.Tensor):
N, T, V = obs_enc.shape[:3]
obs_proj = []
for i in range(V):
this_view_proj = self.model(obs_enc[:, :, i]) # (N, T, Z)
obs_proj.append(this_view_proj)
return torch.stack(obs_proj, dim=2) # (N, T, V, Z)
def configure_optimizers(self, weight_decay, lr, betas):
return self.model.configure_optimizers(weight_decay, lr, betas)