|
import torch |
|
import einops |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from .vector_quantize_pytorch import ResidualVQ |
|
|
|
|
|
def weights_init_encoder(m): |
|
if isinstance(m, nn.Linear): |
|
nn.init.orthogonal_(m.weight.data) |
|
m.bias.data.fill_(0.0) |
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): |
|
assert m.weight.size(2) == m.weight.size(3) |
|
m.weight.data.fill_(0.0) |
|
m.bias.data.fill_(0.0) |
|
mid = m.weight.size(2) // 2 |
|
gain = nn.init.calculate_gain("relu") |
|
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) |
|
|
|
|
|
class EncoderMLP(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim, |
|
output_dim=16, |
|
hidden_dim=128, |
|
layer_num=1, |
|
last_activation=None, |
|
): |
|
super(EncoderMLP, self).__init__() |
|
layers = [] |
|
|
|
layers.append(nn.Linear(input_dim, hidden_dim)) |
|
layers.append(nn.ReLU()) |
|
for _ in range(layer_num): |
|
layers.append(nn.Linear(hidden_dim, hidden_dim)) |
|
layers.append(nn.ReLU()) |
|
|
|
self.encoder = nn.Sequential(*layers) |
|
self.fc = nn.Linear(hidden_dim, output_dim) |
|
|
|
if last_activation is not None: |
|
self.last_layer = last_activation |
|
else: |
|
self.last_layer = None |
|
self.apply(weights_init_encoder) |
|
|
|
def forward(self, x): |
|
h = self.encoder(x) |
|
state = self.fc(h) |
|
if self.last_layer: |
|
state = self.last_layer(state) |
|
return state |
|
|
|
|
|
class VqVae(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim_h=10, |
|
input_dim_w=9, |
|
n_latent_dims=512, |
|
vqvae_n_embed=32, |
|
vqvae_groups=4, |
|
encoder_loss_multiplier=1.0, |
|
act_scale=1.0, |
|
): |
|
super(VqVae, self).__init__() |
|
self.n_latent_dims = n_latent_dims |
|
self.input_dim_h = input_dim_h |
|
self.input_dim_w = input_dim_w |
|
self.rep_dim = self.n_latent_dims |
|
self.vqvae_n_embed = vqvae_n_embed |
|
self.vqvae_groups = vqvae_groups |
|
self.encoder_loss_multiplier = encoder_loss_multiplier |
|
self.act_scale = act_scale |
|
|
|
discrete_cfg = {"groups": self.vqvae_groups, "n_embed": self.vqvae_n_embed} |
|
|
|
self.vq_layer = ResidualVQ( |
|
dim=self.n_latent_dims, |
|
num_quantizers=discrete_cfg["groups"], |
|
codebook_size=self.vqvae_n_embed, |
|
) |
|
self.embedding_dim = self.n_latent_dims |
|
|
|
self.encoder = EncoderMLP( |
|
input_dim=input_dim_w * self.input_dim_h, output_dim=n_latent_dims |
|
) |
|
self.decoder = EncoderMLP( |
|
input_dim=n_latent_dims, output_dim=input_dim_w * self.input_dim_h |
|
) |
|
|
|
def draw_logits_forward(self, encoding_logits): |
|
z_embed = self.vq_layer.draw_logits_forward(encoding_logits) |
|
return z_embed |
|
|
|
def draw_code_forward(self, encoding_indices): |
|
with torch.no_grad(): |
|
z_embed = self.vq_layer.get_codes_from_indices(encoding_indices) |
|
z_embed = z_embed.sum(dim=0) |
|
return z_embed |
|
|
|
def get_action_from_latent(self, latent) -> torch.Tensor: |
|
output = self.decoder(latent) * self.act_scale |
|
return einops.rearrange(output, "... (T A) -> ... T A", A=self.input_dim_w) |
|
|
|
def preprocess(self, state): |
|
state = torch.Tensor(state) |
|
state = state / self.act_scale |
|
if self.input_dim_h == 1: |
|
state = state.squeeze(-2) |
|
else: |
|
state = einops.rearrange(state, "... T A -> ... (T A)") |
|
return state |
|
|
|
def get_code(self, state): |
|
state = self.preprocess(state) |
|
with torch.no_grad(): |
|
state_rep = self.encoder(state) |
|
state_vq, vq_code, vq_loss_state = self.vq_layer(state_rep) |
|
vq_loss_state = torch.sum(vq_loss_state) |
|
return state_vq, vq_code |
|
|
|
def forward(self, state): |
|
state = self.preprocess(state) |
|
state_rep = self.encoder(state) |
|
state_rep_shape = state_rep.shape[:-1] |
|
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1)) |
|
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat) |
|
state_vq = state_rep_flat.view(*state_rep_shape, -1) |
|
vq_code = vq_code.view(*state_rep_shape, -1) |
|
dec_out = self.decoder(state_vq) |
|
|
|
vq_loss_state = vq_loss_state.sum() |
|
encoder_loss = (state - dec_out).abs().mean() |
|
vqvae_recon_loss = F.mse_loss(state, dec_out) |
|
|
|
loss = encoder_loss * self.encoder_loss_multiplier + (vq_loss_state * 5) |
|
loss_dict = { |
|
"loss": loss.detach().clone(), |
|
"vq_loss_state": vq_loss_state.detach().clone(), |
|
"vqvae_recon_loss": vqvae_recon_loss.detach().clone(), |
|
"encoder_loss": encoder_loss.detach().clone(), |
|
} |
|
return loss, vq_code, loss_dict |
|
|
|
def state_dict( |
|
self, |
|
destination: str = None, |
|
prefix: str = "", |
|
keep_vars: bool = False, |
|
): |
|
if destination is None: |
|
destination = {} |
|
self.encoder.state_dict(destination, prefix + "encoder", keep_vars) |
|
self.decoder.state_dict(destination, prefix + "decoder", keep_vars) |
|
self.vq_layer.state_dict(destination, prefix + "vq_embedding", keep_vars) |
|
return destination |
|
|
|
def load_state_dict(self, state_dict): |
|
self.encoder.load_state_dict(state_dict["encoder"]) |
|
self.decoder.load_state_dict(state_dict["decoder"]) |
|
self.vq_layer.load_state_dict(state_dict["vq_embedding"]) |
|
|