import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig class BaseVAE(nn.Module): def __init__(self, latent_dim=16): super().__init__() self.latent_dim = latent_dim input_dim = 3 * 32 * 32 self.encoder = nn.Sequential( nn.Linear(input_dim, 1024), nn.ReLU(), nn.Linear(1024, 512), nn.ReLU(), ) self.fc_mu = nn.Linear(512, latent_dim) self.fc_logvar = nn.Linear(512, latent_dim) self.decoder_input = nn.Linear(latent_dim, 512) self.decoder = nn.Sequential( nn.ReLU(), nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, input_dim), nn.Sigmoid() ) def encode(self, x): x = x.view(x.size(0), -1) x = self.encoder(x) mu = self.fc_mu(x) logvar = self.fc_logvar(x) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): x = self.decoder_input(z) x = self.decoder(x) x = x.view(-1, 3, 32, 32) return x def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) recon = self.decode(z) return recon, mu, logvar class VAEConfig(PretrainedConfig): model_type = "vae" def __init__(self, latent_dim=16, **kwargs): super().__init__(**kwargs) self.latent_dim = latent_dim class VAEModel(PreTrainedModel): config_class = VAEConfig def __init__(self, config): super().__init__(config) self.vae = BaseVAE(latent_dim=config.latent_dim) self.post_init() def forward(self, x): return self.vae(x) def encode(self, x): return self.vae.encode(x) def decode(self, z): return self.vae.decode(z) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = VAEModel.from_pretrained("BioMike/emoji-vae-init").to(device) model.eval()