File size: 4,212 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import annotations

from typing import List, Tuple

import torch

from modules.core import (
    RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow
)


class RectifiedFlowONNX(RectifiedFlow):
    @property
    def backbone(self):
        return self.velocity_fn

    # We give up the setter for the property `backbone` because this will cause TorchScript to fail
    # @backbone.setter
    @torch.jit.unused
    def set_backbone(self, value):
        self.velocity_fn = value

    def sample_euler(self, x, t, dt: float, cond):
        x += self.velocity_fn(x, t * self.time_scale_factor, cond) * dt
        return x

    def norm_spec(self, x):
        k = (self.spec_max - self.spec_min) / 2.
        b = (self.spec_max + self.spec_min) / 2.
        return (x - b) / k

    def denorm_spec(self, x):
        k = (self.spec_max - self.spec_min) / 2.
        b = (self.spec_max + self.spec_min) / 2.
        return x * k + b

    def forward(self, condition, x_end=None, depth=None, steps: int = 10):
        condition = condition.transpose(1, 2)  # [1, T, H] => [1, H, T]
        device = condition.device
        n_frames = condition.shape[2]
        noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device)
        if x_end is None:
            t_start = 0.
            x = noise
        else:
            t_start = torch.max(1 - depth, torch.tensor(self.t_start, dtype=torch.float32, device=device))
            x_end = self.norm_spec(x_end).transpose(-2, -1)
            if self.num_feats == 1:
                x_end = x_end[:, None, :, :]
            if t_start <= 0.:
                x = noise
            elif t_start >= 1.:
                x = x_end
            else:
                x = t_start * x_end + (1 - t_start) * noise

        t_width = 1. - t_start
        if t_width >= 0.:
            dt = t_width / max(1, steps)
            for t in torch.arange(steps, dtype=torch.long, device=device)[:, None].float() * dt + t_start:
                x = self.sample_euler(x, t, dt, condition)

        if self.num_feats == 1:
            x = x.squeeze(1).permute(0, 2, 1)  # [B, 1, M, T] => [B, T, M]
        else:
            x = x.permute(0, 1, 3, 2)  # [B, F, M, T] => [B, F, T, M]
        x = self.denorm_spec(x)
        return x


class PitchRectifiedFlowONNX(RectifiedFlowONNX, PitchRectifiedFlow):
    def __init__(self, vmin: float, vmax: float,
                 cmin: float, cmax: float, repeat_bins,
                 time_scale_factor=1000,
                 backbone_type=None, backbone_args=None):
        self.vmin = vmin
        self.vmax = vmax
        self.cmin = cmin
        self.cmax = cmax
        super(PitchRectifiedFlow, self).__init__(
            vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
            time_scale_factor=time_scale_factor,
            backbone_type=backbone_type, backbone_args=backbone_args
        )

    def clamp_spec(self, x):
        return x.clamp(min=self.cmin, max=self.cmax)

    def denorm_spec(self, x):
        d = (self.spec_max - self.spec_min) / 2.
        m = (self.spec_max + self.spec_min) / 2.
        x = x * d + m
        x = x.mean(dim=-1)
        return x


class MultiVarianceRectifiedFlowONNX(RectifiedFlowONNX, MultiVarianceRectifiedFlow):
    def __init__(
            self, ranges: List[Tuple[float, float]],
            clamps: List[Tuple[float | None, float | None] | None],
            repeat_bins, time_scale_factor=1000,
            backbone_type=None, backbone_args=None
    ):
        assert len(ranges) == len(clamps)
        self.clamps = clamps
        vmin = [r[0] for r in ranges]
        vmax = [r[1] for r in ranges]
        if len(vmin) == 1:
            vmin = vmin[0]
        if len(vmax) == 1:
            vmax = vmax[0]
        super(MultiVarianceRectifiedFlow, self).__init__(
            vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
            time_scale_factor=time_scale_factor,
            backbone_type=backbone_type, backbone_args=backbone_args
        )

    def denorm_spec(self, x):
        d = (self.spec_max - self.spec_min) / 2.
        m = (self.spec_max + self.spec_min) / 2.
        x = x * d + m
        x = x.mean(dim=-1)
        return x