|
from typing import Optional, Tuple |
|
import torch |
|
|
|
from wenet.ssl.bestrq.mask import compute_mask_indices |
|
from wenet.utils.mask import make_pad_mask |
|
|
|
|
|
class BestRQModel(torch.nn.Module): |
|
def __init__( |
|
self, |
|
encoder: torch.nn.Module, |
|
input_dim: int = 256, |
|
embedding_dim: int = 256, |
|
num_embeddings: int = 8192, |
|
num_codebooks: int = 1, |
|
dropout_rate: float = 0.1, |
|
mask_prob: float = 0.01, |
|
mask_length: int = 10, |
|
min_masks: int = 2, |
|
layer_norm_epsilon=1e-5, |
|
) -> None: |
|
super().__init__() |
|
|
|
assert mask_prob > 0.0 |
|
|
|
self.mask_prob = mask_prob |
|
|
|
self.mask_length = mask_length |
|
self.min_masks = min_masks |
|
|
|
self.input_dropout = torch.nn.Dropout(dropout_rate) |
|
|
|
|
|
random_embedding_weight = torch.empty( |
|
num_codebooks, embedding_dim, num_embeddings, requires_grad=False |
|
) |
|
self.embeddings = torch.nn.init.normal_(random_embedding_weight) |
|
|
|
random_projection_weight = torch.empty( |
|
input_dim, embedding_dim, requires_grad=False |
|
) |
|
self.projection = torch.nn.init.xavier_normal_(random_projection_weight) |
|
|
|
mask_emb_weight = torch.Tensor(input_dim) |
|
mask_emb_weight.requires_grad = True |
|
self.mask_emb = torch.nn.init.normal_(mask_emb_weight, mean=0, std=0.1) |
|
|
|
self.input_layer_norm = torch.nn.LayerNorm(input_dim, layer_norm_epsilon) |
|
self.encoder = encoder |
|
self.encoder_top_n_out = torch.nn.parameter.Parameter( |
|
torch.Tensor(num_codebooks, self.encoder.output_size(), num_embeddings) |
|
) |
|
|
|
def forward( |
|
self, |
|
xs: torch.Tensor, |
|
xs_lens: torch.Tensor, |
|
text: Optional[torch.Tensor] = None, |
|
text_length: Optional[torch.Tensor] = None, |
|
): |
|
|
|
|
|
|
|
|
|
xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens) |
|
unmasked_xs = xs |
|
|
|
|
|
masked_xs, masked_masks = self._apply_mask(xs) |
|
|
|
target_ids = self._nearest_embedding_idx(unmasked_xs) |
|
|
|
out, out_mask = self._forward_encoder_blocks(masked_xs, masks, pos_emb, masks) |
|
|
|
out = out.unsqueeze(1) |
|
top_n_out = self.encoder_top_n_out.unsqueeze( |
|
0 |
|
) |
|
out = torch.matmul(out, top_n_out) |
|
|
|
|
|
loss = self._compute_loss(out, target_ids, out_mask.squeeze(1) * masked_masks) |
|
return {"loss": loss} |
|
|
|
def _compute_loss( |
|
self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor |
|
): |
|
input = input.transpose(1, 3) |
|
entropy = torch.nn.functional.cross_entropy( |
|
input, target, reduction="none" |
|
) |
|
|
|
loss = entropy * mask.unsqueeze(2) |
|
return loss.sum() / (mask.sum() * loss.size(2)) |
|
|
|
def _forward_encoder_blocks( |
|
self, |
|
xs: torch.Tensor, |
|
xs_masks: torch.Tensor, |
|
pos_emb: torch.Tensor, |
|
mask_pad: torch.Tensor, |
|
): |
|
masks = xs_masks |
|
for layer in self.encoder.encoders: |
|
xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad) |
|
if self.encoder.normalize_before: |
|
xs = self.encoder.after_norm(xs) |
|
|
|
|
|
|
|
return xs, masks |
|
|
|
def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor: |
|
xs = self.input_layer_norm(xs) |
|
xs = self.input_dropout(xs) |
|
xs = torch.matmul(xs, self.projection.to(xs.device)) |
|
|
|
B, T, C = xs.size() |
|
flattened_input = xs.view(-1, C) |
|
embeddings = self.embeddings.to( |
|
xs.device |
|
) |
|
|
|
distance = ( |
|
torch.sum(flattened_input**2, dim=1, keepdim=True).unsqueeze(0) |
|
+ torch.sum(embeddings**2, dim=1, keepdim=True) |
|
- 2 * torch.matmul(flattened_input.unsqueeze(0), embeddings) |
|
) |
|
|
|
out = torch.argmin(distance, dim=-1) |
|
out = out.transpose(0, 1) |
|
return out.reshape(B, T, -1) |
|
|
|
def _apply_mask(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
masks = compute_mask_indices( |
|
xs.size()[:-1], |
|
self.mask_prob, |
|
self.mask_length, |
|
self.min_masks, |
|
device=xs.device, |
|
) |
|
masks_expand = masks.unsqueeze(-1) |
|
|
|
mask_emb = self.mask_emb.to(xs.device).view(1, 1, -1) |
|
xs = torch.where(masks_expand, mask_emb, xs) |
|
return xs, masks |
|
|
|
def _forward_subsampling( |
|
self, xs: torch.Tensor, xs_lens: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
T = xs.size(1) |
|
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) |
|
if self.encoder.global_cmvn is not None: |
|
xs = self.encoder.global_cmvn(xs) |
|
xs, pos_emb, masks = self.encoder.embed(xs, masks) |
|
return xs, pos_emb, masks |
|
|