Spaces:
Build error
Build error
File size: 9,880 Bytes
1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 fdc056d 1f9c751 |
|
"""Dia TTS provider implementation."""
import logging
import numpy as np
import soundfile as sf
import io
from typing import Iterator, TYPE_CHECKING
if TYPE_CHECKING:
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
from ..base.tts_provider_base import TTSProviderBase
from ...domain.exceptions import SpeechSynthesisException
logger = logging.getLogger(__name__)
# Flag to track Dia availability
DIA_AVAILABLE = False
DEFAULT_SAMPLE_RATE = 24000
# Try to import Dia dependencies
def _check_and_install_dia_dependencies():
"""Check and install Dia dependencies if needed."""
global DIA_AVAILABLE
logger.info("π Checking Dia TTS dependencies...")
try:
logger.info("Attempting to import torch...")
import torch
logger.info("β Successfully imported torch")
logger.info("Attempting to import dia.model...")
from dia.model import Dia
logger.info("β Successfully imported dia.model")
DIA_AVAILABLE = True
logger.info("β
Dia TTS engine is available")
return True
except ImportError as e:
logger.warning(f"β οΈ Dia TTS engine dependencies not available: {e}")
logger.info(f"ImportError details: {type(e).__name__}: {e}")
except ModuleNotFoundError as e:
if "dac" in str(e):
logger.warning("β Dia TTS engine is not available due to missing 'dac' module")
elif "dia" in str(e):
logger.warning("β Dia TTS engine is not available due to missing 'dia' module")
else:
logger.warning(f"β Dia TTS engine is not available: {str(e)}")
logger.info(f"ModuleNotFoundError details: {type(e).__name__}: {e}")
# Try to install missing dependencies
logger.info("π§ Attempting to install Dia TTS dependencies...")
try:
installer = get_dependency_installer()
success, errors = installer.install_dia_dependencies()
if success:
logger.info("β
Successfully installed Dia TTS dependencies")
# Try importing again after installation
try:
logger.info("Re-attempting import after installation...")
import torch
from dia.model import Dia
DIA_AVAILABLE = True
logger.info("π Dia TTS engine is now available after installation")
return True
except Exception as e:
logger.error(f"β Dia TTS still not available after installation: {e}")
logger.info(f"Post-installation import error: {type(e).__name__}: {e}")
DIA_AVAILABLE = False
return False
else:
logger.error(f"β Failed to install Dia TTS dependencies: {errors}")
DIA_AVAILABLE = False
return False
except Exception as e:
logger.error(f"β Error during dependency installation: {e}")
logger.info(f"Installation error details: {type(e).__name__}: {e}")
DIA_AVAILABLE = False
return False
# Initial check
logger.info("π Initializing Dia TTS provider...")
_check_and_install_dia_dependencies()
class DiaTTSProvider(TTSProviderBase):
"""Dia TTS provider implementation."""
def __init__(self, lang_code: str = 'z'):
"""Initialize the Dia TTS provider."""
super().__init__(
provider_name="Dia",
supported_languages=['en', 'z'] # Dia supports English and multilingual
)
self.lang_code = lang_code
self.model = None
def _ensure_model(self):
"""Ensure the model is loaded."""
global DIA_AVAILABLE
if self.model is None:
logger.info("π Ensuring Dia model is loaded...")
# If Dia is not available, try to install dependencies
if not DIA_AVAILABLE:
logger.info("β οΈ Dia not available, attempting to install dependencies...")
if _check_and_install_dia_dependencies():
DIA_AVAILABLE = True
logger.info("β
Dependencies installed, Dia is now available")
else:
logger.error("β Failed to install dependencies, Dia remains unavailable")
return False
if DIA_AVAILABLE:
try:
logger.info("π₯ Loading Dia model from pretrained...")
import torch
from dia.model import Dia
self.model = Dia.from_pretrained()
logger.info("π Dia model successfully loaded")
except ImportError as e:
logger.error(f"β Failed to import Dia dependencies: {str(e)}")
self.model = None
except FileNotFoundError as e:
logger.error(f"β Failed to load Dia model files: {str(e)}")
logger.info("βΉοΈ This might be the first time loading the model. It will be downloaded automatically.")
self.model = None
except Exception as e:
logger.error(f"β Failed to initialize Dia model: {str(e)}")
logger.info(f"Model initialization error: {type(e).__name__}: {e}")
self.model = None
is_available = self.model is not None
logger.info(f"Model availability check result: {is_available}")
return is_available
def is_available(self) -> bool:
"""Check if Dia TTS is available."""
logger.info(f"π Checking Dia availability: DIA_AVAILABLE={DIA_AVAILABLE}")
if not DIA_AVAILABLE:
logger.info("β Dia dependencies not available")
return False
model_available = self._ensure_model()
logger.info(f"π Model availability: {model_available}")
result = DIA_AVAILABLE and model_available
logger.info(f"π― Dia TTS availability result: {result}")
return result
def get_available_voices(self) -> list[str]:
"""Get available voices for Dia."""
# Dia typically uses a default voice
return ['default']
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
"""Generate audio using Dia TTS."""
if not self.is_available():
raise SpeechSynthesisException("Dia TTS engine is not available")
try:
import torch
# Extract parameters from request
text = request.text_content.text
# Generate audio using Dia
with torch.inference_mode():
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is None:
raise SpeechSynthesisException("Dia model returned None for audio output")
# Convert numpy array to bytes
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
return audio_bytes, DEFAULT_SAMPLE_RATE
except ModuleNotFoundError as e:
if "dac" in str(e):
raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
else:
self._handle_provider_error(e, "audio generation")
except Exception as e:
self._handle_provider_error(e, "audio generation")
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
"""Generate audio stream using Dia TTS."""
if not self.is_available():
raise SpeechSynthesisException("Dia TTS engine is not available")
try:
import torch
# Extract parameters from request
text = request.text_content.text
# Generate audio using Dia
with torch.inference_mode():
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is None:
raise SpeechSynthesisException("Dia model returned None for audio output")
# Convert numpy array to bytes
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
# Dia generates complete audio in one go
yield audio_bytes, DEFAULT_SAMPLE_RATE, True
except ModuleNotFoundError as e:
if "dac" in str(e):
raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
else:
self._handle_provider_error(e, "streaming audio generation")
except Exception as e:
self._handle_provider_error(e, "streaming audio generation")
def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
"""Convert numpy audio array to bytes."""
try:
# Create an in-memory buffer
buffer = io.BytesIO()
# Write audio data to buffer as WAV
sf.write(buffer, audio_array, sample_rate, format='WAV')
# Get bytes from buffer
buffer.seek(0)
return buffer.read()
except Exception as e:
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e |