Spaces:
Sleeping
Sleeping
Michael Hu
commited on
Commit
·
3ed3b5a
1
Parent(s):
31708ca
refator tts part
Browse files- utils/tts.py +86 -312
- utils/tts_base.py +152 -0
- utils/tts_dia.py +14 -115
- utils/tts_dummy.py +8 -22
- utils/tts_engines.py +322 -0
- utils/tts_factory.py +118 -0
utils/tts.py
CHANGED
@@ -1,85 +1,52 @@
|
|
1 |
-
import os
|
2 |
import logging
|
3 |
-
import time
|
4 |
-
import soundfile as sf
|
5 |
-
from gradio_client import Client
|
6 |
-
|
7 |
|
|
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
-
|
13 |
-
DIA_AVAILABLE = False
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
from kokoro import KPipeline
|
18 |
-
KOKORO_AVAILABLE = True
|
19 |
-
logger.info("Kokoro TTS engine is available")
|
20 |
-
except AttributeError as e:
|
21 |
-
# Specifically catch the EspeakWrapper.set_data_path error
|
22 |
-
if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
|
23 |
-
logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue, falling back to Kokoro FastAPI server")
|
24 |
-
else:
|
25 |
-
# Re-raise if it's a different error
|
26 |
-
logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
|
27 |
-
raise
|
28 |
-
except ImportError:
|
29 |
-
logger.warning("Kokoro TTS engine is not available")
|
30 |
|
|
|
31 |
class TTSEngine:
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def __init__(self, lang_code='z'):
|
33 |
-
"""Initialize TTS Engine
|
34 |
|
35 |
Args:
|
36 |
lang_code (str): Language code ('a' for US English, 'b' for British English,
|
37 |
'j' for Japanese, 'z' for Mandarin Chinese)
|
38 |
-
Note: lang_code is only used for Kokoro, not for Dia
|
39 |
"""
|
40 |
-
logger.info("Initializing
|
41 |
logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
|
42 |
-
self.engine_type = None
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
try:
|
47 |
-
self.pipeline = KPipeline(lang_code=lang_code)
|
48 |
-
self.engine_type = "kokoro"
|
49 |
-
logger.info("TTS engine successfully initialized with Kokoro")
|
50 |
-
except Exception as kokoro_err:
|
51 |
-
logger.error(f"Failed to initialize Kokoro pipeline: {str(kokoro_err)}")
|
52 |
-
logger.error(f"Error type: {type(kokoro_err).__name__}")
|
53 |
-
logger.info("Will try to fall back to Dia TTS engine")
|
54 |
-
|
55 |
-
if KOKORO_SPACE_AVAILABLE:
|
56 |
-
logger.info(f"Using Kokoro FastAPI server as primary TTS engine with language code: {lang_code}")
|
57 |
-
try:
|
58 |
-
self.client = Client("Remsky/Kokoro-TTS-Zero")
|
59 |
-
self.engine_type = "kokoro_space"
|
60 |
-
logger.info("TTS engine successfully initialized with Kokoro FastAPI server")
|
61 |
-
except Exception as kokoro_err:
|
62 |
-
logger.error(f"Failed to initialize Kokoro space: {str(kokoro_err)}")
|
63 |
-
logger.error(f"Error type: {type(kokoro_err).__name__}")
|
64 |
-
logger.info("Will try to fall back to Dia TTS engine")
|
65 |
-
|
66 |
-
# Try Dia if Kokoro is not available or failed to initialize
|
67 |
-
if self.engine_type is None and DIA_AVAILABLE:
|
68 |
-
logger.info("Using Dia as fallback TTS engine")
|
69 |
-
# For Dia, we don't need to initialize anything here
|
70 |
-
# The model will be lazy-loaded when needed
|
71 |
-
self.pipeline = None
|
72 |
-
self.client = None
|
73 |
-
self.engine_type = "dia"
|
74 |
-
logger.info("TTS engine initialized with Dia (lazy loading)")
|
75 |
|
76 |
-
#
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
self.
|
|
|
|
|
|
|
82 |
self.engine_type = "dummy"
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
85 |
"""Generate speech from text using available TTS engine
|
@@ -87,272 +54,79 @@ class TTSEngine:
|
|
87 |
Args:
|
88 |
text (str): Input text to synthesize
|
89 |
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
90 |
-
Note: voice parameter is only used for Kokoro, not for Dia
|
91 |
speed (float): Speech speed multiplier (0.5 to 2.0)
|
92 |
-
Note: speed parameter is only used for Kokoro, not for Dia
|
93 |
|
94 |
Returns:
|
95 |
str: Path to the generated audio file
|
96 |
"""
|
97 |
-
logger.info(f"
|
98 |
-
|
99 |
-
|
100 |
-
# Create output directory if it doesn't exist
|
101 |
-
os.makedirs("temp/outputs", exist_ok=True)
|
102 |
-
|
103 |
-
# Generate unique output path
|
104 |
-
output_path = f"temp/outputs/output_{int(time.time())}.wav"
|
105 |
-
|
106 |
-
# Use the appropriate TTS engine based on availability
|
107 |
-
if self.engine_type == "kokoro":
|
108 |
-
# Use Kokoro for TTS generation
|
109 |
-
generator = self.pipeline(text, voice=voice, speed=speed)
|
110 |
-
for _, _, audio in generator:
|
111 |
-
logger.info(f"Saving Kokoro audio to {output_path}")
|
112 |
-
sf.write(output_path, audio, 24000)
|
113 |
-
break
|
114 |
-
elif self.engine_type == "kokoro_space":
|
115 |
-
# Use Kokoro FastAPI server for TTS generation
|
116 |
-
logger.info("Generating speech using Kokoro FastAPI server")
|
117 |
-
logger.info(f"text to generate speech on is: {text}")
|
118 |
-
try:
|
119 |
-
result = self.client.predict(
|
120 |
-
text=text,
|
121 |
-
voice_names='af_nova',
|
122 |
-
speed=speed,
|
123 |
-
api_name="/generate_speech_from_ui"
|
124 |
-
)
|
125 |
-
logger.info(f"Received audio from Kokoro FastAPI server: {result}")
|
126 |
-
except Exception as e:
|
127 |
-
logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}")
|
128 |
-
logger.error(f"Error type: {type(e).__name__}")
|
129 |
-
logger.info("Falling back to dummy audio generation")
|
130 |
-
elif self.engine_type == "dia":
|
131 |
-
# Use Dia for TTS generation
|
132 |
-
try:
|
133 |
-
logger.info("Attempting to use Dia TTS for speech generation")
|
134 |
-
# Import here to avoid circular imports
|
135 |
-
try:
|
136 |
-
logger.info("Importing Dia speech generation module")
|
137 |
-
from utils.tts_dia import generate_speech as dia_generate_speech
|
138 |
-
logger.info("Successfully imported Dia speech generation function")
|
139 |
-
except ImportError as import_err:
|
140 |
-
logger.error(f"Failed to import Dia speech generation function: {str(import_err)}")
|
141 |
-
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
142 |
-
raise
|
143 |
-
|
144 |
-
# Call Dia's generate_speech function
|
145 |
-
logger.info("Calling Dia's generate_speech function")
|
146 |
-
output_path = dia_generate_speech(text)
|
147 |
-
logger.info(f"Generated audio with Dia: {output_path}")
|
148 |
-
except ImportError as import_err:
|
149 |
-
logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
|
150 |
-
logger.error("Falling back to dummy audio generation")
|
151 |
-
return self._generate_dummy_audio(output_path)
|
152 |
-
except Exception as dia_error:
|
153 |
-
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
|
154 |
-
logger.error(f"Error type: {type(dia_error).__name__}")
|
155 |
-
logger.error("Falling back to dummy audio generation")
|
156 |
-
# Fall back to dummy audio if Dia fails
|
157 |
-
return self._generate_dummy_audio(output_path)
|
158 |
-
else:
|
159 |
-
# Generate dummy audio as fallback
|
160 |
-
return self._generate_dummy_audio(output_path)
|
161 |
-
|
162 |
-
logger.info(f"Audio generation complete: {output_path}")
|
163 |
-
return output_path
|
164 |
-
|
165 |
-
except Exception as e:
|
166 |
-
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
|
167 |
-
raise
|
168 |
-
|
169 |
-
def _generate_dummy_audio(self, output_path):
|
170 |
-
"""Generate a dummy audio file with a simple sine wave
|
171 |
-
|
172 |
-
Args:
|
173 |
-
output_path (str): Path to save the dummy audio file
|
174 |
-
|
175 |
-
Returns:
|
176 |
-
str: Path to the generated dummy audio file
|
177 |
-
"""
|
178 |
-
import numpy as np
|
179 |
-
sample_rate = 24000
|
180 |
-
duration = 3.0 # seconds
|
181 |
-
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
182 |
-
tone = np.sin(2 * np.pi * 440 * t) * 0.3
|
183 |
-
|
184 |
-
logger.info(f"Saving dummy audio to {output_path}")
|
185 |
-
sf.write(output_path, tone, sample_rate)
|
186 |
-
logger.info(f"Dummy audio generation complete: {output_path}")
|
187 |
-
return output_path
|
188 |
-
|
189 |
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
|
190 |
"""Generate speech from text and yield each segment
|
191 |
|
192 |
Args:
|
193 |
text (str): Input text to synthesize
|
194 |
-
voice (str): Voice ID to use
|
195 |
-
speed (float): Speech speed multiplier
|
196 |
|
197 |
Yields:
|
198 |
tuple: (sample_rate, audio_data) pairs for each segment
|
199 |
"""
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
# and then yield it as a single chunk
|
210 |
-
try:
|
211 |
-
logger.info("Attempting to use Dia TTS for speech streaming")
|
212 |
-
# Import here to avoid circular imports
|
213 |
-
try:
|
214 |
-
logger.info("Importing required modules for Dia streaming")
|
215 |
-
import torch
|
216 |
-
logger.info("PyTorch successfully imported for Dia streaming")
|
217 |
-
|
218 |
-
try:
|
219 |
-
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
|
220 |
-
logger.info("Successfully imported Dia model and sample rate")
|
221 |
-
except ImportError as import_err:
|
222 |
-
logger.error(f"Failed to import Dia model for streaming: {str(import_err)}")
|
223 |
-
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
224 |
-
raise
|
225 |
-
except ImportError as torch_err:
|
226 |
-
logger.error(f"PyTorch import failed for Dia streaming: {str(torch_err)}")
|
227 |
-
raise
|
228 |
-
|
229 |
-
# Get the Dia model
|
230 |
-
logger.info("Getting Dia model instance")
|
231 |
-
try:
|
232 |
-
model = _get_model()
|
233 |
-
logger.info("Successfully obtained Dia model instance")
|
234 |
-
except Exception as model_err:
|
235 |
-
logger.error(f"Failed to get Dia model instance: {str(model_err)}")
|
236 |
-
logger.error(f"Error type: {type(model_err).__name__}")
|
237 |
-
raise
|
238 |
-
|
239 |
-
# Generate audio
|
240 |
-
logger.info("Generating audio with Dia model")
|
241 |
-
with torch.inference_mode():
|
242 |
-
output_audio_np = model.generate(
|
243 |
-
text,
|
244 |
-
max_tokens=None,
|
245 |
-
cfg_scale=3.0,
|
246 |
-
temperature=1.3,
|
247 |
-
top_p=0.95,
|
248 |
-
cfg_filter_top_k=35,
|
249 |
-
use_torch_compile=False,
|
250 |
-
verbose=False
|
251 |
-
)
|
252 |
-
|
253 |
-
if output_audio_np is not None:
|
254 |
-
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
255 |
-
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
256 |
-
else:
|
257 |
-
logger.warning("Dia model returned None for audio output")
|
258 |
-
logger.warning("Falling back to dummy audio stream")
|
259 |
-
# Fall back to dummy audio if Dia fails
|
260 |
-
yield from self._generate_dummy_audio_stream()
|
261 |
-
except ImportError as import_err:
|
262 |
-
logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
|
263 |
-
logger.error("Falling back to dummy audio stream")
|
264 |
-
# Fall back to dummy audio if Dia fails
|
265 |
-
yield from self._generate_dummy_audio_stream()
|
266 |
-
except Exception as dia_error:
|
267 |
-
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
268 |
-
logger.error(f"Error type: {type(dia_error).__name__}")
|
269 |
-
logger.error("Falling back to dummy audio stream")
|
270 |
-
# Fall back to dummy audio if Dia fails
|
271 |
-
yield from self._generate_dummy_audio_stream()
|
272 |
-
else:
|
273 |
-
# Generate dummy audio chunks as fallback
|
274 |
-
yield from self._generate_dummy_audio_stream()
|
275 |
-
|
276 |
-
except Exception as e:
|
277 |
-
logger.error(f"TTS streaming failed: {str(e)}", exc_info=True)
|
278 |
-
raise
|
279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
def _generate_dummy_audio_stream(self):
|
281 |
-
"""Generate dummy audio chunks
|
282 |
|
283 |
Yields:
|
284 |
tuple: (sample_rate, audio_data) pairs for each dummy segment
|
285 |
"""
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
# Create 3 chunks of dummy audio
|
291 |
-
for i in range(3):
|
292 |
-
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
293 |
-
freq = 440 + (i * 220) # Different frequency for each chunk
|
294 |
-
tone = np.sin(2 * np.pi * freq * t) * 0.3
|
295 |
-
yield sample_rate, tone
|
296 |
|
297 |
-
#
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
Args:
|
302 |
-
lang_code (str): Language code for the pipeline
|
303 |
-
|
304 |
-
Returns:
|
305 |
-
TTSEngine: Initialized TTS engine instance
|
306 |
-
"""
|
307 |
-
logger.info(f"Requesting TTS engine with language code: {lang_code}")
|
308 |
-
try:
|
309 |
-
import streamlit as st
|
310 |
-
logger.info("Streamlit detected, using cached TTS engine")
|
311 |
-
@st.cache_resource
|
312 |
-
def _get_engine():
|
313 |
-
logger.info("Creating cached TTS engine instance")
|
314 |
-
engine = TTSEngine(lang_code)
|
315 |
-
logger.info(f"Cached TTS engine created with type: {engine.engine_type}")
|
316 |
-
return engine
|
317 |
-
|
318 |
-
engine = _get_engine()
|
319 |
-
logger.info(f"Retrieved TTS engine from cache with type: {engine.engine_type}")
|
320 |
-
return engine
|
321 |
-
except ImportError:
|
322 |
-
logger.info("Streamlit not available, creating direct TTS engine instance")
|
323 |
-
engine = TTSEngine(lang_code)
|
324 |
-
logger.info(f"Direct TTS engine created with type: {engine.engine_type}")
|
325 |
-
return engine
|
326 |
|
327 |
-
def
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
logger.error(f"Error type: {type(e).__name__}")
|
353 |
-
if hasattr(e, '__traceback__'):
|
354 |
-
tb = e.__traceback__
|
355 |
-
while tb.tb_next:
|
356 |
-
tb = tb.tb_next
|
357 |
-
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
|
358 |
-
raise
|
|
|
|
|
1 |
import logging
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
# Configure logging
|
4 |
logger = logging.getLogger(__name__)
|
5 |
|
6 |
+
# Import from the new factory pattern implementation
|
7 |
+
from utils.tts_factory import get_tts_engine, generate_speech, TTSFactory
|
8 |
+
from utils.tts_engines import get_available_engines
|
|
|
9 |
|
10 |
+
# For backward compatibility
|
11 |
+
from utils.tts_engines import KOKORO_AVAILABLE, KOKORO_SPACE_AVAILABLE, DIA_AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# Backward compatibility class
|
14 |
class TTSEngine:
|
15 |
+
"""Legacy TTSEngine class for backward compatibility
|
16 |
+
|
17 |
+
This class is maintained for backward compatibility with existing code.
|
18 |
+
New code should use the factory pattern implementation directly.
|
19 |
+
"""
|
20 |
+
|
21 |
def __init__(self, lang_code='z'):
|
22 |
+
"""Initialize TTS Engine using the factory pattern
|
23 |
|
24 |
Args:
|
25 |
lang_code (str): Language code ('a' for US English, 'b' for British English,
|
26 |
'j' for Japanese, 'z' for Mandarin Chinese)
|
|
|
27 |
"""
|
28 |
+
logger.info("Initializing legacy TTSEngine wrapper")
|
29 |
logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
|
|
|
30 |
|
31 |
+
# Create the appropriate engine using the factory
|
32 |
+
self._engine = TTSFactory.create_engine(lang_code=lang_code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
# Set engine_type for backward compatibility
|
35 |
+
engine_class = self._engine.__class__.__name__
|
36 |
+
if 'Kokoro' in engine_class and 'Space' in engine_class:
|
37 |
+
self.engine_type = "kokoro_space"
|
38 |
+
elif 'Kokoro' in engine_class:
|
39 |
+
self.engine_type = "kokoro"
|
40 |
+
elif 'Dia' in engine_class:
|
41 |
+
self.engine_type = "dia"
|
42 |
+
else:
|
43 |
self.engine_type = "dummy"
|
44 |
+
|
45 |
+
# Set pipeline and client attributes for backward compatibility
|
46 |
+
self.pipeline = getattr(self._engine, 'pipeline', None)
|
47 |
+
self.client = getattr(self._engine, 'client', None)
|
48 |
+
|
49 |
+
logger.info(f"Legacy TTSEngine wrapper initialized with engine type: {self.engine_type}")
|
50 |
|
51 |
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
52 |
"""Generate speech from text using available TTS engine
|
|
|
54 |
Args:
|
55 |
text (str): Input text to synthesize
|
56 |
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
|
|
57 |
speed (float): Speech speed multiplier (0.5 to 2.0)
|
|
|
58 |
|
59 |
Returns:
|
60 |
str: Path to the generated audio file
|
61 |
"""
|
62 |
+
logger.info(f"Legacy TTSEngine wrapper calling generate_speech for text length: {len(text)}")
|
63 |
+
return self._engine.generate_speech(text, voice, speed)
|
64 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
|
66 |
"""Generate speech from text and yield each segment
|
67 |
|
68 |
Args:
|
69 |
text (str): Input text to synthesize
|
70 |
+
voice (str): Voice ID to use
|
71 |
+
speed (float): Speech speed multiplier
|
72 |
|
73 |
Yields:
|
74 |
tuple: (sample_rate, audio_data) pairs for each segment
|
75 |
"""
|
76 |
+
logger.info(f"Legacy TTSEngine wrapper calling generate_speech_stream for text length: {len(text)}")
|
77 |
+
yield from self._engine.generate_speech_stream(text, voice, speed)
|
78 |
+
|
79 |
+
# For backward compatibility
|
80 |
+
def _generate_dummy_audio(self, output_path):
|
81 |
+
"""Generate a dummy audio file with a simple sine wave (backward compatibility)
|
82 |
+
|
83 |
+
Args:
|
84 |
+
output_path (str): Path to save the dummy audio file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
+
Returns:
|
87 |
+
str: Path to the generated dummy audio file
|
88 |
+
"""
|
89 |
+
from utils.tts_base import DummyTTSEngine
|
90 |
+
dummy_engine = DummyTTSEngine()
|
91 |
+
return dummy_engine.generate_speech("", "", 1.0)
|
92 |
+
|
93 |
+
# For backward compatibility
|
94 |
def _generate_dummy_audio_stream(self):
|
95 |
+
"""Generate dummy audio chunks (backward compatibility)
|
96 |
|
97 |
Yields:
|
98 |
tuple: (sample_rate, audio_data) pairs for each dummy segment
|
99 |
"""
|
100 |
+
from utils.tts_base import DummyTTSEngine
|
101 |
+
dummy_engine = DummyTTSEngine()
|
102 |
+
yield from dummy_engine.generate_speech_stream("", "", 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
# Import the new implementations from tts_base
|
105 |
+
# These functions are already defined in tts_base.py and imported at the top of this file
|
106 |
+
# They are kept here as comments for reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
# def get_tts_engine(lang_code='a'):
|
109 |
+
# """Get or create TTS engine instance
|
110 |
+
#
|
111 |
+
# Args:
|
112 |
+
# lang_code (str): Language code for the pipeline
|
113 |
+
#
|
114 |
+
# Returns:
|
115 |
+
# TTSEngineBase: Initialized TTS engine instance
|
116 |
+
# """
|
117 |
+
# # Implementation moved to tts_base.py
|
118 |
+
# pass
|
119 |
+
|
120 |
+
# def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
121 |
+
# """Public interface for TTS generation
|
122 |
+
#
|
123 |
+
# Args:
|
124 |
+
# text (str): Input text to synthesize
|
125 |
+
# voice (str): Voice ID to use
|
126 |
+
# speed (float): Speech speed multiplier
|
127 |
+
#
|
128 |
+
# Returns:
|
129 |
+
# str: Path to generated audio file
|
130 |
+
# "\"""
|
131 |
+
# # Implementation moved to tts_base.py
|
132 |
+
# pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/tts_base.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
import soundfile as sf
|
5 |
+
import numpy as np
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from typing import Tuple, Generator, Optional
|
8 |
+
|
9 |
+
# Configure logging
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class TTSEngineBase(ABC):
|
13 |
+
"""Base class for all TTS engines
|
14 |
+
|
15 |
+
This abstract class defines the interface that all TTS engines must implement.
|
16 |
+
It also provides common utility methods for file handling and audio generation.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, lang_code: str = 'z'):
|
20 |
+
"""Initialize the TTS engine
|
21 |
+
|
22 |
+
Args:
|
23 |
+
lang_code (str): Language code ('a' for US English, 'b' for British English,
|
24 |
+
'j' for Japanese, 'z' for Mandarin Chinese)
|
25 |
+
Note: Not all engines support all language codes
|
26 |
+
"""
|
27 |
+
self.lang_code = lang_code
|
28 |
+
logger.info(f"Initializing {self.__class__.__name__} with language code: {lang_code}")
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
32 |
+
"""Generate speech from text
|
33 |
+
|
34 |
+
Args:
|
35 |
+
text (str): Input text to synthesize
|
36 |
+
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
37 |
+
Note: Not all engines support all voices
|
38 |
+
speed (float): Speech speed multiplier (0.5 to 2.0)
|
39 |
+
Note: Not all engines support speed adjustment
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
str: Path to the generated audio file
|
43 |
+
"""
|
44 |
+
pass
|
45 |
+
|
46 |
+
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
47 |
+
"""Generate speech from text and yield each segment
|
48 |
+
|
49 |
+
Args:
|
50 |
+
text (str): Input text to synthesize
|
51 |
+
voice (str): Voice ID to use
|
52 |
+
speed (float): Speech speed multiplier
|
53 |
+
|
54 |
+
Yields:
|
55 |
+
tuple: (sample_rate, audio_data) pairs for each segment
|
56 |
+
"""
|
57 |
+
# Default implementation: generate full audio and yield as a single chunk
|
58 |
+
output_path = self.generate_speech(text, voice, speed)
|
59 |
+
audio_data, sample_rate = sf.read(output_path)
|
60 |
+
yield sample_rate, audio_data
|
61 |
+
|
62 |
+
def _create_output_dir(self) -> str:
|
63 |
+
"""Create output directory for audio files
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
str: Path to the output directory
|
67 |
+
"""
|
68 |
+
output_dir = "temp/outputs"
|
69 |
+
os.makedirs(output_dir, exist_ok=True)
|
70 |
+
return output_dir
|
71 |
+
|
72 |
+
def _generate_output_path(self, prefix: str = "output") -> str:
|
73 |
+
"""Generate a unique output path for audio files
|
74 |
+
|
75 |
+
Args:
|
76 |
+
prefix (str): Prefix for the output filename
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
str: Path to the output file
|
80 |
+
"""
|
81 |
+
output_dir = self._create_output_dir()
|
82 |
+
timestamp = int(time.time())
|
83 |
+
return f"{output_dir}/{prefix}_{timestamp}.wav"
|
84 |
+
|
85 |
+
|
86 |
+
class DummyTTSEngine(TTSEngineBase):
|
87 |
+
"""Dummy TTS engine that generates a simple sine wave
|
88 |
+
|
89 |
+
This engine is used as a fallback when no other engines are available.
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, lang_code: str = 'z'):
|
93 |
+
super().__init__(lang_code)
|
94 |
+
logger.warning("Using dummy TTS implementation as no other engines are available")
|
95 |
+
|
96 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
97 |
+
"""Generate a dummy audio file with a simple sine wave
|
98 |
+
|
99 |
+
Args:
|
100 |
+
text (str): Input text (not used)
|
101 |
+
voice (str): Voice ID (not used)
|
102 |
+
speed (float): Speed multiplier (not used)
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
str: Path to the generated dummy audio file
|
106 |
+
"""
|
107 |
+
logger.info(f"Generating dummy speech for text length: {len(text)}")
|
108 |
+
|
109 |
+
# Generate unique output path
|
110 |
+
output_path = self._generate_output_path("dummy")
|
111 |
+
|
112 |
+
# Generate a simple sine wave
|
113 |
+
sample_rate = 24000
|
114 |
+
duration = 3.0 # seconds
|
115 |
+
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
116 |
+
tone = np.sin(2 * np.pi * 440 * t) * 0.3
|
117 |
+
|
118 |
+
# Save the audio file
|
119 |
+
logger.info(f"Saving dummy audio to {output_path}")
|
120 |
+
sf.write(output_path, tone, sample_rate)
|
121 |
+
logger.info(f"Dummy audio generation complete: {output_path}")
|
122 |
+
|
123 |
+
return output_path
|
124 |
+
|
125 |
+
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
126 |
+
"""Generate dummy audio chunks with simple sine waves
|
127 |
+
|
128 |
+
Args:
|
129 |
+
text (str): Input text (not used)
|
130 |
+
voice (str): Voice ID (not used)
|
131 |
+
speed (float): Speed multiplier (not used)
|
132 |
+
|
133 |
+
Yields:
|
134 |
+
tuple: (sample_rate, audio_data) pairs for each dummy segment
|
135 |
+
"""
|
136 |
+
logger.info(f"Generating dummy speech stream for text length: {len(text)}")
|
137 |
+
|
138 |
+
sample_rate = 24000
|
139 |
+
duration = 1.0 # seconds per chunk
|
140 |
+
|
141 |
+
# Create 3 chunks of dummy audio
|
142 |
+
for i in range(3):
|
143 |
+
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
144 |
+
freq = 440 + (i * 220) # Different frequency for each chunk
|
145 |
+
tone = np.sin(2 * np.pi * freq * t) * 0.3
|
146 |
+
yield sample_rate, tone
|
147 |
+
|
148 |
+
|
149 |
+
# Factory functionality moved to tts_factory.py to avoid circular imports
|
150 |
+
|
151 |
+
|
152 |
+
# Note: Backward compatibility functions moved to tts_factory.py
|
utils/tts_dia.py
CHANGED
@@ -68,6 +68,9 @@ def _get_model() -> Dia:
|
|
68 |
def generate_speech(text: str, language: str = "zh") -> str:
|
69 |
"""Public interface for TTS generation using Dia model
|
70 |
|
|
|
|
|
|
|
71 |
Args:
|
72 |
text (str): Input text to synthesize
|
73 |
language (str): Language code (not used in Dia model, kept for API compatibility)
|
@@ -75,122 +78,18 @@ def generate_speech(text: str, language: str = "zh") -> str:
|
|
75 |
Returns:
|
76 |
str: Path to the generated audio file
|
77 |
"""
|
78 |
-
logger.info(f"
|
79 |
-
logger.info(f"Text content (first 50 chars): {text[:50]}...")
|
80 |
-
|
81 |
-
# Create output directory if it doesn't exist
|
82 |
-
output_dir = "temp/outputs"
|
83 |
-
logger.info(f"Ensuring output directory exists: {output_dir}")
|
84 |
-
try:
|
85 |
-
os.makedirs(output_dir, exist_ok=True)
|
86 |
-
logger.info(f"Output directory ready: {output_dir}")
|
87 |
-
except PermissionError as perm_err:
|
88 |
-
logger.error(f"Permission error creating output directory: {perm_err}")
|
89 |
-
# Fall back to dummy TTS
|
90 |
-
logger.info("Falling back to dummy TTS due to directory creation error")
|
91 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
92 |
-
return dummy_generate_speech(text, language)
|
93 |
-
except Exception as dir_err:
|
94 |
-
logger.error(f"Error creating output directory: {dir_err}")
|
95 |
-
# Fall back to dummy TTS
|
96 |
-
logger.info("Falling back to dummy TTS due to directory creation error")
|
97 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
98 |
-
return dummy_generate_speech(text, language)
|
99 |
-
|
100 |
-
# Generate unique output path
|
101 |
-
timestamp = int(time.time())
|
102 |
-
output_path = f"{output_dir}/output_{timestamp}.wav"
|
103 |
-
logger.info(f"Output will be saved to: {output_path}")
|
104 |
-
|
105 |
-
# Get the model
|
106 |
-
logger.info("Retrieving Dia model instance")
|
107 |
-
try:
|
108 |
-
model = _get_model()
|
109 |
-
logger.info("Successfully retrieved Dia model instance")
|
110 |
-
except Exception as model_err:
|
111 |
-
logger.error(f"Failed to get Dia model: {model_err}")
|
112 |
-
logger.error(f"Error type: {type(model_err).__name__}")
|
113 |
-
# Fall back to dummy TTS
|
114 |
-
logger.info("Falling back to dummy TTS due to model loading error")
|
115 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
116 |
-
return dummy_generate_speech(text, language)
|
117 |
|
118 |
-
#
|
119 |
-
|
120 |
-
start_time = time.time()
|
121 |
|
122 |
try:
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
cfg_scale=3.0,
|
129 |
-
temperature=1.3,
|
130 |
-
top_p=0.95,
|
131 |
-
cfg_filter_top_k=35,
|
132 |
-
use_torch_compile=False, # Keep False for stability
|
133 |
-
verbose=False
|
134 |
-
)
|
135 |
-
logger.info("Model.generate() completed")
|
136 |
-
except RuntimeError as rt_err:
|
137 |
-
logger.error(f"Runtime error during generation: {rt_err}")
|
138 |
-
if "CUDA out of memory" in str(rt_err):
|
139 |
-
logger.error("CUDA out of memory error - consider reducing batch size or model size")
|
140 |
# Fall back to dummy TTS
|
141 |
-
|
142 |
-
|
143 |
-
return
|
144 |
-
except Exception as gen_err:
|
145 |
-
logger.error(f"Error during audio generation: {gen_err}")
|
146 |
-
logger.error(f"Error type: {type(gen_err).__name__}")
|
147 |
-
# Fall back to dummy TTS
|
148 |
-
logger.info("Falling back to dummy TTS due to error during generation")
|
149 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
150 |
-
return dummy_generate_speech(text, language)
|
151 |
-
|
152 |
-
end_time = time.time()
|
153 |
-
generation_time = end_time - start_time
|
154 |
-
logger.info(f"Generation finished in {generation_time:.2f} seconds")
|
155 |
-
|
156 |
-
# Process the output
|
157 |
-
if output_audio_np is not None:
|
158 |
-
logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}")
|
159 |
-
logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}")
|
160 |
-
|
161 |
-
# Apply a slight slowdown for better quality (0.94x speed)
|
162 |
-
speed_factor = 0.94
|
163 |
-
original_len = len(output_audio_np)
|
164 |
-
target_len = int(original_len / speed_factor)
|
165 |
-
|
166 |
-
logger.info(f"Applying speed adjustment factor: {speed_factor}")
|
167 |
-
if target_len != original_len and target_len > 0:
|
168 |
-
try:
|
169 |
-
x_original = np.arange(original_len)
|
170 |
-
x_resampled = np.linspace(0, original_len - 1, target_len)
|
171 |
-
output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
|
172 |
-
logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
|
173 |
-
except Exception as resample_err:
|
174 |
-
logger.error(f"Error during audio resampling: {resample_err}")
|
175 |
-
logger.warning("Using original audio without resampling")
|
176 |
-
|
177 |
-
# Save the audio file
|
178 |
-
logger.info(f"Saving audio to file: {output_path}")
|
179 |
-
try:
|
180 |
-
sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
|
181 |
-
logger.info(f"Audio successfully saved to {output_path}")
|
182 |
-
except Exception as save_err:
|
183 |
-
logger.error(f"Error saving audio file: {save_err}")
|
184 |
-
logger.error(f"Error type: {type(save_err).__name__}")
|
185 |
-
# Fall back to dummy TTS
|
186 |
-
logger.info("Falling back to dummy TTS due to error saving audio file")
|
187 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
188 |
-
return dummy_generate_speech(text, language)
|
189 |
-
|
190 |
-
return output_path
|
191 |
-
else:
|
192 |
-
logger.warning("Generation produced no output (None returned from model)")
|
193 |
-
logger.warning("This may indicate a model configuration issue or empty input text")
|
194 |
-
dummy_path = f"{output_dir}/dummy_{timestamp}.wav"
|
195 |
-
logger.warning(f"Returning dummy audio path: {dummy_path}")
|
196 |
-
return dummy_path
|
|
|
68 |
def generate_speech(text: str, language: str = "zh") -> str:
|
69 |
"""Public interface for TTS generation using Dia model
|
70 |
|
71 |
+
This is a legacy function maintained for backward compatibility.
|
72 |
+
New code should use the factory pattern implementation directly.
|
73 |
+
|
74 |
Args:
|
75 |
text (str): Input text to synthesize
|
76 |
language (str): Language code (not used in Dia model, kept for API compatibility)
|
|
|
78 |
Returns:
|
79 |
str: Path to the generated audio file
|
80 |
"""
|
81 |
+
logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
+
# Use the new implementation via factory pattern
|
84 |
+
from utils.tts_engines import DiaTTSEngine
|
|
|
85 |
|
86 |
try:
|
87 |
+
# Create a Dia engine and generate speech
|
88 |
+
dia_engine = DiaTTSEngine(language)
|
89 |
+
return dia_engine.generate_speech(text)
|
90 |
+
except Exception as e:
|
91 |
+
logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
# Fall back to dummy TTS
|
93 |
+
from utils.tts_base import DummyTTSEngine
|
94 |
+
dummy_engine = DummyTTSEngine()
|
95 |
+
return dummy_engine.generate_speech(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/tts_dummy.py
CHANGED
@@ -1,25 +1,11 @@
|
|
1 |
def generate_speech(text: str, language: str = "zh") -> str:
|
2 |
-
"""Public interface for TTS generation
|
3 |
-
import os
|
4 |
-
import numpy as np
|
5 |
-
import soundfile as sf
|
6 |
-
import time
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
# Generate a simple sine wave as dummy audio
|
17 |
-
sample_rate = 24000
|
18 |
-
duration = 2.0 # seconds
|
19 |
-
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
20 |
-
tone = np.sin(2 * np.pi * 440 * t) * 0.3
|
21 |
-
|
22 |
-
# Save the audio file
|
23 |
-
sf.write(output_path, tone, sample_rate)
|
24 |
-
|
25 |
-
return output_path
|
|
|
1 |
def generate_speech(text: str, language: str = "zh") -> str:
|
2 |
+
"""Public interface for TTS generation
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
This is a legacy function maintained for backward compatibility.
|
5 |
+
New code should use the factory pattern implementation directly.
|
6 |
+
"""
|
7 |
+
from utils.tts_base import DummyTTSEngine
|
8 |
|
9 |
+
# Create a dummy engine and generate speech
|
10 |
+
dummy_engine = DummyTTSEngine()
|
11 |
+
return dummy_engine.generate_speech(text, "af_heart", 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/tts_engines.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import soundfile as sf
|
6 |
+
from typing import Dict, List, Optional, Tuple, Generator, Any
|
7 |
+
|
8 |
+
from utils.tts_base import TTSEngineBase, DummyTTSEngine
|
9 |
+
|
10 |
+
# Configure logging
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# Flag to track TTS engine availability
|
14 |
+
KOKORO_AVAILABLE = False
|
15 |
+
KOKORO_SPACE_AVAILABLE = True
|
16 |
+
DIA_AVAILABLE = False
|
17 |
+
|
18 |
+
# Try to import Kokoro
|
19 |
+
try:
|
20 |
+
from kokoro import KPipeline
|
21 |
+
KOKORO_AVAILABLE = True
|
22 |
+
logger.info("Kokoro TTS engine is available")
|
23 |
+
except AttributeError as e:
|
24 |
+
# Specifically catch the EspeakWrapper.set_data_path error
|
25 |
+
if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
|
26 |
+
logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue, falling back to Kokoro FastAPI server")
|
27 |
+
else:
|
28 |
+
# Re-raise if it's a different error
|
29 |
+
logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
|
30 |
+
raise
|
31 |
+
except ImportError:
|
32 |
+
logger.warning("Kokoro TTS engine is not available")
|
33 |
+
|
34 |
+
# Try to import Dia dependencies to check availability
|
35 |
+
try:
|
36 |
+
import torch
|
37 |
+
from dia.model import Dia
|
38 |
+
DIA_AVAILABLE = True
|
39 |
+
logger.info("Dia TTS engine is available")
|
40 |
+
except ImportError:
|
41 |
+
logger.warning("Dia TTS engine is not available")
|
42 |
+
|
43 |
+
|
44 |
+
class KokoroTTSEngine(TTSEngineBase):
|
45 |
+
"""Kokoro TTS engine implementation
|
46 |
+
|
47 |
+
This engine uses the Kokoro library for TTS generation.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, lang_code: str = 'z'):
|
51 |
+
super().__init__(lang_code)
|
52 |
+
try:
|
53 |
+
self.pipeline = KPipeline(lang_code=lang_code)
|
54 |
+
logger.info("Kokoro TTS engine successfully initialized")
|
55 |
+
except Exception as e:
|
56 |
+
logger.error(f"Failed to initialize Kokoro pipeline: {str(e)}")
|
57 |
+
logger.error(f"Error type: {type(e).__name__}")
|
58 |
+
raise
|
59 |
+
|
60 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
61 |
+
"""Generate speech using Kokoro TTS engine
|
62 |
+
|
63 |
+
Args:
|
64 |
+
text (str): Input text to synthesize
|
65 |
+
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
66 |
+
speed (float): Speech speed multiplier (0.5 to 2.0)
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
str: Path to the generated audio file
|
70 |
+
"""
|
71 |
+
logger.info(f"Generating speech with Kokoro for text length: {len(text)}")
|
72 |
+
|
73 |
+
# Generate unique output path
|
74 |
+
output_path = self._generate_output_path()
|
75 |
+
|
76 |
+
# Generate speech
|
77 |
+
generator = self.pipeline(text, voice=voice, speed=speed)
|
78 |
+
for _, _, audio in generator:
|
79 |
+
logger.info(f"Saving Kokoro audio to {output_path}")
|
80 |
+
sf.write(output_path, audio, 24000)
|
81 |
+
break
|
82 |
+
|
83 |
+
logger.info(f"Kokoro audio generation complete: {output_path}")
|
84 |
+
return output_path
|
85 |
+
|
86 |
+
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
87 |
+
"""Generate speech stream using Kokoro TTS engine
|
88 |
+
|
89 |
+
Args:
|
90 |
+
text (str): Input text to synthesize
|
91 |
+
voice (str): Voice ID to use
|
92 |
+
speed (float): Speech speed multiplier
|
93 |
+
|
94 |
+
Yields:
|
95 |
+
tuple: (sample_rate, audio_data) pairs for each segment
|
96 |
+
"""
|
97 |
+
logger.info(f"Generating speech stream with Kokoro for text length: {len(text)}")
|
98 |
+
|
99 |
+
# Generate speech stream
|
100 |
+
generator = self.pipeline(text, voice=voice, speed=speed)
|
101 |
+
for _, _, audio in generator:
|
102 |
+
yield 24000, audio
|
103 |
+
|
104 |
+
|
105 |
+
class KokoroSpaceTTSEngine(TTSEngineBase):
|
106 |
+
"""Kokoro Space TTS engine implementation
|
107 |
+
|
108 |
+
This engine uses the Kokoro FastAPI server for TTS generation.
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(self, lang_code: str = 'z'):
|
112 |
+
super().__init__(lang_code)
|
113 |
+
try:
|
114 |
+
from gradio_client import Client
|
115 |
+
self.client = Client("Remsky/Kokoro-TTS-Zero")
|
116 |
+
logger.info("Kokoro Space TTS engine successfully initialized")
|
117 |
+
except Exception as e:
|
118 |
+
logger.error(f"Failed to initialize Kokoro Space client: {str(e)}")
|
119 |
+
logger.error(f"Error type: {type(e).__name__}")
|
120 |
+
raise
|
121 |
+
|
122 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
123 |
+
"""Generate speech using Kokoro Space TTS engine
|
124 |
+
|
125 |
+
Args:
|
126 |
+
text (str): Input text to synthesize
|
127 |
+
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
128 |
+
speed (float): Speech speed multiplier (0.5 to 2.0)
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
str: Path to the generated audio file
|
132 |
+
"""
|
133 |
+
logger.info(f"Generating speech with Kokoro Space for text length: {len(text)}")
|
134 |
+
logger.info(f"Text to generate speech on is: {text[:50]}..." if len(text) > 50 else f"Text to generate speech on is: {text}")
|
135 |
+
|
136 |
+
# Generate unique output path
|
137 |
+
output_path = self._generate_output_path()
|
138 |
+
|
139 |
+
try:
|
140 |
+
# Use af_nova as the default voice for Kokoro Space
|
141 |
+
voice_to_use = 'af_nova' if voice == 'af_heart' else voice
|
142 |
+
|
143 |
+
# Generate speech
|
144 |
+
result = self.client.predict(
|
145 |
+
text=text,
|
146 |
+
voice_names=voice_to_use,
|
147 |
+
speed=speed,
|
148 |
+
api_name="/generate_speech_from_ui"
|
149 |
+
)
|
150 |
+
logger.info(f"Received audio from Kokoro FastAPI server: {result}")
|
151 |
+
|
152 |
+
# TODO: Process the result and save to output_path
|
153 |
+
# For now, we'll return the result path directly if it's a string
|
154 |
+
if isinstance(result, str) and os.path.exists(result):
|
155 |
+
return result
|
156 |
+
else:
|
157 |
+
logger.warning("Unexpected result from Kokoro Space, falling back to dummy audio")
|
158 |
+
return DummyTTSEngine().generate_speech(text, voice, speed)
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}")
|
162 |
+
logger.error(f"Error type: {type(e).__name__}")
|
163 |
+
logger.info("Falling back to dummy audio generation")
|
164 |
+
return DummyTTSEngine().generate_speech(text, voice, speed)
|
165 |
+
|
166 |
+
|
167 |
+
class DiaTTSEngine(TTSEngineBase):
|
168 |
+
"""Dia TTS engine implementation
|
169 |
+
|
170 |
+
This engine uses the Dia model for TTS generation.
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, lang_code: str = 'z'):
|
174 |
+
super().__init__(lang_code)
|
175 |
+
# Dia doesn't need initialization here, it will be lazy-loaded when needed
|
176 |
+
logger.info("Dia TTS engine initialized (lazy loading)")
|
177 |
+
|
178 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
179 |
+
"""Generate speech using Dia TTS engine
|
180 |
+
|
181 |
+
Args:
|
182 |
+
text (str): Input text to synthesize
|
183 |
+
voice (str): Voice ID (not used in Dia)
|
184 |
+
speed (float): Speech speed multiplier (not used in Dia)
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
str: Path to the generated audio file
|
188 |
+
"""
|
189 |
+
logger.info(f"Generating speech with Dia for text length: {len(text)}")
|
190 |
+
|
191 |
+
try:
|
192 |
+
# Import here to avoid circular imports
|
193 |
+
from utils.tts_dia import generate_speech as dia_generate_speech
|
194 |
+
logger.info("Successfully imported Dia speech generation function")
|
195 |
+
|
196 |
+
# Call Dia's generate_speech function
|
197 |
+
# Note: Dia's function expects a language parameter, not voice or speed
|
198 |
+
output_path = dia_generate_speech(text, language=self.lang_code)
|
199 |
+
logger.info(f"Generated audio with Dia: {output_path}")
|
200 |
+
return output_path
|
201 |
+
|
202 |
+
except ImportError as import_err:
|
203 |
+
logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
|
204 |
+
logger.error("Falling back to dummy audio generation")
|
205 |
+
return DummyTTSEngine().generate_speech(text, voice, speed)
|
206 |
+
|
207 |
+
except Exception as dia_error:
|
208 |
+
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
|
209 |
+
logger.error(f"Error type: {type(dia_error).__name__}")
|
210 |
+
logger.error("Falling back to dummy audio generation")
|
211 |
+
return DummyTTSEngine().generate_speech(text, voice, speed)
|
212 |
+
|
213 |
+
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
214 |
+
"""Generate speech stream using Dia TTS engine
|
215 |
+
|
216 |
+
Args:
|
217 |
+
text (str): Input text to synthesize
|
218 |
+
voice (str): Voice ID (not used in Dia)
|
219 |
+
speed (float): Speech speed multiplier (not used in Dia)
|
220 |
+
|
221 |
+
Yields:
|
222 |
+
tuple: (sample_rate, audio_data) pairs for each segment
|
223 |
+
"""
|
224 |
+
logger.info(f"Generating speech stream with Dia for text length: {len(text)}")
|
225 |
+
|
226 |
+
try:
|
227 |
+
# Import required modules
|
228 |
+
import torch
|
229 |
+
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
|
230 |
+
|
231 |
+
# Get the Dia model
|
232 |
+
model = _get_model()
|
233 |
+
|
234 |
+
# Generate audio
|
235 |
+
with torch.inference_mode():
|
236 |
+
output_audio_np = model.generate(
|
237 |
+
text,
|
238 |
+
max_tokens=None,
|
239 |
+
cfg_scale=3.0,
|
240 |
+
temperature=1.3,
|
241 |
+
top_p=0.95,
|
242 |
+
cfg_filter_top_k=35,
|
243 |
+
use_torch_compile=False,
|
244 |
+
verbose=False
|
245 |
+
)
|
246 |
+
|
247 |
+
if output_audio_np is not None:
|
248 |
+
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
249 |
+
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
250 |
+
else:
|
251 |
+
logger.warning("Dia model returned None for audio output")
|
252 |
+
logger.warning("Falling back to dummy audio stream")
|
253 |
+
yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
|
254 |
+
|
255 |
+
except ImportError as import_err:
|
256 |
+
logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
|
257 |
+
logger.error("Falling back to dummy audio stream")
|
258 |
+
yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
|
259 |
+
|
260 |
+
except Exception as dia_error:
|
261 |
+
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
262 |
+
logger.error(f"Error type: {type(dia_error).__name__}")
|
263 |
+
logger.error("Falling back to dummy audio stream")
|
264 |
+
yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
|
265 |
+
|
266 |
+
|
267 |
+
def get_available_engines() -> List[str]:
|
268 |
+
"""Get a list of available TTS engines
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
List[str]: List of available engine names
|
272 |
+
"""
|
273 |
+
available = []
|
274 |
+
|
275 |
+
if KOKORO_AVAILABLE:
|
276 |
+
available.append('kokoro')
|
277 |
+
|
278 |
+
if KOKORO_SPACE_AVAILABLE:
|
279 |
+
available.append('kokoro_space')
|
280 |
+
|
281 |
+
if DIA_AVAILABLE:
|
282 |
+
available.append('dia')
|
283 |
+
|
284 |
+
# Dummy is always available
|
285 |
+
available.append('dummy')
|
286 |
+
|
287 |
+
return available
|
288 |
+
|
289 |
+
|
290 |
+
def create_engine(engine_type: str, lang_code: str = 'z') -> TTSEngineBase:
|
291 |
+
"""Create a specific TTS engine
|
292 |
+
|
293 |
+
Args:
|
294 |
+
engine_type (str): Type of engine to create ('kokoro', 'kokoro_space', 'dia', 'dummy')
|
295 |
+
lang_code (str): Language code for the engine
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
TTSEngineBase: An instance of the requested TTS engine
|
299 |
+
|
300 |
+
Raises:
|
301 |
+
ValueError: If the requested engine type is not supported
|
302 |
+
"""
|
303 |
+
if engine_type == 'kokoro':
|
304 |
+
if not KOKORO_AVAILABLE:
|
305 |
+
raise ValueError("Kokoro TTS engine is not available")
|
306 |
+
return KokoroTTSEngine(lang_code)
|
307 |
+
|
308 |
+
elif engine_type == 'kokoro_space':
|
309 |
+
if not KOKORO_SPACE_AVAILABLE:
|
310 |
+
raise ValueError("Kokoro Space TTS engine is not available")
|
311 |
+
return KokoroSpaceTTSEngine(lang_code)
|
312 |
+
|
313 |
+
elif engine_type == 'dia':
|
314 |
+
if not DIA_AVAILABLE:
|
315 |
+
raise ValueError("Dia TTS engine is not available")
|
316 |
+
return DiaTTSEngine(lang_code)
|
317 |
+
|
318 |
+
elif engine_type == 'dummy':
|
319 |
+
return DummyTTSEngine(lang_code)
|
320 |
+
|
321 |
+
else:
|
322 |
+
raise ValueError(f"Unsupported TTS engine type: {engine_type}")
|
utils/tts_factory.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
# Configure logging
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
# Import the base class
|
8 |
+
from utils.tts_base import TTSEngineBase, DummyTTSEngine
|
9 |
+
|
10 |
+
class TTSFactory:
|
11 |
+
"""Factory class for creating TTS engines
|
12 |
+
|
13 |
+
This class is responsible for creating the appropriate TTS engine based on
|
14 |
+
availability and configuration.
|
15 |
+
"""
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def create_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> TTSEngineBase:
|
19 |
+
"""Create a TTS engine instance
|
20 |
+
|
21 |
+
Args:
|
22 |
+
engine_type (str, optional): Type of engine to create ('kokoro', 'kokoro_space', 'dia', 'dummy')
|
23 |
+
If None, the best available engine will be used
|
24 |
+
lang_code (str): Language code for the engine
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
TTSEngineBase: An instance of a TTS engine
|
28 |
+
"""
|
29 |
+
from utils.tts_engines import get_available_engines, create_engine
|
30 |
+
|
31 |
+
# Get available engines
|
32 |
+
available_engines = get_available_engines()
|
33 |
+
logger.info(f"Available TTS engines: {available_engines}")
|
34 |
+
|
35 |
+
# If engine_type is specified, try to create that specific engine
|
36 |
+
if engine_type is not None:
|
37 |
+
if engine_type in available_engines:
|
38 |
+
logger.info(f"Creating requested engine: {engine_type}")
|
39 |
+
return create_engine(engine_type, lang_code)
|
40 |
+
else:
|
41 |
+
logger.warning(f"Requested engine '{engine_type}' is not available")
|
42 |
+
|
43 |
+
# Try to create the best available engine
|
44 |
+
# Priority: kokoro > kokoro_space > dia > dummy
|
45 |
+
for engine in ['kokoro', 'kokoro_space', 'dia']:
|
46 |
+
if engine in available_engines:
|
47 |
+
logger.info(f"Creating best available engine: {engine}")
|
48 |
+
return create_engine(engine, lang_code)
|
49 |
+
|
50 |
+
# Fall back to dummy engine
|
51 |
+
logger.warning("No TTS engines available, falling back to dummy engine")
|
52 |
+
return DummyTTSEngine(lang_code)
|
53 |
+
|
54 |
+
|
55 |
+
# Backward compatibility function
|
56 |
+
def get_tts_engine(lang_code: str = 'a') -> TTSEngineBase:
|
57 |
+
"""Get or create TTS engine instance (backward compatibility function)
|
58 |
+
|
59 |
+
Args:
|
60 |
+
lang_code (str): Language code for the pipeline
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
TTSEngineBase: Initialized TTS engine instance
|
64 |
+
"""
|
65 |
+
logger.info(f"Requesting TTS engine with language code: {lang_code}")
|
66 |
+
try:
|
67 |
+
import streamlit as st
|
68 |
+
logger.info("Streamlit detected, using cached TTS engine")
|
69 |
+
@st.cache_resource
|
70 |
+
def _get_engine():
|
71 |
+
logger.info("Creating cached TTS engine instance")
|
72 |
+
engine = TTSFactory.create_engine(lang_code=lang_code)
|
73 |
+
logger.info(f"Cached TTS engine created with type: {engine.__class__.__name__}")
|
74 |
+
return engine
|
75 |
+
|
76 |
+
engine = _get_engine()
|
77 |
+
logger.info(f"Retrieved TTS engine from cache with type: {engine.__class__.__name__}")
|
78 |
+
return engine
|
79 |
+
except ImportError:
|
80 |
+
logger.info("Streamlit not available, creating direct TTS engine instance")
|
81 |
+
engine = TTSFactory.create_engine(lang_code=lang_code)
|
82 |
+
logger.info(f"Direct TTS engine created with type: {engine.__class__.__name__}")
|
83 |
+
return engine
|
84 |
+
|
85 |
+
|
86 |
+
# Backward compatibility function
|
87 |
+
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
88 |
+
"""Public interface for TTS generation (backward compatibility function)
|
89 |
+
|
90 |
+
Args:
|
91 |
+
text (str): Input text to synthesize
|
92 |
+
voice (str): Voice ID to use
|
93 |
+
speed (float): Speech speed multiplier
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
str: Path to generated audio file
|
97 |
+
"""
|
98 |
+
logger.info(f"Public generate_speech called with text length: {len(text)}, voice: {voice}, speed: {speed}")
|
99 |
+
try:
|
100 |
+
# Get the TTS engine
|
101 |
+
logger.info("Getting TTS engine instance")
|
102 |
+
engine = get_tts_engine()
|
103 |
+
logger.info(f"Using TTS engine type: {engine.__class__.__name__}")
|
104 |
+
|
105 |
+
# Generate speech
|
106 |
+
logger.info("Calling engine.generate_speech")
|
107 |
+
output_path = engine.generate_speech(text, voice, speed)
|
108 |
+
logger.info(f"Speech generation complete, output path: {output_path}")
|
109 |
+
return output_path
|
110 |
+
except Exception as e:
|
111 |
+
logger.error(f"Error in public generate_speech function: {str(e)}", exc_info=True)
|
112 |
+
logger.error(f"Error type: {type(e).__name__}")
|
113 |
+
if hasattr(e, '__traceback__'):
|
114 |
+
tb = e.__traceback__
|
115 |
+
while tb.tb_next:
|
116 |
+
tb = tb.tb_next
|
117 |
+
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
|
118 |
+
raise
|