|
import os |
|
import sys |
|
|
|
import torch |
|
import torchaudio |
|
from torch import nn |
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(current_dir) |
|
print(sys.path) |
|
from common import safe_log |
|
|
|
|
|
class FeatureExtractor(nn.Module): |
|
"""Base class for feature extractors.""" |
|
|
|
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: |
|
""" |
|
Extract features from the given audio. |
|
|
|
Args: |
|
audio (Tensor): Input audio waveform. |
|
|
|
Returns: |
|
Tensor: Extracted features of shape (B, C, L), where B is the batch size, |
|
C denotes output features, and L is the sequence length. |
|
""" |
|
raise NotImplementedError("Subclasses must implement the forward method.") |
|
|
|
|
|
class MelSpectrogramFeatures(FeatureExtractor): |
|
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None, |
|
n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"): |
|
super().__init__() |
|
if padding not in ["center", "same"]: |
|
raise ValueError("Padding must be 'center' or 'same'.") |
|
self.padding = padding |
|
self.mel_spec = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=sample_rate, |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
win_length=win_length, |
|
power=1, |
|
normalized=normalize, |
|
f_min=mel_fmin, |
|
f_max=mel_fmax, |
|
n_mels=n_mels, |
|
center=padding == "center", |
|
) |
|
|
|
def forward(self, audio, **kwargs): |
|
if self.padding == "same": |
|
pad = self.mel_spec.win_length - self.mel_spec.hop_length |
|
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") |
|
mel = self.mel_spec(audio) |
|
mel = safe_log(mel) |
|
return mel |