"""Base class for TTS provider implementations.""" import logging import os import time import tempfile from abc import ABC, abstractmethod from typing import Iterator, Optional, TYPE_CHECKING from pathlib import Path if TYPE_CHECKING: from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest from ...domain.models.audio_content import AudioContent from ...domain.models.audio_chunk import AudioChunk from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService from ...domain.exceptions import SpeechSynthesisException logger = logging.getLogger(__name__) class TTSProviderBase(ISpeechSynthesisService, ABC): """Abstract base class for TTS provider implementations.""" def __init__(self, provider_name: str, supported_languages: list[str] = None): """ Initialize the TTS provider. Args: provider_name: Name of the TTS provider supported_languages: List of supported language codes """ self.provider_name = provider_name self.supported_languages = supported_languages or [] self._output_dir = self._ensure_output_directory() def synthesize(self, request: 'SpeechSynthesisRequest') -> 'AudioContent': """ Synthesize speech from text. Args: request: The speech synthesis request Returns: AudioContent: The synthesized audio Raises: SpeechSynthesisException: If synthesis fails """ try: logger.info(f"Starting synthesis with {self.provider_name} provider") self._validate_request(request) # Generate audio using provider-specific implementation audio_data, sample_rate = self._generate_audio(request) # Create AudioContent from the generated data from ...domain.models.audio_content import AudioContent audio_content = AudioContent( data=audio_data, format='wav', # Most providers output WAV sample_rate=sample_rate, duration=self._calculate_duration(audio_data, sample_rate), filename=f"{self.provider_name}_{int(time.time())}.wav" ) logger.info(f"Synthesis completed successfully with {self.provider_name}") return audio_content except Exception as e: logger.error(f"Synthesis failed with {self.provider_name}: {str(e)}") raise SpeechSynthesisException(f"TTS synthesis failed: {str(e)}") from e def synthesize_stream(self, request: 'SpeechSynthesisRequest') -> Iterator['AudioChunk']: """ Synthesize speech from text as a stream. Args: request: The speech synthesis request Returns: Iterator[AudioChunk]: Stream of audio chunks Raises: SpeechSynthesisException: If synthesis fails """ try: logger.info(f"Starting streaming synthesis with {self.provider_name} provider") self._validate_request(request) # Generate audio stream using provider-specific implementation chunk_index = 0 for audio_data, sample_rate, is_final in self._generate_audio_stream(request): from ...domain.models.audio_chunk import AudioChunk chunk = AudioChunk( data=audio_data, format='wav', sample_rate=sample_rate, chunk_index=chunk_index, is_final=is_final, timestamp=time.time() ) yield chunk chunk_index += 1 logger.info(f"Streaming synthesis completed with {self.provider_name}") except Exception as e: logger.error(f"Streaming synthesis failed with {self.provider_name}: {str(e)}") raise SpeechSynthesisException(f"TTS streaming synthesis failed: {str(e)}") from e @abstractmethod def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]: """ Generate audio data from synthesis request. Args: request: The speech synthesis request Returns: tuple: (audio_data_bytes, sample_rate) """ pass @abstractmethod def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]: """ Generate audio data stream from synthesis request. Args: request: The speech synthesis request Returns: Iterator: (audio_data_bytes, sample_rate, is_final) tuples """ pass @abstractmethod def is_available(self) -> bool: """ Check if the TTS provider is available and ready to use. Returns: bool: True if provider is available, False otherwise """ pass @abstractmethod def get_available_voices(self) -> list[str]: """ Get list of available voices for this provider. Returns: list[str]: List of voice identifiers """ pass def _validate_request(self, request: 'SpeechSynthesisRequest') -> None: """ Validate the synthesis request. Args: request: The synthesis request to validate Raises: SpeechSynthesisException: If request is invalid """ if not request.text_content.text.strip(): raise SpeechSynthesisException("Text content cannot be empty") if self.supported_languages and request.text_content.language not in self.supported_languages: raise SpeechSynthesisException( f"Language {request.text_content.language} not supported by {self.provider_name}. " f"Supported languages: {self.supported_languages}" ) available_voices = self.get_available_voices() if available_voices and request.voice_settings.voice_id not in available_voices: raise SpeechSynthesisException( f"Voice {request.voice_settings.voice_id} not available for {self.provider_name}. " f"Available voices: {available_voices}" ) def _ensure_output_directory(self) -> Path: """ Ensure output directory exists and return its path. Returns: Path: Path to the output directory """ output_dir = Path(tempfile.gettempdir()) / "tts_output" output_dir.mkdir(exist_ok=True) return output_dir def _generate_output_path(self, prefix: str = None, extension: str = "wav") -> Path: """ Generate a unique output path for audio files. Args: prefix: Optional prefix for the filename extension: File extension (default: wav) Returns: Path: Unique file path """ prefix = prefix or self.provider_name timestamp = int(time.time() * 1000) filename = f"{prefix}_{timestamp}.{extension}" return self._output_dir / filename def _calculate_duration(self, audio_data: bytes, sample_rate: int, channels: int = 1, sample_width: int = 2) -> float: """ Calculate audio duration from raw audio data. Args: audio_data: Raw audio data in bytes sample_rate: Sample rate in Hz channels: Number of audio channels (default: 1) sample_width: Sample width in bytes (default: 2 for 16-bit) Returns: float: Duration in seconds """ if not audio_data or sample_rate <= 0: return 0.0 bytes_per_sample = channels * sample_width total_samples = len(audio_data) // bytes_per_sample return total_samples / sample_rate def _cleanup_temp_files(self, max_age_hours: int = 24) -> None: """ Clean up old temporary files. Args: max_age_hours: Maximum age of files to keep in hours """ try: current_time = time.time() max_age_seconds = max_age_hours * 3600 for file_path in self._output_dir.glob("*"): if file_path.is_file(): file_age = current_time - file_path.stat().st_mtime if file_age > max_age_seconds: file_path.unlink() logger.debug(f"Cleaned up old temp file: {file_path}") except Exception as e: logger.warning(f"Failed to cleanup temp files: {str(e)}") def _handle_provider_error(self, error: Exception, context: str = "") -> None: """ Handle provider-specific errors and convert to domain exceptions. Args: error: The original error context: Additional context about when the error occurred """ error_msg = f"{self.provider_name} error" if context: error_msg += f" during {context}" error_msg += f": {str(error)}" logger.error(error_msg, exc_info=True) raise SpeechSynthesisException(error_msg) from error