Spaces:
Build error
Build error
File size: 9,880 Bytes
1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
"""Dia TTS provider implementation."""
import logging
import numpy as np
import soundfile as sf
import io
from typing import Iterator, TYPE_CHECKING
if TYPE_CHECKING:
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
from ..base.tts_provider_base import TTSProviderBase
from ...domain.exceptions import SpeechSynthesisException
logger = logging.getLogger(__name__)
# Flag to track Dia availability
DIA_AVAILABLE = False
DEFAULT_SAMPLE_RATE = 24000
# Try to import Dia dependencies
def _check_and_install_dia_dependencies():
"""Check and install Dia dependencies if needed."""
global DIA_AVAILABLE
logger.info("π Checking Dia TTS dependencies...")
try:
logger.info("Attempting to import torch...")
import torch
logger.info("β Successfully imported torch")
logger.info("Attempting to import dia.model...")
from dia.model import Dia
logger.info("β Successfully imported dia.model")
DIA_AVAILABLE = True
logger.info("β
Dia TTS engine is available")
return True
except ImportError as e:
logger.warning(f"β οΈ Dia TTS engine dependencies not available: {e}")
logger.info(f"ImportError details: {type(e).__name__}: {e}")
except ModuleNotFoundError as e:
if "dac" in str(e):
logger.warning("β Dia TTS engine is not available due to missing 'dac' module")
elif "dia" in str(e):
logger.warning("β Dia TTS engine is not available due to missing 'dia' module")
else:
logger.warning(f"β Dia TTS engine is not available: {str(e)}")
logger.info(f"ModuleNotFoundError details: {type(e).__name__}: {e}")
# Try to install missing dependencies
logger.info("π§ Attempting to install Dia TTS dependencies...")
try:
installer = get_dependency_installer()
success, errors = installer.install_dia_dependencies()
if success:
logger.info("β
Successfully installed Dia TTS dependencies")
# Try importing again after installation
try:
logger.info("Re-attempting import after installation...")
import torch
from dia.model import Dia
DIA_AVAILABLE = True
logger.info("π Dia TTS engine is now available after installation")
return True
except Exception as e:
logger.error(f"β Dia TTS still not available after installation: {e}")
logger.info(f"Post-installation import error: {type(e).__name__}: {e}")
DIA_AVAILABLE = False
return False
else:
logger.error(f"β Failed to install Dia TTS dependencies: {errors}")
DIA_AVAILABLE = False
return False
except Exception as e:
logger.error(f"β Error during dependency installation: {e}")
logger.info(f"Installation error details: {type(e).__name__}: {e}")
DIA_AVAILABLE = False
return False
# Initial check
logger.info("π Initializing Dia TTS provider...")
_check_and_install_dia_dependencies()
class DiaTTSProvider(TTSProviderBase):
"""Dia TTS provider implementation."""
def __init__(self, lang_code: str = 'z'):
"""Initialize the Dia TTS provider."""
super().__init__(
provider_name="Dia",
supported_languages=['en', 'z'] # Dia supports English and multilingual
)
self.lang_code = lang_code
self.model = None
def _ensure_model(self):
"""Ensure the model is loaded."""
global DIA_AVAILABLE
if self.model is None:
logger.info("π Ensuring Dia model is loaded...")
# If Dia is not available, try to install dependencies
if not DIA_AVAILABLE:
logger.info("β οΈ Dia not available, attempting to install dependencies...")
if _check_and_install_dia_dependencies():
DIA_AVAILABLE = True
logger.info("β
Dependencies installed, Dia is now available")
else:
logger.error("β Failed to install dependencies, Dia remains unavailable")
return False
if DIA_AVAILABLE:
try:
logger.info("π₯ Loading Dia model from pretrained...")
import torch
from dia.model import Dia
self.model = Dia.from_pretrained()
logger.info("π Dia model successfully loaded")
except ImportError as e:
logger.error(f"β Failed to import Dia dependencies: {str(e)}")
self.model = None
except FileNotFoundError as e:
logger.error(f"β Failed to load Dia model files: {str(e)}")
logger.info("βΉοΈ This might be the first time loading the model. It will be downloaded automatically.")
self.model = None
except Exception as e:
logger.error(f"β Failed to initialize Dia model: {str(e)}")
logger.info(f"Model initialization error: {type(e).__name__}: {e}")
self.model = None
is_available = self.model is not None
logger.info(f"Model availability check result: {is_available}")
return is_available
def is_available(self) -> bool:
"""Check if Dia TTS is available."""
logger.info(f"π Checking Dia availability: DIA_AVAILABLE={DIA_AVAILABLE}")
if not DIA_AVAILABLE:
logger.info("β Dia dependencies not available")
return False
model_available = self._ensure_model()
logger.info(f"π Model availability: {model_available}")
result = DIA_AVAILABLE and model_available
logger.info(f"π― Dia TTS availability result: {result}")
return result
def get_available_voices(self) -> list[str]:
"""Get available voices for Dia."""
# Dia typically uses a default voice
return ['default']
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
"""Generate audio using Dia TTS."""
if not self.is_available():
raise SpeechSynthesisException("Dia TTS engine is not available")
try:
import torch
# Extract parameters from request
text = request.text_content.text
# Generate audio using Dia
with torch.inference_mode():
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is None:
raise SpeechSynthesisException("Dia model returned None for audio output")
# Convert numpy array to bytes
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
return audio_bytes, DEFAULT_SAMPLE_RATE
except ModuleNotFoundError as e:
if "dac" in str(e):
raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
else:
self._handle_provider_error(e, "audio generation")
except Exception as e:
self._handle_provider_error(e, "audio generation")
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
"""Generate audio stream using Dia TTS."""
if not self.is_available():
raise SpeechSynthesisException("Dia TTS engine is not available")
try:
import torch
# Extract parameters from request
text = request.text_content.text
# Generate audio using Dia
with torch.inference_mode():
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is None:
raise SpeechSynthesisException("Dia model returned None for audio output")
# Convert numpy array to bytes
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
# Dia generates complete audio in one go
yield audio_bytes, DEFAULT_SAMPLE_RATE, True
except ModuleNotFoundError as e:
if "dac" in str(e):
raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
else:
self._handle_provider_error(e, "streaming audio generation")
except Exception as e:
self._handle_provider_error(e, "streaming audio generation")
def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
"""Convert numpy audio array to bytes."""
try:
# Create an in-memory buffer
buffer = io.BytesIO()
# Write audio data to buffer as WAV
sf.write(buffer, audio_array, sample_rate, format='WAV')
# Get bytes from buffer
buffer.seek(0)
return buffer.read()
except Exception as e:
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e |