|
import tqdm |
|
import torch |
|
import einops |
|
import logging |
|
import torch.nn as nn |
|
from pathlib import Path |
|
|
|
from .vqvae import VqVae |
|
import torch.nn.functional as F |
|
from accelerate import Accelerator |
|
from typing import Dict, Optional, Tuple, Sequence |
|
from .gpt import GPT, GPTConfig |
|
from .utils import MLP, batch_idx |
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
def repeat_start_to_length(x: torch.Tensor, length: int, dim: int = 0): |
|
""" |
|
Pad tensor x to length along dim, repeating the first value at the start. |
|
""" |
|
pad_size = length - x.shape[dim] |
|
if pad_size <= 0: |
|
return x |
|
first_frame = x.index_select(dim, torch.tensor(0, device=x.device)) |
|
repeat_shape = [1] * len(x.shape) |
|
repeat_shape[dim] = pad_size |
|
pad = first_frame.repeat(*repeat_shape) |
|
return torch.cat([pad, x], dim=dim) |
|
|
|
|
|
class GroupedOptimizer(torch.optim.Optimizer): |
|
def __init__(self, optimizers: Sequence[torch.optim.Optimizer]): |
|
self.optimizers = optimizers |
|
param_groups = [] |
|
for optimizer in self.optimizers: |
|
param_groups.extend(optimizer.param_groups) |
|
super().__init__(param_groups, optimizers[0].defaults) |
|
|
|
def step(self, closure=None): |
|
for optimizer in self.optimizers: |
|
optimizer.step(closure) |
|
|
|
def zero_grad(self, set_to_none: bool = False): |
|
for optimizer in self.optimizers: |
|
optimizer.zero_grad(set_to_none) |
|
|
|
def state_dict(self): |
|
return { |
|
i: optimizer.state_dict() for i, optimizer in enumerate(self.optimizers) |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
for i, optimizer in enumerate(self.optimizers): |
|
optimizer.load_state_dict(state_dict[i]) |
|
|
|
|
|
class BehaviorTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
obs_dim: int, |
|
act_dim: int, |
|
goal_dim: int, |
|
views: int, |
|
vqvae_latent_dim: int, |
|
vqvae_n_embed: int, |
|
vqvae_groups: int, |
|
vqvae_fit_steps: int, |
|
vqvae_iters: int, |
|
n_layer: int, |
|
n_head: int, |
|
n_embd: int, |
|
dropout: float = 0.0, |
|
vqvae_encoder_loss_multiplier: float = 1.0, |
|
vqvae_batch_size: int = 1024, |
|
act_scale: float = 1.0, |
|
offset_loss_multiplier: float = 1.0e3, |
|
secondary_code_multiplier: float = 0.5, |
|
gamma: float = 2.0, |
|
obs_window_size: int = 10, |
|
act_window_size: int = 10, |
|
): |
|
super().__init__() |
|
self.GOAL_SPEC = ["concat", "stack", "unconditional"] |
|
self._obs_dim = obs_dim * views |
|
self._act_dim = act_dim |
|
self._goal_dim = goal_dim * views |
|
self._obs_window_size = obs_window_size |
|
self._act_window_size = act_window_size |
|
if goal_dim <= 0: |
|
self._cbet_method = "unconditional" |
|
else: |
|
self._cbet_method = "stack" |
|
|
|
self._gpt_model = GPT( |
|
GPTConfig( |
|
block_size=obs_window_size + act_window_size, |
|
n_layer=n_layer, |
|
n_head=n_head, |
|
n_embd=n_embd, |
|
dropout=dropout, |
|
input_dim=self._obs_dim + self._goal_dim, |
|
) |
|
) |
|
|
|
|
|
self.vqvae_fit_steps = vqvae_fit_steps |
|
self.vqvae_iters = vqvae_iters |
|
self.vqvae_batch_size = vqvae_batch_size |
|
self.vqvae_is_fit = False |
|
self._vqvae_model = VqVae( |
|
input_dim_h=act_window_size, |
|
input_dim_w=act_dim, |
|
n_latent_dims=vqvae_latent_dim, |
|
vqvae_n_embed=vqvae_n_embed, |
|
vqvae_groups=vqvae_groups, |
|
encoder_loss_multiplier=vqvae_encoder_loss_multiplier, |
|
act_scale=act_scale, |
|
) |
|
self._vqvae_optim = torch.optim.Adam( |
|
self._vqvae_model.parameters(), |
|
lr=1e-3, |
|
weight_decay=1e-4, |
|
) |
|
|
|
self._G = vqvae_groups |
|
self._C = vqvae_n_embed |
|
self._D = vqvae_latent_dim |
|
|
|
self._map_to_cbet_preds_bin = MLP( |
|
in_channels=self._gpt_model.config.output_dim, |
|
hidden_channels=[1024, 1024, self._G * self._C], |
|
) |
|
self._map_to_cbet_preds_offset = MLP( |
|
in_channels=self._gpt_model.config.output_dim, |
|
hidden_channels=[ |
|
1024, |
|
1024, |
|
self._G * self._C * (act_dim * self._act_window_size), |
|
], |
|
) |
|
|
|
self._collected_actions = [] |
|
self._offset_loss_multiplier = offset_loss_multiplier |
|
self._secondary_code_multiplier = secondary_code_multiplier |
|
self._criterion = FocalLoss(gamma=gamma) |
|
|
|
def _unpack_actions(self, action_seq: torch.Tensor): |
|
"""Unpack actions from (N, total_window, A) to (N, T, W, A)""" |
|
n, total_w, act_dim = action_seq.shape |
|
act_w = self._act_window_size |
|
obs_w = total_w + 1 - act_w |
|
result = torch.empty((n, obs_w, act_w, act_dim)).to(action_seq.device) |
|
for i in range(obs_w): |
|
result[:, i, :, :] = action_seq[:, i : i + act_w, :] |
|
return result |
|
|
|
def _maybe_fit_vq(self): |
|
if self.vqvae_is_fit or len(self._collected_actions) < self.vqvae_fit_steps: |
|
return |
|
all_actions = torch.cat(self._collected_actions) |
|
all_actions = einops.rearrange(all_actions, "N T W A -> (N T) (W A)") |
|
|
|
all_actions = torch.unique(all_actions, dim=0) |
|
all_actions = einops.rearrange( |
|
all_actions, |
|
"... (W A) -> ... W A", |
|
W=self._act_window_size, |
|
) |
|
pbar = tqdm.trange( |
|
self.vqvae_iters, |
|
desc="VQ training", |
|
disable=not accelerator.is_local_main_process, |
|
) |
|
for epoch in pbar: |
|
shuffle_idx = torch.randperm(len(all_actions)) |
|
for i in range(0, len(all_actions), self.vqvae_batch_size): |
|
batch = all_actions[shuffle_idx[i : i + self.vqvae_batch_size]] |
|
loss, vq_code, loss_dict = self._vqvae_model(batch) |
|
self._vqvae_optim.zero_grad() |
|
loss.backward() |
|
self._vqvae_optim.step() |
|
accelerator.wait_for_everyone() |
|
|
|
self._vqvae_model = accelerator.prepare(self._vqvae_model) |
|
self._vqvae_model = accelerator.unwrap_model(self._vqvae_model) |
|
self._vqvae_model.eval() |
|
print("n_different_codes", len(torch.unique(vq_code))) |
|
print("n_different_combinations", len(torch.unique(vq_code, dim=0))) |
|
print("losses", loss_dict) |
|
self.vqvae_is_fit = True |
|
|
|
def train(self, mode=True): |
|
|
|
super().train(mode) |
|
if self.vqvae_is_fit: |
|
self._vqvae_model.eval() |
|
|
|
def forward( |
|
self, |
|
obs_seq: torch.Tensor, |
|
goal_seq: Optional[torch.Tensor], |
|
action_seq: Optional[torch.Tensor], |
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]: |
|
if ( |
|
(action_seq is not None) |
|
and (len(self._collected_actions) < self.vqvae_fit_steps) |
|
and (self.training) |
|
): |
|
action_seq_all = accelerator.gather(action_seq) |
|
self._collected_actions.append(self._unpack_actions(action_seq_all)) |
|
self._maybe_fit_vq() |
|
|
|
if obs_seq.shape[1] < self._obs_window_size: |
|
obs_seq = repeat_start_to_length(obs_seq, self._obs_window_size, dim=1) |
|
if self._cbet_method == "unconditional": |
|
gpt_input = obs_seq |
|
elif self._cbet_method == "stack": |
|
gpt_input = torch.cat([goal_seq, obs_seq], dim=-1) |
|
else: |
|
raise NotImplementedError |
|
|
|
gpt_output = self._gpt_model(gpt_input) |
|
|
|
if self._cbet_method == "concat": |
|
gpt_output = gpt_output[:, goal_seq.size(1) :, :] |
|
else: |
|
gpt_output = gpt_output |
|
|
|
cbet_logits, cbet_offsets = self._forward_heads(gpt_output) |
|
predicted_action, decoded_action, sampled_centers, sampled_offsets = ( |
|
self._sample_action(cbet_logits, cbet_offsets) |
|
) |
|
|
|
if action_seq is not None: |
|
loss, loss_dict = self._calc_loss( |
|
action_seq, |
|
predicted_action, |
|
decoded_action, |
|
sampled_centers, |
|
cbet_logits, |
|
) |
|
else: |
|
loss, loss_dict = None, {} |
|
return predicted_action, loss, loss_dict |
|
|
|
def _calc_loss( |
|
self, action_seq, predicted_action, decoded_action, sampled_centers, cbet_logits |
|
): |
|
action_seq = self._unpack_actions(action_seq) |
|
_, action_bins = self._vqvae_model.get_code(action_seq) |
|
|
|
action_bins_flat = einops.rearrange(action_bins, "N T ... -> (N T) ...") |
|
cbet_logits_flat = einops.rearrange(cbet_logits, "N T ... -> (N T) ...") |
|
|
|
offset_loss = F.l1_loss(action_seq, predicted_action) |
|
|
|
|
|
action_diff = F.mse_loss(action_seq[:, -1, 0], predicted_action[:, -1, 0]) |
|
action_diff_tot = F.mse_loss(action_seq[:, -1], predicted_action[:, -1]) |
|
action_diff_mean_res1 = (action_seq - decoded_action)[:, -1, 0].abs().mean() |
|
action_diff_mean_res2 = (action_seq - predicted_action)[:, -1, 0].abs().mean() |
|
action_diff_max = (action_seq - predicted_action)[:, -1, 0].abs().max() |
|
cbet_loss1 = self._criterion(cbet_logits_flat[:, 0], action_bins_flat[:, 0]) |
|
cbet_loss2 = self._criterion(cbet_logits_flat[:, 1], action_bins_flat[:, 1]) |
|
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self._secondary_code_multiplier |
|
|
|
eq_mask = action_bins == sampled_centers |
|
equal_total_code_rate = (eq_mask.sum(-1) == self._G).float().mean() |
|
equal_single_code_rate = eq_mask[..., 0].float().mean() |
|
equal_single_code_rate2 = eq_mask[..., 1].float().mean() |
|
|
|
loss = cbet_loss + self._offset_loss_multiplier * offset_loss |
|
loss_dict = { |
|
"classification_loss": cbet_loss.detach().cpu().item(), |
|
"offset_loss": offset_loss.detach().cpu().item(), |
|
"total_loss": loss.detach().cpu().item(), |
|
"equal_total_code_rate": equal_total_code_rate, |
|
"equal_single_code_rate": equal_single_code_rate, |
|
"equal_single_code_rate2": equal_single_code_rate2, |
|
"action_diff": action_diff.detach().cpu().item(), |
|
"action_diff_tot": action_diff_tot.detach().cpu().item(), |
|
"action_diff_mean_res1": action_diff_mean_res1.detach().cpu().item(), |
|
"action_diff_mean_res2": action_diff_mean_res2.detach().cpu().item(), |
|
"action_diff_max": action_diff_max.detach().cpu().item(), |
|
} |
|
if not self.vqvae_is_fit: |
|
loss = loss * 0.0 |
|
return loss, loss_dict |
|
|
|
def _forward_heads(self, gpt_output): |
|
cbet_logits = self._map_to_cbet_preds_bin(gpt_output) |
|
cbet_logits = einops.rearrange(cbet_logits, "N T (G C) -> N T G C", G=self._G) |
|
cbet_offsets = self._map_to_cbet_preds_offset(gpt_output) |
|
cbet_offsets = einops.rearrange( |
|
cbet_offsets, |
|
"N T (G C W A) -> N T G C W A", |
|
G=self._G, |
|
C=self._C, |
|
W=self._act_window_size, |
|
A=self._act_dim, |
|
) |
|
return cbet_logits, cbet_offsets |
|
|
|
def _sample_action(self, cbet_logits, cbet_offsets): |
|
|
|
|
|
cbet_probs = torch.softmax(cbet_logits, dim=-1) |
|
sampled_centers = einops.rearrange( |
|
torch.multinomial(cbet_probs.view(-1, self._C), num_samples=1), |
|
"(N T G) 1 -> N T G", |
|
N=cbet_probs.shape[0], |
|
T=cbet_probs.shape[1], |
|
G=self._G, |
|
) |
|
centers = self._vqvae_model.draw_code_forward(sampled_centers).clone().detach() |
|
decoded_action = ( |
|
self._vqvae_model.get_action_from_latent(centers).clone().detach() |
|
) |
|
|
|
sampled_offsets = batch_idx(cbet_offsets, sampled_centers) |
|
|
|
sampled_offsets = sampled_offsets.sum(dim=2) |
|
predicted_action = decoded_action + sampled_offsets |
|
return predicted_action, decoded_action, sampled_centers, sampled_offsets |
|
|
|
def configure_optimizers(self, weight_decay, learning_rate, betas): |
|
optimizer1 = self._gpt_model.configure_optimizers( |
|
weight_decay=weight_decay, |
|
learning_rate=learning_rate, |
|
betas=betas, |
|
) |
|
|
|
optimizer1.add_param_group({"params": self._map_to_cbet_preds_bin.parameters()}) |
|
optimizer2 = torch.optim.AdamW( |
|
self._map_to_cbet_preds_offset.parameters(), |
|
lr=learning_rate, |
|
weight_decay=weight_decay, |
|
betas=betas, |
|
) |
|
optim = GroupedOptimizer([optimizer1, optimizer2]) |
|
return optim |
|
|
|
def load_model(self, path: Path): |
|
if (path / "cbet_model.pt").exists(): |
|
self.load_state_dict(torch.load(path / "cbet_model.pt")) |
|
else: |
|
logging.warning("No model found at %s", path) |
|
|
|
|
|
class FocalLoss(nn.Module): |
|
def __init__(self, gamma: float = 0, reduction: str = "mean"): |
|
super(FocalLoss, self).__init__() |
|
self.gamma = gamma |
|
if reduction not in ("mean", "sum", "none"): |
|
raise NotImplementedError |
|
self.reduction = reduction |
|
|
|
def forward(self, input, target): |
|
logpt = F.log_softmax(input, dim=-1) |
|
logpt = logpt.gather(1, target.view(-1, 1)).view(-1) |
|
pt = logpt.exp() |
|
|
|
loss = -1 * (1 - pt) ** self.gamma * logpt |
|
if self.reduction == "mean": |
|
return loss.mean() |
|
elif self.reduction == "sum": |
|
return loss.sum() |
|
else: |
|
return loss |
|
|