#! /usr/bin/env python # -*- coding: utf-8 -*- # vim:fenc=utf-8 # # Copyright (c) 2021 Kazuhiro KOBAYASHI # # Distributed under terms of the MIT license. """ """ import librosa import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class LogMelSpectrogram(nn.Module): def __init__( self, fs=24000, hop_size=128, fft_size=1024, win_length=None, window="hann", center=True, pad_mode="reflect", n_mels=80, fmin=None, fmax=None, scaler_file=None, ): super().__init__() self.mag_layer = Magnitude(fs, hop_size, fft_size, win_length, window) self.mel_layer = Magnitude2LogMelSpectrogram(fs, fft_size, n_mels, fmin, fmax) if scaler_file is not None: self.melspc_scaler = LogMelSpectrogramScaler(scaler_file) else: self.melspc_scaler = None def forward(self, x): mag = self.mag_layer(x) log_melspc = self.mel_layer(mag) if self.melspc_scaler is not None: log_melspc = self.scaler_layer(log_melspc) return log_melspc class Magnitude(torch.nn.Module): def __init__( self, fs=24000, hop_size=128, fft_size=1024, win_length=None, window="hann", center=True, pad_mode="reflect", return_complex=True, ): super().__init__() self.hop_size = hop_size self.fft_size = fft_size self.win_length = fft_size if win_length is None else win_length self.window = window self.center = center self.pad_mode = pad_mode self.return_complex = return_complex def forward(self, x): """ x: (B, 1, T) ret: (B, T, fft_size // 2 + 1) """ f = getattr(torch, f"{self.window}_window") window = f(self.win_length, dtype=x.dtype, device=x.device) y = torch.stft( x, n_fft=self.fft_size, win_length=self.win_length, hop_length=self.hop_size, window=window, center=self.center, pad_mode=self.pad_mode, return_complex=self.return_complex, ) return y.abs().transpose(1, 2) class Magnitude2LogMelSpectrogram(torch.nn.Module): def __init__( self, fs=24000, fft_size=1024, n_mels=80, fmin=None, fmax=None, eps=1.0e-10 ): super().__init__() self.eps = eps fmin = 0 if fmin is None else fmin fmax = fs / 2 if fmax is None else fmax mel_basis = librosa.filters.mel( fs, fft_size, n_mels=n_mels, fmin=fmin, fmax=fmax ) self.register_buffer("mel_basis", torch.from_numpy(mel_basis.T).float()) def forward(self, x): melspc = torch.matmul(x, self.mel_basis) log_melspc = torch.clamp(melspc, min=self.eps).log10() return log_melspc class LogMelSpectrogram2LogMagnitude(nn.Module): def __init__( self, fs, fft_size, n_mels=80, fmin=None, fmax=None, eps=1.0e-10, roll_size=24, melspc_scaler_fn=None, ): super().__init__() self.eps = eps self.roll_size = roll_size self.melspc_scaler_fn = melspc_scaler_fn fmin = 0 if fmin is None else fmin fmax = fs / 2 if fmax is None else fmax mel_basis = librosa.filters.mel( fs, fft_size, n_mels=n_mels, fmin=fmin, fmax=fmax ) inv_mel_basis = np.linalg.pinv(mel_basis) self.register_buffer("inv_mel_basis", torch.from_numpy(inv_mel_basis.T).float()) def forward(self, x): if self.melspc_scaler_fn is not None: # denorm mlfb x = self.melspc_scaler_fn.inverse_transform(x) x = torch.pow(10.0, x) spc = torch.matmul(x, self.inv_mel_basis) log_spc = 10 * torch.clamp(spc, min=self.eps).log10() z = F.pad(log_spc, (self.roll_size // 2 - 1, self.roll_size // 2)) z = z.unfold(-1, self.roll_size, step=1) log_spc = torch.median(z, dim=-1)[0] return log_spc class CepstrumLiftering(nn.Module): def __init__(self, lifter_size=None): super().__init__() if lifter_size <= 0: raise ValueError("lifter_size must be > 0.") else: self.lifter_size = lifter_size def forward(self, x): cep = torch.fft.ifft(x, dim=-1) cep[..., self.lifter_size : -self.lifter_size] = 0 x = torch.fft.fft(cep, dim=-1) return x class LogMelSpectrogramScaler(nn.Module): def __init__(self, scaler): super().__init__() self.register_parameter( "mean", nn.Parameter(torch.from_numpy(scaler.mean_).float(), requires_grad=False), ) self.register_parameter( "scale", nn.Parameter( torch.from_numpy(scaler.var_).float().sqrt(), requires_grad=False ), ) def forward(self, x): return (x - self.mean) / self.scale def inverse_transform(self, x): return x * self.scale + self.mean