File size: 4,079 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torchmetrics
from torch import Tensor

from modules.fastspeech.tts_modules import RhythmRegulator


def linguistic_checks(pred, target, ph2word, mask=None):
    if mask is None:
        assert pred.shape == target.shape == ph2word.shape, \
            f'shapes of pred, target and ph2word mismatch: {pred.shape}, {target.shape}, {ph2word.shape}'
    else:
        assert pred.shape == target.shape == ph2word.shape == mask.shape, \
            f'shapes of pred, target and mask mismatch: {pred.shape}, {target.shape}, {ph2word.shape}, {mask.shape}'
    assert pred.ndim == 2, f'all inputs should be 2D, but got {pred.shape}'
    assert torch.any(ph2word > 0), 'empty word sequence'
    assert torch.all(ph2word >= 0), 'unexpected negative word index'
    assert ph2word.max() <= pred.shape[1], f'word index out of range: {ph2word.max()} > {pred.shape[1]}'
    assert torch.all(pred >= 0.), f'unexpected negative ph_dur prediction'
    assert torch.all(target >= 0.), f'unexpected negative ph_dur target'


class RhythmCorrectness(torchmetrics.Metric):
    def __init__(self, *, tolerance, **kwargs):
        super().__init__(**kwargs)
        assert 0. < tolerance < 1., 'tolerance should be within (0, 1)'
        self.tolerance = tolerance
        self.add_state('correct', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum')
        self.add_state('total', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum')

    def update(self, pdur_pred: Tensor, pdur_target: Tensor, ph2word: Tensor, mask=None) -> None:
        """

        :param pdur_pred: predicted ph_dur
        :param pdur_target: reference ph_dur
        :param ph2word: word division sequence
        :param mask: valid or non-padding mask
        """
        linguistic_checks(pdur_pred, pdur_target, ph2word, mask=mask)

        shape = pdur_pred.shape[0], ph2word.max() + 1
        wdur_pred = pdur_pred.new_zeros(*shape).scatter_add(
            1, ph2word, pdur_pred
        )[:, 1:]  # [B, T_ph] => [B, T_w]
        wdur_target = pdur_target.new_zeros(*shape).scatter_add(
            1, ph2word, pdur_target
        )[:, 1:]  # [B, T_ph] => [B, T_w]
        if mask is None:
            wdur_mask = torch.ones_like(wdur_pred, dtype=torch.bool)
        else:
            wdur_mask = mask.new_zeros(*shape).scatter_add(
                1, ph2word, mask
            )[:, 1:].bool()  # [B, T_ph] => [B, T_w]

        correct = torch.abs(wdur_pred - wdur_target) <= wdur_target * self.tolerance
        correct &= wdur_mask

        self.correct += correct.sum()
        self.total += wdur_mask.sum()

    def compute(self) -> Tensor:
        return self.correct / self.total


class PhonemeDurationAccuracy(torchmetrics.Metric):
    def __init__(self, *, tolerance, **kwargs):
        super().__init__(**kwargs)
        self.tolerance = tolerance
        self.rr = RhythmRegulator()
        self.add_state('accurate', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum')
        self.add_state('total', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum')

    def update(self, pdur_pred: Tensor, pdur_target: Tensor, ph2word: Tensor, mask=None) -> None:
        """

        :param pdur_pred: predicted ph_dur
        :param pdur_target: reference ph_dur
        :param ph2word: word division sequence
        :param mask: valid or non-padding mask
        """
        linguistic_checks(pdur_pred, pdur_target, ph2word, mask=mask)

        shape = pdur_pred.shape[0], ph2word.max() + 1
        wdur_target = pdur_target.new_zeros(*shape).scatter_add(
            1, ph2word, pdur_target
        )[:, 1:]  # [B, T_ph] => [B, T_w]
        pdur_align = self.rr(pdur_pred, ph2word=ph2word, word_dur=wdur_target)

        accurate = torch.abs(pdur_align - pdur_target) <= pdur_target * self.tolerance
        if mask is not None:
            accurate &= mask

        self.accurate += accurate.sum()
        self.total += pdur_pred.numel() if mask is None else mask.sum()

    def compute(self) -> Tensor:
        return self.accurate / self.total