Spaces:
Runtime error
Runtime error
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/model/cfm.py +6 -1
- src/f5_tts/model/utils.py +19 -0
src/f5_tts/model/cfm.py
CHANGED
@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
|
|
22 |
from f5_tts.model.utils import (
|
23 |
default,
|
24 |
exists,
|
|
|
25 |
lens_to_mask,
|
26 |
list_str_to_idx,
|
27 |
list_str_to_tensor,
|
@@ -92,6 +93,7 @@ class CFM(nn.Module):
|
|
92 |
seed: int | None = None,
|
93 |
max_duration=4096,
|
94 |
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
|
|
95 |
no_ref_audio=False,
|
96 |
duplicate_test=False,
|
97 |
t_inter=0.1,
|
@@ -190,7 +192,10 @@ class CFM(nn.Module):
|
|
190 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
191 |
steps = int(steps * (1 - t_start))
|
192 |
|
193 |
-
|
|
|
|
|
|
|
194 |
if sway_sampling_coef is not None:
|
195 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
196 |
|
|
|
22 |
from f5_tts.model.utils import (
|
23 |
default,
|
24 |
exists,
|
25 |
+
get_epss_timesteps,
|
26 |
lens_to_mask,
|
27 |
list_str_to_idx,
|
28 |
list_str_to_tensor,
|
|
|
93 |
seed: int | None = None,
|
94 |
max_duration=4096,
|
95 |
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
96 |
+
use_epss=True,
|
97 |
no_ref_audio=False,
|
98 |
duplicate_test=False,
|
99 |
t_inter=0.1,
|
|
|
192 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
193 |
steps = int(steps * (1 - t_start))
|
194 |
|
195 |
+
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
|
196 |
+
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
|
197 |
+
else:
|
198 |
+
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
199 |
if sway_sampling_coef is not None:
|
200 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
201 |
|
src/f5_tts/model/utils.py
CHANGED
@@ -189,3 +189,22 @@ def repetition_found(text, length=2, tolerance=10):
|
|
189 |
if count > tolerance:
|
190 |
return True
|
191 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
if count > tolerance:
|
190 |
return True
|
191 |
return False
|
192 |
+
|
193 |
+
|
194 |
+
# get the empirically pruned step for sampling
|
195 |
+
|
196 |
+
|
197 |
+
def get_epss_timesteps(n, device, dtype):
|
198 |
+
dt = 1 / 32
|
199 |
+
predefined_timesteps = {
|
200 |
+
5: [0, 2, 4, 8, 16, 32],
|
201 |
+
6: [0, 2, 4, 6, 8, 16, 32],
|
202 |
+
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
203 |
+
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
204 |
+
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
205 |
+
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
206 |
+
}
|
207 |
+
t = predefined_timesteps.get(n, [])
|
208 |
+
if not t:
|
209 |
+
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
|
210 |
+
return dt * torch.tensor(t, device=device, dtype=dtype)
|