mrfakename commited on
Commit
641fd3c
·
verified ·
1 Parent(s): 069a328

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/model/cfm.py CHANGED
@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
22
  from f5_tts.model.utils import (
23
  default,
24
  exists,
 
25
  lens_to_mask,
26
  list_str_to_idx,
27
  list_str_to_tensor,
@@ -92,6 +93,7 @@ class CFM(nn.Module):
92
  seed: int | None = None,
93
  max_duration=4096,
94
  vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
 
95
  no_ref_audio=False,
96
  duplicate_test=False,
97
  t_inter=0.1,
@@ -190,7 +192,10 @@ class CFM(nn.Module):
190
  y0 = (1 - t_start) * y0 + t_start * test_cond
191
  steps = int(steps * (1 - t_start))
192
 
193
- t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
 
 
 
194
  if sway_sampling_coef is not None:
195
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
196
 
 
22
  from f5_tts.model.utils import (
23
  default,
24
  exists,
25
+ get_epss_timesteps,
26
  lens_to_mask,
27
  list_str_to_idx,
28
  list_str_to_tensor,
 
93
  seed: int | None = None,
94
  max_duration=4096,
95
  vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
96
+ use_epss=True,
97
  no_ref_audio=False,
98
  duplicate_test=False,
99
  t_inter=0.1,
 
192
  y0 = (1 - t_start) * y0 + t_start * test_cond
193
  steps = int(steps * (1 - t_start))
194
 
195
+ if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
196
+ t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
197
+ else:
198
+ t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
199
  if sway_sampling_coef is not None:
200
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
201
 
src/f5_tts/model/utils.py CHANGED
@@ -189,3 +189,22 @@ def repetition_found(text, length=2, tolerance=10):
189
  if count > tolerance:
190
  return True
191
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  if count > tolerance:
190
  return True
191
  return False
192
+
193
+
194
+ # get the empirically pruned step for sampling
195
+
196
+
197
+ def get_epss_timesteps(n, device, dtype):
198
+ dt = 1 / 32
199
+ predefined_timesteps = {
200
+ 5: [0, 2, 4, 8, 16, 32],
201
+ 6: [0, 2, 4, 6, 8, 16, 32],
202
+ 7: [0, 2, 4, 6, 8, 16, 24, 32],
203
+ 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
204
+ 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
205
+ 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
206
+ }
207
+ t = predefined_timesteps.get(n, [])
208
+ if not t:
209
+ return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
210
+ return dt * torch.tensor(t, device=device, dtype=dtype)