from einops import rearrange, reduce import torch import torch.nn as nn from torch.autograd import Function class DifferentiableEntropyFunction(Function): @staticmethod def forward(ctx, zq, basis, K, eps): zb = (zq + 1) / 2 zi = ((zb * basis).sum(-1)).to(torch.int64) cnt = torch.scatter_reduce(torch.zeros(2**K, device=zq.device, dtype=zq.dtype), 0, zi.flatten(), torch.ones_like(zi.flatten()).to(zq.dtype), 'sum') prob = (cnt + eps) / (cnt + eps).sum() H = -(prob * torch.log(prob)).sum() ctx.save_for_backward(zq, zi, prob) ctx.K = K return H @staticmethod def backward(ctx, grad_output): zq, zi, prob= ctx.saved_tensors grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K reord_grad = grad_array[zi.flatten()].reshape(zi.shape) grad_input = reord_grad.unsqueeze(-1) * zq return grad_input, None, None, None, None def codebook_entropy(zq, basis, K, eps=1e-4): return DifferentiableEntropyFunction.apply(zq, basis, K, eps) class BinarySphericalQuantizer(nn.Module): def __init__(self, embed_dim, beta, gamma0, gamma, zeta, input_format='bchw', soft_entropy=True, group_size=9, persample_entropy_compute='group', cb_entropy_compute='group', l2_norm=False, inv_temperature=1): super().__init__() self.embed_dim = embed_dim self.beta = beta # loss weight for commit loss self.gamma0 = gamma0 # loss weight for entropy penalty self.gamma = gamma # loss weight for entropy penalty self.zeta = zeta # loss weight for entire entropy penalty self.input_format = input_format assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" self.num_groups = self.embed_dim // group_size self.group_size = group_size assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" self.persample_entropy_compute = persample_entropy_compute self.cb_entropy_compute = cb_entropy_compute self.l2_norm = l2_norm self.inv_temperature = inv_temperature self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) self.num_dimensions = 2 ** embed_dim self.bits_per_index = embed_dim # we only need to keep the codebook portion up to the group size # because we approximate the H loss with this subcode group_codes = torch.arange(2 ** self.group_size) group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] self.register_buffer('group_codebook', group_codebook, persistent=False) self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf def quantize(self, z): assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" zhat = torch.where(z > 0, torch.tensor(1, dtype=z.dtype, device=z.device), torch.tensor(-1, dtype=z.dtype, device=z.device)) return z + (zhat - z).detach() def forward(self, z): if self.input_format == 'bchw': z = rearrange(z, 'b c h w -> b h w c') zq = self.quantize(z) indices = self.codes_to_indexes(zq.detach()) group_indices = self.codes_to_group_indexes(zq.detach()) if not self.training: used_codes = torch.unique(indices, return_counts=False) else: used_codes = None q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. if self.soft_entropy: persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy else: zb_by_sample= ((zq + 1)/2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy zq = zq * q_scale # commit loss commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) if self.input_format == 'bchw': zq = rearrange(zq, 'b h w c -> b c h w') return ( zq, commit_loss + self.zeta * entropy_penalty / self.inv_temperature, {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, "avg_prob": avg_prob} ) def soft_entropy_loss(self, z): # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size # the sub-code is the last group_size bits of the full code group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) # we calculate the distance between the divided_z and the codebook for each subgroup distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) prob = (-distance * self.inv_temperature).softmax(dim = -1) if self.persample_entropy_compute == 'analytical': if self.l2_norm: p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) else: p = torch.sigmoid(-4 * z * self.inv_temperature) prob = torch.stack([p, 1-p], dim=-1) per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() else: per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() # macro average of the probability of each subgroup avg_prob = reduce(prob, '... g d ->g d', 'mean') codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) # the approximation of the entropy is the sum of the entropy of each subgroup return per_sample_entropy, codebook_entropy.sum(), avg_prob def get_hard_per_sample_entropy(self, zb_by_sample): probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) persample_entropy = persample_entropy.sum(-1) return persample_entropy.mean() def codes_to_indexes(self, zhat): """Converts a `code` to an index in the codebook. Args: zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} """ assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) def codes_to_group_indexes(self, zhat): """Converts a `code` to a list of indexes (in groups) in the codebook. Args: zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} """ zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) def indexes_to_codes(self, indices): """Inverse of `indexes_to_codes`.""" indices = indices.unsqueeze(-1) codes_non_centered = torch.remainder( torch.floor_divide(indices, self.basis), 2 ) return codes_non_centered * 2 - 1 def group_indexes_to_codes(self, group_indices): """Inverse of `group_indexes_to_codes`.""" group_indices = group_indices.unsqueeze(-1) codes_non_centered = torch.remainder( torch.floor_divide(group_indices, self.group_basis), 2 ) codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') return codes_non_centered * 2 - 1 def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): if normalize: probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True) else: probs = count H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) return H def get_group_codebook_entry(self, group_indices): z_q = self.group_indexes_to_codes(group_indices) q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. z_q = z_q * q_scale if self.input_format == 'bchw': h, w = int(z_q.shape[1] ** 0.5) assert h * w == z_q.shape[1], 'Invalid sequence length' z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) return z_q def get_codebook_entry(self, indices): z_q = self.indexes_to_codes(indices) q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. z_q = z_q * q_scale if self.input_format == 'bchw': h, w = int(z_q.shape[1] ** 0.5) assert h * w == z_q.shape[1], 'Invalid sequence length' z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) return z_q if __name__ == "__main__": K = 8 # zq = torch.randint(0, 2, (4, 32, K), dtype=torch.bfloat16, device='cuda') * 2 - 1 zq = torch.zeros((4, 32, K), dtype=torch.bfloat16, device='cuda') * 2 - 1 basis = (2 ** torch.arange(K - 1, -1, -1)).to(torch.bfloat16).cuda() zq.requires_grad = True h = codebook_entropy(zq, basis, K) h.backward() print(zq.grad, zq)