Spaces:
Sleeping
Sleeping
import torch | |
import torchmetrics | |
from torch import Tensor | |
from modules.fastspeech.tts_modules import RhythmRegulator | |
def linguistic_checks(pred, target, ph2word, mask=None): | |
if mask is None: | |
assert pred.shape == target.shape == ph2word.shape, \ | |
f'shapes of pred, target and ph2word mismatch: {pred.shape}, {target.shape}, {ph2word.shape}' | |
else: | |
assert pred.shape == target.shape == ph2word.shape == mask.shape, \ | |
f'shapes of pred, target and mask mismatch: {pred.shape}, {target.shape}, {ph2word.shape}, {mask.shape}' | |
assert pred.ndim == 2, f'all inputs should be 2D, but got {pred.shape}' | |
assert torch.any(ph2word > 0), 'empty word sequence' | |
assert torch.all(ph2word >= 0), 'unexpected negative word index' | |
assert ph2word.max() <= pred.shape[1], f'word index out of range: {ph2word.max()} > {pred.shape[1]}' | |
assert torch.all(pred >= 0.), f'unexpected negative ph_dur prediction' | |
assert torch.all(target >= 0.), f'unexpected negative ph_dur target' | |
class RhythmCorrectness(torchmetrics.Metric): | |
def __init__(self, *, tolerance, **kwargs): | |
super().__init__(**kwargs) | |
assert 0. < tolerance < 1., 'tolerance should be within (0, 1)' | |
self.tolerance = tolerance | |
self.add_state('correct', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum') | |
self.add_state('total', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum') | |
def update(self, pdur_pred: Tensor, pdur_target: Tensor, ph2word: Tensor, mask=None) -> None: | |
""" | |
:param pdur_pred: predicted ph_dur | |
:param pdur_target: reference ph_dur | |
:param ph2word: word division sequence | |
:param mask: valid or non-padding mask | |
""" | |
linguistic_checks(pdur_pred, pdur_target, ph2word, mask=mask) | |
shape = pdur_pred.shape[0], ph2word.max() + 1 | |
wdur_pred = pdur_pred.new_zeros(*shape).scatter_add( | |
1, ph2word, pdur_pred | |
)[:, 1:] # [B, T_ph] => [B, T_w] | |
wdur_target = pdur_target.new_zeros(*shape).scatter_add( | |
1, ph2word, pdur_target | |
)[:, 1:] # [B, T_ph] => [B, T_w] | |
if mask is None: | |
wdur_mask = torch.ones_like(wdur_pred, dtype=torch.bool) | |
else: | |
wdur_mask = mask.new_zeros(*shape).scatter_add( | |
1, ph2word, mask | |
)[:, 1:].bool() # [B, T_ph] => [B, T_w] | |
correct = torch.abs(wdur_pred - wdur_target) <= wdur_target * self.tolerance | |
correct &= wdur_mask | |
self.correct += correct.sum() | |
self.total += wdur_mask.sum() | |
def compute(self) -> Tensor: | |
return self.correct / self.total | |
class PhonemeDurationAccuracy(torchmetrics.Metric): | |
def __init__(self, *, tolerance, **kwargs): | |
super().__init__(**kwargs) | |
self.tolerance = tolerance | |
self.rr = RhythmRegulator() | |
self.add_state('accurate', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum') | |
self.add_state('total', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum') | |
def update(self, pdur_pred: Tensor, pdur_target: Tensor, ph2word: Tensor, mask=None) -> None: | |
""" | |
:param pdur_pred: predicted ph_dur | |
:param pdur_target: reference ph_dur | |
:param ph2word: word division sequence | |
:param mask: valid or non-padding mask | |
""" | |
linguistic_checks(pdur_pred, pdur_target, ph2word, mask=mask) | |
shape = pdur_pred.shape[0], ph2word.max() + 1 | |
wdur_target = pdur_target.new_zeros(*shape).scatter_add( | |
1, ph2word, pdur_target | |
)[:, 1:] # [B, T_ph] => [B, T_w] | |
pdur_align = self.rr(pdur_pred, ph2word=ph2word, word_dur=wdur_target) | |
accurate = torch.abs(pdur_align - pdur_target) <= pdur_target * self.tolerance | |
if mask is not None: | |
accurate &= mask | |
self.accurate += accurate.sum() | |
self.total += pdur_pred.numel() if mask is None else mask.sum() | |
def compute(self) -> Tensor: | |
return self.accurate / self.total | |