Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from typing import List, Tuple | |
import torch | |
from modules.core import ( | |
RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow | |
) | |
class RectifiedFlowONNX(RectifiedFlow): | |
def backbone(self): | |
return self.velocity_fn | |
# We give up the setter for the property `backbone` because this will cause TorchScript to fail | |
# @backbone.setter | |
def set_backbone(self, value): | |
self.velocity_fn = value | |
def sample_euler(self, x, t, dt: float, cond): | |
x += self.velocity_fn(x, t * self.time_scale_factor, cond) * dt | |
return x | |
def norm_spec(self, x): | |
k = (self.spec_max - self.spec_min) / 2. | |
b = (self.spec_max + self.spec_min) / 2. | |
return (x - b) / k | |
def denorm_spec(self, x): | |
k = (self.spec_max - self.spec_min) / 2. | |
b = (self.spec_max + self.spec_min) / 2. | |
return x * k + b | |
def forward(self, condition, x_end=None, depth=None, steps: int = 10): | |
condition = condition.transpose(1, 2) # [1, T, H] => [1, H, T] | |
device = condition.device | |
n_frames = condition.shape[2] | |
noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device) | |
if x_end is None: | |
t_start = 0. | |
x = noise | |
else: | |
t_start = torch.max(1 - depth, torch.tensor(self.t_start, dtype=torch.float32, device=device)) | |
x_end = self.norm_spec(x_end).transpose(-2, -1) | |
if self.num_feats == 1: | |
x_end = x_end[:, None, :, :] | |
if t_start <= 0.: | |
x = noise | |
elif t_start >= 1.: | |
x = x_end | |
else: | |
x = t_start * x_end + (1 - t_start) * noise | |
t_width = 1. - t_start | |
if t_width >= 0.: | |
dt = t_width / max(1, steps) | |
for t in torch.arange(steps, dtype=torch.long, device=device)[:, None].float() * dt + t_start: | |
x = self.sample_euler(x, t, dt, condition) | |
if self.num_feats == 1: | |
x = x.squeeze(1).permute(0, 2, 1) # [B, 1, M, T] => [B, T, M] | |
else: | |
x = x.permute(0, 1, 3, 2) # [B, F, M, T] => [B, F, T, M] | |
x = self.denorm_spec(x) | |
return x | |
class PitchRectifiedFlowONNX(RectifiedFlowONNX, PitchRectifiedFlow): | |
def __init__(self, vmin: float, vmax: float, | |
cmin: float, cmax: float, repeat_bins, | |
time_scale_factor=1000, | |
backbone_type=None, backbone_args=None): | |
self.vmin = vmin | |
self.vmax = vmax | |
self.cmin = cmin | |
self.cmax = cmax | |
super(PitchRectifiedFlow, self).__init__( | |
vmin=vmin, vmax=vmax, repeat_bins=repeat_bins, | |
time_scale_factor=time_scale_factor, | |
backbone_type=backbone_type, backbone_args=backbone_args | |
) | |
def clamp_spec(self, x): | |
return x.clamp(min=self.cmin, max=self.cmax) | |
def denorm_spec(self, x): | |
d = (self.spec_max - self.spec_min) / 2. | |
m = (self.spec_max + self.spec_min) / 2. | |
x = x * d + m | |
x = x.mean(dim=-1) | |
return x | |
class MultiVarianceRectifiedFlowONNX(RectifiedFlowONNX, MultiVarianceRectifiedFlow): | |
def __init__( | |
self, ranges: List[Tuple[float, float]], | |
clamps: List[Tuple[float | None, float | None] | None], | |
repeat_bins, time_scale_factor=1000, | |
backbone_type=None, backbone_args=None | |
): | |
assert len(ranges) == len(clamps) | |
self.clamps = clamps | |
vmin = [r[0] for r in ranges] | |
vmax = [r[1] for r in ranges] | |
if len(vmin) == 1: | |
vmin = vmin[0] | |
if len(vmax) == 1: | |
vmax = vmax[0] | |
super(MultiVarianceRectifiedFlow, self).__init__( | |
vmin=vmin, vmax=vmax, repeat_bins=repeat_bins, | |
time_scale_factor=time_scale_factor, | |
backbone_type=backbone_type, backbone_args=backbone_args | |
) | |
def denorm_spec(self, x): | |
d = (self.spec_max - self.spec_min) / 2. | |
m = (self.spec_max + self.spec_min) / 2. | |
x = x * d + m | |
x = x.mean(dim=-1) | |
return x | |