Michael Hu commited on
Commit
3ed3b5a
·
1 Parent(s): 31708ca

refator tts part

Browse files
Files changed (6) hide show
  1. utils/tts.py +86 -312
  2. utils/tts_base.py +152 -0
  3. utils/tts_dia.py +14 -115
  4. utils/tts_dummy.py +8 -22
  5. utils/tts_engines.py +322 -0
  6. utils/tts_factory.py +118 -0
utils/tts.py CHANGED
@@ -1,85 +1,52 @@
1
- import os
2
  import logging
3
- import time
4
- import soundfile as sf
5
- from gradio_client import Client
6
-
7
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
- # Flag to track TTS engine availability
11
- KOKORO_AVAILABLE = False
12
- KOKORO_SPACE_AVAILABLE = True
13
- DIA_AVAILABLE = False
14
 
15
- # Try to import Kokoro first
16
- try:
17
- from kokoro import KPipeline
18
- KOKORO_AVAILABLE = True
19
- logger.info("Kokoro TTS engine is available")
20
- except AttributeError as e:
21
- # Specifically catch the EspeakWrapper.set_data_path error
22
- if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
23
- logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue, falling back to Kokoro FastAPI server")
24
- else:
25
- # Re-raise if it's a different error
26
- logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
27
- raise
28
- except ImportError:
29
- logger.warning("Kokoro TTS engine is not available")
30
 
 
31
  class TTSEngine:
 
 
 
 
 
 
32
  def __init__(self, lang_code='z'):
33
- """Initialize TTS Engine with Kokoro or Dia as fallback
34
 
35
  Args:
36
  lang_code (str): Language code ('a' for US English, 'b' for British English,
37
  'j' for Japanese, 'z' for Mandarin Chinese)
38
- Note: lang_code is only used for Kokoro, not for Dia
39
  """
40
- logger.info("Initializing TTS Engine")
41
  logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
42
- self.engine_type = None
43
 
44
- if KOKORO_AVAILABLE:
45
- logger.info(f"Using Kokoro as primary TTS engine with language code: {lang_code}")
46
- try:
47
- self.pipeline = KPipeline(lang_code=lang_code)
48
- self.engine_type = "kokoro"
49
- logger.info("TTS engine successfully initialized with Kokoro")
50
- except Exception as kokoro_err:
51
- logger.error(f"Failed to initialize Kokoro pipeline: {str(kokoro_err)}")
52
- logger.error(f"Error type: {type(kokoro_err).__name__}")
53
- logger.info("Will try to fall back to Dia TTS engine")
54
-
55
- if KOKORO_SPACE_AVAILABLE:
56
- logger.info(f"Using Kokoro FastAPI server as primary TTS engine with language code: {lang_code}")
57
- try:
58
- self.client = Client("Remsky/Kokoro-TTS-Zero")
59
- self.engine_type = "kokoro_space"
60
- logger.info("TTS engine successfully initialized with Kokoro FastAPI server")
61
- except Exception as kokoro_err:
62
- logger.error(f"Failed to initialize Kokoro space: {str(kokoro_err)}")
63
- logger.error(f"Error type: {type(kokoro_err).__name__}")
64
- logger.info("Will try to fall back to Dia TTS engine")
65
-
66
- # Try Dia if Kokoro is not available or failed to initialize
67
- if self.engine_type is None and DIA_AVAILABLE:
68
- logger.info("Using Dia as fallback TTS engine")
69
- # For Dia, we don't need to initialize anything here
70
- # The model will be lazy-loaded when needed
71
- self.pipeline = None
72
- self.client = None
73
- self.engine_type = "dia"
74
- logger.info("TTS engine initialized with Dia (lazy loading)")
75
 
76
- # Use dummy if no TTS engines are available
77
- if self.engine_type is None:
78
- logger.warning("Using dummy TTS implementation as no TTS engines are available")
79
- logger.warning("Check logs above for specific errors that prevented Kokoro or Dia initialization")
80
- self.pipeline = None
81
- self.client = None
 
 
 
82
  self.engine_type = "dummy"
 
 
 
 
 
 
83
 
84
  def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
85
  """Generate speech from text using available TTS engine
@@ -87,272 +54,79 @@ class TTSEngine:
87
  Args:
88
  text (str): Input text to synthesize
89
  voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
90
- Note: voice parameter is only used for Kokoro, not for Dia
91
  speed (float): Speech speed multiplier (0.5 to 2.0)
92
- Note: speed parameter is only used for Kokoro, not for Dia
93
 
94
  Returns:
95
  str: Path to the generated audio file
96
  """
97
- logger.info(f"Generating speech for text length: {len(text)}")
98
-
99
- try:
100
- # Create output directory if it doesn't exist
101
- os.makedirs("temp/outputs", exist_ok=True)
102
-
103
- # Generate unique output path
104
- output_path = f"temp/outputs/output_{int(time.time())}.wav"
105
-
106
- # Use the appropriate TTS engine based on availability
107
- if self.engine_type == "kokoro":
108
- # Use Kokoro for TTS generation
109
- generator = self.pipeline(text, voice=voice, speed=speed)
110
- for _, _, audio in generator:
111
- logger.info(f"Saving Kokoro audio to {output_path}")
112
- sf.write(output_path, audio, 24000)
113
- break
114
- elif self.engine_type == "kokoro_space":
115
- # Use Kokoro FastAPI server for TTS generation
116
- logger.info("Generating speech using Kokoro FastAPI server")
117
- logger.info(f"text to generate speech on is: {text}")
118
- try:
119
- result = self.client.predict(
120
- text=text,
121
- voice_names='af_nova',
122
- speed=speed,
123
- api_name="/generate_speech_from_ui"
124
- )
125
- logger.info(f"Received audio from Kokoro FastAPI server: {result}")
126
- except Exception as e:
127
- logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}")
128
- logger.error(f"Error type: {type(e).__name__}")
129
- logger.info("Falling back to dummy audio generation")
130
- elif self.engine_type == "dia":
131
- # Use Dia for TTS generation
132
- try:
133
- logger.info("Attempting to use Dia TTS for speech generation")
134
- # Import here to avoid circular imports
135
- try:
136
- logger.info("Importing Dia speech generation module")
137
- from utils.tts_dia import generate_speech as dia_generate_speech
138
- logger.info("Successfully imported Dia speech generation function")
139
- except ImportError as import_err:
140
- logger.error(f"Failed to import Dia speech generation function: {str(import_err)}")
141
- logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
142
- raise
143
-
144
- # Call Dia's generate_speech function
145
- logger.info("Calling Dia's generate_speech function")
146
- output_path = dia_generate_speech(text)
147
- logger.info(f"Generated audio with Dia: {output_path}")
148
- except ImportError as import_err:
149
- logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
150
- logger.error("Falling back to dummy audio generation")
151
- return self._generate_dummy_audio(output_path)
152
- except Exception as dia_error:
153
- logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
154
- logger.error(f"Error type: {type(dia_error).__name__}")
155
- logger.error("Falling back to dummy audio generation")
156
- # Fall back to dummy audio if Dia fails
157
- return self._generate_dummy_audio(output_path)
158
- else:
159
- # Generate dummy audio as fallback
160
- return self._generate_dummy_audio(output_path)
161
-
162
- logger.info(f"Audio generation complete: {output_path}")
163
- return output_path
164
-
165
- except Exception as e:
166
- logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
167
- raise
168
-
169
- def _generate_dummy_audio(self, output_path):
170
- """Generate a dummy audio file with a simple sine wave
171
-
172
- Args:
173
- output_path (str): Path to save the dummy audio file
174
-
175
- Returns:
176
- str: Path to the generated dummy audio file
177
- """
178
- import numpy as np
179
- sample_rate = 24000
180
- duration = 3.0 # seconds
181
- t = np.linspace(0, duration, int(sample_rate * duration), False)
182
- tone = np.sin(2 * np.pi * 440 * t) * 0.3
183
-
184
- logger.info(f"Saving dummy audio to {output_path}")
185
- sf.write(output_path, tone, sample_rate)
186
- logger.info(f"Dummy audio generation complete: {output_path}")
187
- return output_path
188
-
189
  def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
190
  """Generate speech from text and yield each segment
191
 
192
  Args:
193
  text (str): Input text to synthesize
194
- voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
195
- speed (float): Speech speed multiplier (0.5 to 2.0)
196
 
197
  Yields:
198
  tuple: (sample_rate, audio_data) pairs for each segment
199
  """
200
- try:
201
- # Use the appropriate TTS engine based on availability
202
- if self.engine_type == "kokoro":
203
- # Use Kokoro for streaming TTS
204
- generator = self.pipeline(text, voice=voice, speed=speed)
205
- for _, _, audio in generator:
206
- yield 24000, audio
207
- elif self.engine_type == "dia":
208
- # Dia doesn't support streaming natively, so we generate the full audio
209
- # and then yield it as a single chunk
210
- try:
211
- logger.info("Attempting to use Dia TTS for speech streaming")
212
- # Import here to avoid circular imports
213
- try:
214
- logger.info("Importing required modules for Dia streaming")
215
- import torch
216
- logger.info("PyTorch successfully imported for Dia streaming")
217
-
218
- try:
219
- from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
220
- logger.info("Successfully imported Dia model and sample rate")
221
- except ImportError as import_err:
222
- logger.error(f"Failed to import Dia model for streaming: {str(import_err)}")
223
- logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
224
- raise
225
- except ImportError as torch_err:
226
- logger.error(f"PyTorch import failed for Dia streaming: {str(torch_err)}")
227
- raise
228
-
229
- # Get the Dia model
230
- logger.info("Getting Dia model instance")
231
- try:
232
- model = _get_model()
233
- logger.info("Successfully obtained Dia model instance")
234
- except Exception as model_err:
235
- logger.error(f"Failed to get Dia model instance: {str(model_err)}")
236
- logger.error(f"Error type: {type(model_err).__name__}")
237
- raise
238
-
239
- # Generate audio
240
- logger.info("Generating audio with Dia model")
241
- with torch.inference_mode():
242
- output_audio_np = model.generate(
243
- text,
244
- max_tokens=None,
245
- cfg_scale=3.0,
246
- temperature=1.3,
247
- top_p=0.95,
248
- cfg_filter_top_k=35,
249
- use_torch_compile=False,
250
- verbose=False
251
- )
252
-
253
- if output_audio_np is not None:
254
- logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
255
- yield DEFAULT_SAMPLE_RATE, output_audio_np
256
- else:
257
- logger.warning("Dia model returned None for audio output")
258
- logger.warning("Falling back to dummy audio stream")
259
- # Fall back to dummy audio if Dia fails
260
- yield from self._generate_dummy_audio_stream()
261
- except ImportError as import_err:
262
- logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
263
- logger.error("Falling back to dummy audio stream")
264
- # Fall back to dummy audio if Dia fails
265
- yield from self._generate_dummy_audio_stream()
266
- except Exception as dia_error:
267
- logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
268
- logger.error(f"Error type: {type(dia_error).__name__}")
269
- logger.error("Falling back to dummy audio stream")
270
- # Fall back to dummy audio if Dia fails
271
- yield from self._generate_dummy_audio_stream()
272
- else:
273
- # Generate dummy audio chunks as fallback
274
- yield from self._generate_dummy_audio_stream()
275
-
276
- except Exception as e:
277
- logger.error(f"TTS streaming failed: {str(e)}", exc_info=True)
278
- raise
279
 
 
 
 
 
 
 
 
 
280
  def _generate_dummy_audio_stream(self):
281
- """Generate dummy audio chunks with simple sine waves
282
 
283
  Yields:
284
  tuple: (sample_rate, audio_data) pairs for each dummy segment
285
  """
286
- import numpy as np
287
- sample_rate = 24000
288
- duration = 1.0 # seconds per chunk
289
-
290
- # Create 3 chunks of dummy audio
291
- for i in range(3):
292
- t = np.linspace(0, duration, int(sample_rate * duration), False)
293
- freq = 440 + (i * 220) # Different frequency for each chunk
294
- tone = np.sin(2 * np.pi * freq * t) * 0.3
295
- yield sample_rate, tone
296
 
297
- # Initialize TTS engine with cache decorator if using Streamlit
298
- def get_tts_engine(lang_code='a'):
299
- """Get or create TTS engine instance
300
-
301
- Args:
302
- lang_code (str): Language code for the pipeline
303
-
304
- Returns:
305
- TTSEngine: Initialized TTS engine instance
306
- """
307
- logger.info(f"Requesting TTS engine with language code: {lang_code}")
308
- try:
309
- import streamlit as st
310
- logger.info("Streamlit detected, using cached TTS engine")
311
- @st.cache_resource
312
- def _get_engine():
313
- logger.info("Creating cached TTS engine instance")
314
- engine = TTSEngine(lang_code)
315
- logger.info(f"Cached TTS engine created with type: {engine.engine_type}")
316
- return engine
317
-
318
- engine = _get_engine()
319
- logger.info(f"Retrieved TTS engine from cache with type: {engine.engine_type}")
320
- return engine
321
- except ImportError:
322
- logger.info("Streamlit not available, creating direct TTS engine instance")
323
- engine = TTSEngine(lang_code)
324
- logger.info(f"Direct TTS engine created with type: {engine.engine_type}")
325
- return engine
326
 
327
- def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
328
- """Public interface for TTS generation
329
-
330
- Args:
331
- text (str): Input text to synthesize
332
- voice (str): Voice ID to use
333
- speed (float): Speech speed multiplier
334
-
335
- Returns:
336
- str: Path to generated audio file
337
- """
338
- logger.info(f"Public generate_speech called with text length: {len(text)}, voice: {voice}, speed: {speed}")
339
- try:
340
- # Get the TTS engine
341
- logger.info("Getting TTS engine instance")
342
- engine = get_tts_engine()
343
- logger.info(f"Using TTS engine type: {engine.engine_type}")
344
-
345
- # Generate speech
346
- logger.info("Calling engine.generate_speech")
347
- output_path = engine.generate_speech(text, voice, speed)
348
- logger.info(f"Speech generation complete, output path: {output_path}")
349
- return output_path
350
- except Exception as e:
351
- logger.error(f"Error in public generate_speech function: {str(e)}", exc_info=True)
352
- logger.error(f"Error type: {type(e).__name__}")
353
- if hasattr(e, '__traceback__'):
354
- tb = e.__traceback__
355
- while tb.tb_next:
356
- tb = tb.tb_next
357
- logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
358
- raise
 
 
1
  import logging
 
 
 
 
2
 
3
+ # Configure logging
4
  logger = logging.getLogger(__name__)
5
 
6
+ # Import from the new factory pattern implementation
7
+ from utils.tts_factory import get_tts_engine, generate_speech, TTSFactory
8
+ from utils.tts_engines import get_available_engines
 
9
 
10
+ # For backward compatibility
11
+ from utils.tts_engines import KOKORO_AVAILABLE, KOKORO_SPACE_AVAILABLE, DIA_AVAILABLE
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Backward compatibility class
14
  class TTSEngine:
15
+ """Legacy TTSEngine class for backward compatibility
16
+
17
+ This class is maintained for backward compatibility with existing code.
18
+ New code should use the factory pattern implementation directly.
19
+ """
20
+
21
  def __init__(self, lang_code='z'):
22
+ """Initialize TTS Engine using the factory pattern
23
 
24
  Args:
25
  lang_code (str): Language code ('a' for US English, 'b' for British English,
26
  'j' for Japanese, 'z' for Mandarin Chinese)
 
27
  """
28
+ logger.info("Initializing legacy TTSEngine wrapper")
29
  logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
 
30
 
31
+ # Create the appropriate engine using the factory
32
+ self._engine = TTSFactory.create_engine(lang_code=lang_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Set engine_type for backward compatibility
35
+ engine_class = self._engine.__class__.__name__
36
+ if 'Kokoro' in engine_class and 'Space' in engine_class:
37
+ self.engine_type = "kokoro_space"
38
+ elif 'Kokoro' in engine_class:
39
+ self.engine_type = "kokoro"
40
+ elif 'Dia' in engine_class:
41
+ self.engine_type = "dia"
42
+ else:
43
  self.engine_type = "dummy"
44
+
45
+ # Set pipeline and client attributes for backward compatibility
46
+ self.pipeline = getattr(self._engine, 'pipeline', None)
47
+ self.client = getattr(self._engine, 'client', None)
48
+
49
+ logger.info(f"Legacy TTSEngine wrapper initialized with engine type: {self.engine_type}")
50
 
51
  def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
52
  """Generate speech from text using available TTS engine
 
54
  Args:
55
  text (str): Input text to synthesize
56
  voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
 
57
  speed (float): Speech speed multiplier (0.5 to 2.0)
 
58
 
59
  Returns:
60
  str: Path to the generated audio file
61
  """
62
+ logger.info(f"Legacy TTSEngine wrapper calling generate_speech for text length: {len(text)}")
63
+ return self._engine.generate_speech(text, voice, speed)
64
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
66
  """Generate speech from text and yield each segment
67
 
68
  Args:
69
  text (str): Input text to synthesize
70
+ voice (str): Voice ID to use
71
+ speed (float): Speech speed multiplier
72
 
73
  Yields:
74
  tuple: (sample_rate, audio_data) pairs for each segment
75
  """
76
+ logger.info(f"Legacy TTSEngine wrapper calling generate_speech_stream for text length: {len(text)}")
77
+ yield from self._engine.generate_speech_stream(text, voice, speed)
78
+
79
+ # For backward compatibility
80
+ def _generate_dummy_audio(self, output_path):
81
+ """Generate a dummy audio file with a simple sine wave (backward compatibility)
82
+
83
+ Args:
84
+ output_path (str): Path to save the dummy audio file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ Returns:
87
+ str: Path to the generated dummy audio file
88
+ """
89
+ from utils.tts_base import DummyTTSEngine
90
+ dummy_engine = DummyTTSEngine()
91
+ return dummy_engine.generate_speech("", "", 1.0)
92
+
93
+ # For backward compatibility
94
  def _generate_dummy_audio_stream(self):
95
+ """Generate dummy audio chunks (backward compatibility)
96
 
97
  Yields:
98
  tuple: (sample_rate, audio_data) pairs for each dummy segment
99
  """
100
+ from utils.tts_base import DummyTTSEngine
101
+ dummy_engine = DummyTTSEngine()
102
+ yield from dummy_engine.generate_speech_stream("", "", 1.0)
 
 
 
 
 
 
 
103
 
104
+ # Import the new implementations from tts_base
105
+ # These functions are already defined in tts_base.py and imported at the top of this file
106
+ # They are kept here as comments for reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # def get_tts_engine(lang_code='a'):
109
+ # """Get or create TTS engine instance
110
+ #
111
+ # Args:
112
+ # lang_code (str): Language code for the pipeline
113
+ #
114
+ # Returns:
115
+ # TTSEngineBase: Initialized TTS engine instance
116
+ # """
117
+ # # Implementation moved to tts_base.py
118
+ # pass
119
+
120
+ # def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
121
+ # """Public interface for TTS generation
122
+ #
123
+ # Args:
124
+ # text (str): Input text to synthesize
125
+ # voice (str): Voice ID to use
126
+ # speed (float): Speech speed multiplier
127
+ #
128
+ # Returns:
129
+ # str: Path to generated audio file
130
+ # "\"""
131
+ # # Implementation moved to tts_base.py
132
+ # pass
 
 
 
 
 
 
 
utils/tts_base.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import soundfile as sf
5
+ import numpy as np
6
+ from abc import ABC, abstractmethod
7
+ from typing import Tuple, Generator, Optional
8
+
9
+ # Configure logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class TTSEngineBase(ABC):
13
+ """Base class for all TTS engines
14
+
15
+ This abstract class defines the interface that all TTS engines must implement.
16
+ It also provides common utility methods for file handling and audio generation.
17
+ """
18
+
19
+ def __init__(self, lang_code: str = 'z'):
20
+ """Initialize the TTS engine
21
+
22
+ Args:
23
+ lang_code (str): Language code ('a' for US English, 'b' for British English,
24
+ 'j' for Japanese, 'z' for Mandarin Chinese)
25
+ Note: Not all engines support all language codes
26
+ """
27
+ self.lang_code = lang_code
28
+ logger.info(f"Initializing {self.__class__.__name__} with language code: {lang_code}")
29
+
30
+ @abstractmethod
31
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
32
+ """Generate speech from text
33
+
34
+ Args:
35
+ text (str): Input text to synthesize
36
+ voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
37
+ Note: Not all engines support all voices
38
+ speed (float): Speech speed multiplier (0.5 to 2.0)
39
+ Note: Not all engines support speed adjustment
40
+
41
+ Returns:
42
+ str: Path to the generated audio file
43
+ """
44
+ pass
45
+
46
+ def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
47
+ """Generate speech from text and yield each segment
48
+
49
+ Args:
50
+ text (str): Input text to synthesize
51
+ voice (str): Voice ID to use
52
+ speed (float): Speech speed multiplier
53
+
54
+ Yields:
55
+ tuple: (sample_rate, audio_data) pairs for each segment
56
+ """
57
+ # Default implementation: generate full audio and yield as a single chunk
58
+ output_path = self.generate_speech(text, voice, speed)
59
+ audio_data, sample_rate = sf.read(output_path)
60
+ yield sample_rate, audio_data
61
+
62
+ def _create_output_dir(self) -> str:
63
+ """Create output directory for audio files
64
+
65
+ Returns:
66
+ str: Path to the output directory
67
+ """
68
+ output_dir = "temp/outputs"
69
+ os.makedirs(output_dir, exist_ok=True)
70
+ return output_dir
71
+
72
+ def _generate_output_path(self, prefix: str = "output") -> str:
73
+ """Generate a unique output path for audio files
74
+
75
+ Args:
76
+ prefix (str): Prefix for the output filename
77
+
78
+ Returns:
79
+ str: Path to the output file
80
+ """
81
+ output_dir = self._create_output_dir()
82
+ timestamp = int(time.time())
83
+ return f"{output_dir}/{prefix}_{timestamp}.wav"
84
+
85
+
86
+ class DummyTTSEngine(TTSEngineBase):
87
+ """Dummy TTS engine that generates a simple sine wave
88
+
89
+ This engine is used as a fallback when no other engines are available.
90
+ """
91
+
92
+ def __init__(self, lang_code: str = 'z'):
93
+ super().__init__(lang_code)
94
+ logger.warning("Using dummy TTS implementation as no other engines are available")
95
+
96
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
97
+ """Generate a dummy audio file with a simple sine wave
98
+
99
+ Args:
100
+ text (str): Input text (not used)
101
+ voice (str): Voice ID (not used)
102
+ speed (float): Speed multiplier (not used)
103
+
104
+ Returns:
105
+ str: Path to the generated dummy audio file
106
+ """
107
+ logger.info(f"Generating dummy speech for text length: {len(text)}")
108
+
109
+ # Generate unique output path
110
+ output_path = self._generate_output_path("dummy")
111
+
112
+ # Generate a simple sine wave
113
+ sample_rate = 24000
114
+ duration = 3.0 # seconds
115
+ t = np.linspace(0, duration, int(sample_rate * duration), False)
116
+ tone = np.sin(2 * np.pi * 440 * t) * 0.3
117
+
118
+ # Save the audio file
119
+ logger.info(f"Saving dummy audio to {output_path}")
120
+ sf.write(output_path, tone, sample_rate)
121
+ logger.info(f"Dummy audio generation complete: {output_path}")
122
+
123
+ return output_path
124
+
125
+ def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
126
+ """Generate dummy audio chunks with simple sine waves
127
+
128
+ Args:
129
+ text (str): Input text (not used)
130
+ voice (str): Voice ID (not used)
131
+ speed (float): Speed multiplier (not used)
132
+
133
+ Yields:
134
+ tuple: (sample_rate, audio_data) pairs for each dummy segment
135
+ """
136
+ logger.info(f"Generating dummy speech stream for text length: {len(text)}")
137
+
138
+ sample_rate = 24000
139
+ duration = 1.0 # seconds per chunk
140
+
141
+ # Create 3 chunks of dummy audio
142
+ for i in range(3):
143
+ t = np.linspace(0, duration, int(sample_rate * duration), False)
144
+ freq = 440 + (i * 220) # Different frequency for each chunk
145
+ tone = np.sin(2 * np.pi * freq * t) * 0.3
146
+ yield sample_rate, tone
147
+
148
+
149
+ # Factory functionality moved to tts_factory.py to avoid circular imports
150
+
151
+
152
+ # Note: Backward compatibility functions moved to tts_factory.py
utils/tts_dia.py CHANGED
@@ -68,6 +68,9 @@ def _get_model() -> Dia:
68
  def generate_speech(text: str, language: str = "zh") -> str:
69
  """Public interface for TTS generation using Dia model
70
 
 
 
 
71
  Args:
72
  text (str): Input text to synthesize
73
  language (str): Language code (not used in Dia model, kept for API compatibility)
@@ -75,122 +78,18 @@ def generate_speech(text: str, language: str = "zh") -> str:
75
  Returns:
76
  str: Path to the generated audio file
77
  """
78
- logger.info(f"Generating speech for text length: {len(text)}")
79
- logger.info(f"Text content (first 50 chars): {text[:50]}...")
80
-
81
- # Create output directory if it doesn't exist
82
- output_dir = "temp/outputs"
83
- logger.info(f"Ensuring output directory exists: {output_dir}")
84
- try:
85
- os.makedirs(output_dir, exist_ok=True)
86
- logger.info(f"Output directory ready: {output_dir}")
87
- except PermissionError as perm_err:
88
- logger.error(f"Permission error creating output directory: {perm_err}")
89
- # Fall back to dummy TTS
90
- logger.info("Falling back to dummy TTS due to directory creation error")
91
- from utils.tts_dummy import generate_speech as dummy_generate_speech
92
- return dummy_generate_speech(text, language)
93
- except Exception as dir_err:
94
- logger.error(f"Error creating output directory: {dir_err}")
95
- # Fall back to dummy TTS
96
- logger.info("Falling back to dummy TTS due to directory creation error")
97
- from utils.tts_dummy import generate_speech as dummy_generate_speech
98
- return dummy_generate_speech(text, language)
99
-
100
- # Generate unique output path
101
- timestamp = int(time.time())
102
- output_path = f"{output_dir}/output_{timestamp}.wav"
103
- logger.info(f"Output will be saved to: {output_path}")
104
-
105
- # Get the model
106
- logger.info("Retrieving Dia model instance")
107
- try:
108
- model = _get_model()
109
- logger.info("Successfully retrieved Dia model instance")
110
- except Exception as model_err:
111
- logger.error(f"Failed to get Dia model: {model_err}")
112
- logger.error(f"Error type: {type(model_err).__name__}")
113
- # Fall back to dummy TTS
114
- logger.info("Falling back to dummy TTS due to model loading error")
115
- from utils.tts_dummy import generate_speech as dummy_generate_speech
116
- return dummy_generate_speech(text, language)
117
 
118
- # Generate audio
119
- logger.info("Starting audio generation with Dia model")
120
- start_time = time.time()
121
 
122
  try:
123
- with torch.inference_mode():
124
- logger.info("Calling model.generate() with inference_mode")
125
- output_audio_np = model.generate(
126
- text,
127
- max_tokens=None, # Use default from model config
128
- cfg_scale=3.0,
129
- temperature=1.3,
130
- top_p=0.95,
131
- cfg_filter_top_k=35,
132
- use_torch_compile=False, # Keep False for stability
133
- verbose=False
134
- )
135
- logger.info("Model.generate() completed")
136
- except RuntimeError as rt_err:
137
- logger.error(f"Runtime error during generation: {rt_err}")
138
- if "CUDA out of memory" in str(rt_err):
139
- logger.error("CUDA out of memory error - consider reducing batch size or model size")
140
  # Fall back to dummy TTS
141
- logger.info("Falling back to dummy TTS due to runtime error during generation")
142
- from utils.tts_dummy import generate_speech as dummy_generate_speech
143
- return dummy_generate_speech(text, language)
144
- except Exception as gen_err:
145
- logger.error(f"Error during audio generation: {gen_err}")
146
- logger.error(f"Error type: {type(gen_err).__name__}")
147
- # Fall back to dummy TTS
148
- logger.info("Falling back to dummy TTS due to error during generation")
149
- from utils.tts_dummy import generate_speech as dummy_generate_speech
150
- return dummy_generate_speech(text, language)
151
-
152
- end_time = time.time()
153
- generation_time = end_time - start_time
154
- logger.info(f"Generation finished in {generation_time:.2f} seconds")
155
-
156
- # Process the output
157
- if output_audio_np is not None:
158
- logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}")
159
- logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}")
160
-
161
- # Apply a slight slowdown for better quality (0.94x speed)
162
- speed_factor = 0.94
163
- original_len = len(output_audio_np)
164
- target_len = int(original_len / speed_factor)
165
-
166
- logger.info(f"Applying speed adjustment factor: {speed_factor}")
167
- if target_len != original_len and target_len > 0:
168
- try:
169
- x_original = np.arange(original_len)
170
- x_resampled = np.linspace(0, original_len - 1, target_len)
171
- output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
172
- logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
173
- except Exception as resample_err:
174
- logger.error(f"Error during audio resampling: {resample_err}")
175
- logger.warning("Using original audio without resampling")
176
-
177
- # Save the audio file
178
- logger.info(f"Saving audio to file: {output_path}")
179
- try:
180
- sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
181
- logger.info(f"Audio successfully saved to {output_path}")
182
- except Exception as save_err:
183
- logger.error(f"Error saving audio file: {save_err}")
184
- logger.error(f"Error type: {type(save_err).__name__}")
185
- # Fall back to dummy TTS
186
- logger.info("Falling back to dummy TTS due to error saving audio file")
187
- from utils.tts_dummy import generate_speech as dummy_generate_speech
188
- return dummy_generate_speech(text, language)
189
-
190
- return output_path
191
- else:
192
- logger.warning("Generation produced no output (None returned from model)")
193
- logger.warning("This may indicate a model configuration issue or empty input text")
194
- dummy_path = f"{output_dir}/dummy_{timestamp}.wav"
195
- logger.warning(f"Returning dummy audio path: {dummy_path}")
196
- return dummy_path
 
68
  def generate_speech(text: str, language: str = "zh") -> str:
69
  """Public interface for TTS generation using Dia model
70
 
71
+ This is a legacy function maintained for backward compatibility.
72
+ New code should use the factory pattern implementation directly.
73
+
74
  Args:
75
  text (str): Input text to synthesize
76
  language (str): Language code (not used in Dia model, kept for API compatibility)
 
78
  Returns:
79
  str: Path to the generated audio file
80
  """
81
+ logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Use the new implementation via factory pattern
84
+ from utils.tts_engines import DiaTTSEngine
 
85
 
86
  try:
87
+ # Create a Dia engine and generate speech
88
+ dia_engine = DiaTTSEngine(language)
89
+ return dia_engine.generate_speech(text)
90
+ except Exception as e:
91
+ logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Fall back to dummy TTS
93
+ from utils.tts_base import DummyTTSEngine
94
+ dummy_engine = DummyTTSEngine()
95
+ return dummy_engine.generate_speech(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/tts_dummy.py CHANGED
@@ -1,25 +1,11 @@
1
  def generate_speech(text: str, language: str = "zh") -> str:
2
- """Public interface for TTS generation"""
3
- import os
4
- import numpy as np
5
- import soundfile as sf
6
- import time
7
 
8
- # Create output directory if it doesn't exist
9
- output_dir = "temp/outputs"
10
- os.makedirs(output_dir, exist_ok=True)
 
11
 
12
- # Generate a unique filename
13
- timestamp = int(time.time())
14
- output_path = f"{output_dir}/dummy_{timestamp}.wav"
15
-
16
- # Generate a simple sine wave as dummy audio
17
- sample_rate = 24000
18
- duration = 2.0 # seconds
19
- t = np.linspace(0, duration, int(sample_rate * duration), False)
20
- tone = np.sin(2 * np.pi * 440 * t) * 0.3
21
-
22
- # Save the audio file
23
- sf.write(output_path, tone, sample_rate)
24
-
25
- return output_path
 
1
  def generate_speech(text: str, language: str = "zh") -> str:
2
+ """Public interface for TTS generation
 
 
 
 
3
 
4
+ This is a legacy function maintained for backward compatibility.
5
+ New code should use the factory pattern implementation directly.
6
+ """
7
+ from utils.tts_base import DummyTTSEngine
8
 
9
+ # Create a dummy engine and generate speech
10
+ dummy_engine = DummyTTSEngine()
11
+ return dummy_engine.generate_speech(text, "af_heart", 1.0)
 
 
 
 
 
 
 
 
 
 
 
utils/tts_engines.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import os
4
+ import numpy as np
5
+ import soundfile as sf
6
+ from typing import Dict, List, Optional, Tuple, Generator, Any
7
+
8
+ from utils.tts_base import TTSEngineBase, DummyTTSEngine
9
+
10
+ # Configure logging
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Flag to track TTS engine availability
14
+ KOKORO_AVAILABLE = False
15
+ KOKORO_SPACE_AVAILABLE = True
16
+ DIA_AVAILABLE = False
17
+
18
+ # Try to import Kokoro
19
+ try:
20
+ from kokoro import KPipeline
21
+ KOKORO_AVAILABLE = True
22
+ logger.info("Kokoro TTS engine is available")
23
+ except AttributeError as e:
24
+ # Specifically catch the EspeakWrapper.set_data_path error
25
+ if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
26
+ logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue, falling back to Kokoro FastAPI server")
27
+ else:
28
+ # Re-raise if it's a different error
29
+ logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
30
+ raise
31
+ except ImportError:
32
+ logger.warning("Kokoro TTS engine is not available")
33
+
34
+ # Try to import Dia dependencies to check availability
35
+ try:
36
+ import torch
37
+ from dia.model import Dia
38
+ DIA_AVAILABLE = True
39
+ logger.info("Dia TTS engine is available")
40
+ except ImportError:
41
+ logger.warning("Dia TTS engine is not available")
42
+
43
+
44
+ class KokoroTTSEngine(TTSEngineBase):
45
+ """Kokoro TTS engine implementation
46
+
47
+ This engine uses the Kokoro library for TTS generation.
48
+ """
49
+
50
+ def __init__(self, lang_code: str = 'z'):
51
+ super().__init__(lang_code)
52
+ try:
53
+ self.pipeline = KPipeline(lang_code=lang_code)
54
+ logger.info("Kokoro TTS engine successfully initialized")
55
+ except Exception as e:
56
+ logger.error(f"Failed to initialize Kokoro pipeline: {str(e)}")
57
+ logger.error(f"Error type: {type(e).__name__}")
58
+ raise
59
+
60
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
61
+ """Generate speech using Kokoro TTS engine
62
+
63
+ Args:
64
+ text (str): Input text to synthesize
65
+ voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
66
+ speed (float): Speech speed multiplier (0.5 to 2.0)
67
+
68
+ Returns:
69
+ str: Path to the generated audio file
70
+ """
71
+ logger.info(f"Generating speech with Kokoro for text length: {len(text)}")
72
+
73
+ # Generate unique output path
74
+ output_path = self._generate_output_path()
75
+
76
+ # Generate speech
77
+ generator = self.pipeline(text, voice=voice, speed=speed)
78
+ for _, _, audio in generator:
79
+ logger.info(f"Saving Kokoro audio to {output_path}")
80
+ sf.write(output_path, audio, 24000)
81
+ break
82
+
83
+ logger.info(f"Kokoro audio generation complete: {output_path}")
84
+ return output_path
85
+
86
+ def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
87
+ """Generate speech stream using Kokoro TTS engine
88
+
89
+ Args:
90
+ text (str): Input text to synthesize
91
+ voice (str): Voice ID to use
92
+ speed (float): Speech speed multiplier
93
+
94
+ Yields:
95
+ tuple: (sample_rate, audio_data) pairs for each segment
96
+ """
97
+ logger.info(f"Generating speech stream with Kokoro for text length: {len(text)}")
98
+
99
+ # Generate speech stream
100
+ generator = self.pipeline(text, voice=voice, speed=speed)
101
+ for _, _, audio in generator:
102
+ yield 24000, audio
103
+
104
+
105
+ class KokoroSpaceTTSEngine(TTSEngineBase):
106
+ """Kokoro Space TTS engine implementation
107
+
108
+ This engine uses the Kokoro FastAPI server for TTS generation.
109
+ """
110
+
111
+ def __init__(self, lang_code: str = 'z'):
112
+ super().__init__(lang_code)
113
+ try:
114
+ from gradio_client import Client
115
+ self.client = Client("Remsky/Kokoro-TTS-Zero")
116
+ logger.info("Kokoro Space TTS engine successfully initialized")
117
+ except Exception as e:
118
+ logger.error(f"Failed to initialize Kokoro Space client: {str(e)}")
119
+ logger.error(f"Error type: {type(e).__name__}")
120
+ raise
121
+
122
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
123
+ """Generate speech using Kokoro Space TTS engine
124
+
125
+ Args:
126
+ text (str): Input text to synthesize
127
+ voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
128
+ speed (float): Speech speed multiplier (0.5 to 2.0)
129
+
130
+ Returns:
131
+ str: Path to the generated audio file
132
+ """
133
+ logger.info(f"Generating speech with Kokoro Space for text length: {len(text)}")
134
+ logger.info(f"Text to generate speech on is: {text[:50]}..." if len(text) > 50 else f"Text to generate speech on is: {text}")
135
+
136
+ # Generate unique output path
137
+ output_path = self._generate_output_path()
138
+
139
+ try:
140
+ # Use af_nova as the default voice for Kokoro Space
141
+ voice_to_use = 'af_nova' if voice == 'af_heart' else voice
142
+
143
+ # Generate speech
144
+ result = self.client.predict(
145
+ text=text,
146
+ voice_names=voice_to_use,
147
+ speed=speed,
148
+ api_name="/generate_speech_from_ui"
149
+ )
150
+ logger.info(f"Received audio from Kokoro FastAPI server: {result}")
151
+
152
+ # TODO: Process the result and save to output_path
153
+ # For now, we'll return the result path directly if it's a string
154
+ if isinstance(result, str) and os.path.exists(result):
155
+ return result
156
+ else:
157
+ logger.warning("Unexpected result from Kokoro Space, falling back to dummy audio")
158
+ return DummyTTSEngine().generate_speech(text, voice, speed)
159
+
160
+ except Exception as e:
161
+ logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}")
162
+ logger.error(f"Error type: {type(e).__name__}")
163
+ logger.info("Falling back to dummy audio generation")
164
+ return DummyTTSEngine().generate_speech(text, voice, speed)
165
+
166
+
167
+ class DiaTTSEngine(TTSEngineBase):
168
+ """Dia TTS engine implementation
169
+
170
+ This engine uses the Dia model for TTS generation.
171
+ """
172
+
173
+ def __init__(self, lang_code: str = 'z'):
174
+ super().__init__(lang_code)
175
+ # Dia doesn't need initialization here, it will be lazy-loaded when needed
176
+ logger.info("Dia TTS engine initialized (lazy loading)")
177
+
178
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
179
+ """Generate speech using Dia TTS engine
180
+
181
+ Args:
182
+ text (str): Input text to synthesize
183
+ voice (str): Voice ID (not used in Dia)
184
+ speed (float): Speech speed multiplier (not used in Dia)
185
+
186
+ Returns:
187
+ str: Path to the generated audio file
188
+ """
189
+ logger.info(f"Generating speech with Dia for text length: {len(text)}")
190
+
191
+ try:
192
+ # Import here to avoid circular imports
193
+ from utils.tts_dia import generate_speech as dia_generate_speech
194
+ logger.info("Successfully imported Dia speech generation function")
195
+
196
+ # Call Dia's generate_speech function
197
+ # Note: Dia's function expects a language parameter, not voice or speed
198
+ output_path = dia_generate_speech(text, language=self.lang_code)
199
+ logger.info(f"Generated audio with Dia: {output_path}")
200
+ return output_path
201
+
202
+ except ImportError as import_err:
203
+ logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
204
+ logger.error("Falling back to dummy audio generation")
205
+ return DummyTTSEngine().generate_speech(text, voice, speed)
206
+
207
+ except Exception as dia_error:
208
+ logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
209
+ logger.error(f"Error type: {type(dia_error).__name__}")
210
+ logger.error("Falling back to dummy audio generation")
211
+ return DummyTTSEngine().generate_speech(text, voice, speed)
212
+
213
+ def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
214
+ """Generate speech stream using Dia TTS engine
215
+
216
+ Args:
217
+ text (str): Input text to synthesize
218
+ voice (str): Voice ID (not used in Dia)
219
+ speed (float): Speech speed multiplier (not used in Dia)
220
+
221
+ Yields:
222
+ tuple: (sample_rate, audio_data) pairs for each segment
223
+ """
224
+ logger.info(f"Generating speech stream with Dia for text length: {len(text)}")
225
+
226
+ try:
227
+ # Import required modules
228
+ import torch
229
+ from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
230
+
231
+ # Get the Dia model
232
+ model = _get_model()
233
+
234
+ # Generate audio
235
+ with torch.inference_mode():
236
+ output_audio_np = model.generate(
237
+ text,
238
+ max_tokens=None,
239
+ cfg_scale=3.0,
240
+ temperature=1.3,
241
+ top_p=0.95,
242
+ cfg_filter_top_k=35,
243
+ use_torch_compile=False,
244
+ verbose=False
245
+ )
246
+
247
+ if output_audio_np is not None:
248
+ logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
249
+ yield DEFAULT_SAMPLE_RATE, output_audio_np
250
+ else:
251
+ logger.warning("Dia model returned None for audio output")
252
+ logger.warning("Falling back to dummy audio stream")
253
+ yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
254
+
255
+ except ImportError as import_err:
256
+ logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
257
+ logger.error("Falling back to dummy audio stream")
258
+ yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
259
+
260
+ except Exception as dia_error:
261
+ logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
262
+ logger.error(f"Error type: {type(dia_error).__name__}")
263
+ logger.error("Falling back to dummy audio stream")
264
+ yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
265
+
266
+
267
+ def get_available_engines() -> List[str]:
268
+ """Get a list of available TTS engines
269
+
270
+ Returns:
271
+ List[str]: List of available engine names
272
+ """
273
+ available = []
274
+
275
+ if KOKORO_AVAILABLE:
276
+ available.append('kokoro')
277
+
278
+ if KOKORO_SPACE_AVAILABLE:
279
+ available.append('kokoro_space')
280
+
281
+ if DIA_AVAILABLE:
282
+ available.append('dia')
283
+
284
+ # Dummy is always available
285
+ available.append('dummy')
286
+
287
+ return available
288
+
289
+
290
+ def create_engine(engine_type: str, lang_code: str = 'z') -> TTSEngineBase:
291
+ """Create a specific TTS engine
292
+
293
+ Args:
294
+ engine_type (str): Type of engine to create ('kokoro', 'kokoro_space', 'dia', 'dummy')
295
+ lang_code (str): Language code for the engine
296
+
297
+ Returns:
298
+ TTSEngineBase: An instance of the requested TTS engine
299
+
300
+ Raises:
301
+ ValueError: If the requested engine type is not supported
302
+ """
303
+ if engine_type == 'kokoro':
304
+ if not KOKORO_AVAILABLE:
305
+ raise ValueError("Kokoro TTS engine is not available")
306
+ return KokoroTTSEngine(lang_code)
307
+
308
+ elif engine_type == 'kokoro_space':
309
+ if not KOKORO_SPACE_AVAILABLE:
310
+ raise ValueError("Kokoro Space TTS engine is not available")
311
+ return KokoroSpaceTTSEngine(lang_code)
312
+
313
+ elif engine_type == 'dia':
314
+ if not DIA_AVAILABLE:
315
+ raise ValueError("Dia TTS engine is not available")
316
+ return DiaTTSEngine(lang_code)
317
+
318
+ elif engine_type == 'dummy':
319
+ return DummyTTSEngine(lang_code)
320
+
321
+ else:
322
+ raise ValueError(f"Unsupported TTS engine type: {engine_type}")
utils/tts_factory.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, List
3
+
4
+ # Configure logging
5
+ logger = logging.getLogger(__name__)
6
+
7
+ # Import the base class
8
+ from utils.tts_base import TTSEngineBase, DummyTTSEngine
9
+
10
+ class TTSFactory:
11
+ """Factory class for creating TTS engines
12
+
13
+ This class is responsible for creating the appropriate TTS engine based on
14
+ availability and configuration.
15
+ """
16
+
17
+ @staticmethod
18
+ def create_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> TTSEngineBase:
19
+ """Create a TTS engine instance
20
+
21
+ Args:
22
+ engine_type (str, optional): Type of engine to create ('kokoro', 'kokoro_space', 'dia', 'dummy')
23
+ If None, the best available engine will be used
24
+ lang_code (str): Language code for the engine
25
+
26
+ Returns:
27
+ TTSEngineBase: An instance of a TTS engine
28
+ """
29
+ from utils.tts_engines import get_available_engines, create_engine
30
+
31
+ # Get available engines
32
+ available_engines = get_available_engines()
33
+ logger.info(f"Available TTS engines: {available_engines}")
34
+
35
+ # If engine_type is specified, try to create that specific engine
36
+ if engine_type is not None:
37
+ if engine_type in available_engines:
38
+ logger.info(f"Creating requested engine: {engine_type}")
39
+ return create_engine(engine_type, lang_code)
40
+ else:
41
+ logger.warning(f"Requested engine '{engine_type}' is not available")
42
+
43
+ # Try to create the best available engine
44
+ # Priority: kokoro > kokoro_space > dia > dummy
45
+ for engine in ['kokoro', 'kokoro_space', 'dia']:
46
+ if engine in available_engines:
47
+ logger.info(f"Creating best available engine: {engine}")
48
+ return create_engine(engine, lang_code)
49
+
50
+ # Fall back to dummy engine
51
+ logger.warning("No TTS engines available, falling back to dummy engine")
52
+ return DummyTTSEngine(lang_code)
53
+
54
+
55
+ # Backward compatibility function
56
+ def get_tts_engine(lang_code: str = 'a') -> TTSEngineBase:
57
+ """Get or create TTS engine instance (backward compatibility function)
58
+
59
+ Args:
60
+ lang_code (str): Language code for the pipeline
61
+
62
+ Returns:
63
+ TTSEngineBase: Initialized TTS engine instance
64
+ """
65
+ logger.info(f"Requesting TTS engine with language code: {lang_code}")
66
+ try:
67
+ import streamlit as st
68
+ logger.info("Streamlit detected, using cached TTS engine")
69
+ @st.cache_resource
70
+ def _get_engine():
71
+ logger.info("Creating cached TTS engine instance")
72
+ engine = TTSFactory.create_engine(lang_code=lang_code)
73
+ logger.info(f"Cached TTS engine created with type: {engine.__class__.__name__}")
74
+ return engine
75
+
76
+ engine = _get_engine()
77
+ logger.info(f"Retrieved TTS engine from cache with type: {engine.__class__.__name__}")
78
+ return engine
79
+ except ImportError:
80
+ logger.info("Streamlit not available, creating direct TTS engine instance")
81
+ engine = TTSFactory.create_engine(lang_code=lang_code)
82
+ logger.info(f"Direct TTS engine created with type: {engine.__class__.__name__}")
83
+ return engine
84
+
85
+
86
+ # Backward compatibility function
87
+ def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
88
+ """Public interface for TTS generation (backward compatibility function)
89
+
90
+ Args:
91
+ text (str): Input text to synthesize
92
+ voice (str): Voice ID to use
93
+ speed (float): Speech speed multiplier
94
+
95
+ Returns:
96
+ str: Path to generated audio file
97
+ """
98
+ logger.info(f"Public generate_speech called with text length: {len(text)}, voice: {voice}, speed: {speed}")
99
+ try:
100
+ # Get the TTS engine
101
+ logger.info("Getting TTS engine instance")
102
+ engine = get_tts_engine()
103
+ logger.info(f"Using TTS engine type: {engine.__class__.__name__}")
104
+
105
+ # Generate speech
106
+ logger.info("Calling engine.generate_speech")
107
+ output_path = engine.generate_speech(text, voice, speed)
108
+ logger.info(f"Speech generation complete, output path: {output_path}")
109
+ return output_path
110
+ except Exception as e:
111
+ logger.error(f"Error in public generate_speech function: {str(e)}", exc_info=True)
112
+ logger.error(f"Error type: {type(e).__name__}")
113
+ if hasattr(e, '__traceback__'):
114
+ tb = e.__traceback__
115
+ while tb.tb_next:
116
+ tb = tb.tb_next
117
+ logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
118
+ raise