"""Automatic dependency installer for TTS providers.""" import logging import subprocess import sys import importlib from typing import List, Dict, Optional, Tuple import os logger = logging.getLogger(__name__) class DependencyInstaller: """Utility class for automatically installing missing dependencies.""" def __init__(self): """Initialize the dependency installer.""" self.installed_packages = set() def check_module_available(self, module_name: str) -> bool: """ Check if a module is available for import. Args: module_name: Name of the module to check Returns: bool: True if module is available, False otherwise """ try: importlib.import_module(module_name) return True except ImportError: return False def install_package(self, package_name: str, upgrade: bool = False) -> bool: """ Install a package using pip. Args: package_name: Name of the package to install upgrade: Whether to upgrade if already installed Returns: bool: True if installation succeeded, False otherwise """ if package_name in self.installed_packages: logger.info(f"Package {package_name} already installed in this session") return True try: cmd = [sys.executable, "-m", "pip", "install"] if upgrade: cmd.append("--upgrade") cmd.append(package_name) logger.info(f"Installing package: {package_name}") result = subprocess.run( cmd, capture_output=True, text=True, timeout=300 # 5 minute timeout ) if result.returncode == 0: logger.info(f"Successfully installed {package_name}") self.installed_packages.add(package_name) return True else: logger.error(f"Failed to install {package_name}: {result.stderr}") return False except subprocess.TimeoutExpired: logger.error(f"Installation of {package_name} timed out") return False except Exception as e: logger.error(f"Error installing {package_name}: {e}") return False def install_from_git(self, git_url: str, package_name: Optional[str] = None) -> bool: """ Install a package from a git repository. Args: git_url: Git repository URL package_name: Optional package name for tracking Returns: bool: True if installation succeeded, False otherwise """ package_name = package_name or git_url.split('/')[-1].replace('.git', '') if package_name in self.installed_packages: logger.info(f"Package {package_name} already installed in this session") return True try: cmd = [sys.executable, "-m", "pip", "install", f"git+{git_url}"] logger.info(f"Installing package from git: {git_url}") result = subprocess.run( cmd, capture_output=True, text=True, timeout=600 # 10 minute timeout for git installs ) if result.returncode == 0: logger.info(f"Successfully installed {package_name} from git") self.installed_packages.add(package_name) return True else: logger.error(f"Failed to install {package_name} from git: {result.stderr}") return False except subprocess.TimeoutExpired: logger.error(f"Git installation of {package_name} timed out") return False except Exception as e: logger.error(f"Error installing {package_name} from git: {e}") return False def install_dia_dependencies(self) -> Tuple[bool, List[str]]: """ Install all dependencies required for Dia TTS. Returns: Tuple[bool, List[str]]: (success, list of error messages) """ errors = [] # Check if Dia is already available if self.check_module_available("dia"): logger.info("Dia TTS is already available") return True, [] # Install Dia TTS from git - this will automatically install all dependencies # including descript-audio-codec as specified in pyproject.toml logger.info("Installing Dia TTS and all dependencies from GitHub") if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"): logger.info("Successfully installed Dia TTS and dependencies") return True, [] else: errors.append("Failed to install Dia TTS from git") # Fallback: try installing individual dependencies if git install fails logger.info("Git install failed, trying individual dependencies...") dependencies = [ ("torch", "torch"), ("transformers", "transformers"), ("accelerate", "accelerate"), ("soundfile", "soundfile"), ("dac", "descript-audio-codec"), ] success = True for module_name, package_name in dependencies: if not self.check_module_available(module_name): logger.info(f"Installing missing dependency: {package_name}") if not self.install_package(package_name): errors.append(f"Failed to install {package_name}") success = False # Try installing Dia again after dependencies if success and not self.check_module_available("dia"): if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"): return True, [] else: errors.append("Failed to install Dia TTS after installing dependencies") return success and len(errors) == 1, errors # Only the initial git error if dependencies succeeded def install_dependencies_for_provider(self, provider_name: str) -> Tuple[bool, List[str]]: """ Install dependencies for a specific TTS provider. Args: provider_name: Name of the TTS provider Returns: Tuple[bool, List[str]]: (success, list of error messages) """ if provider_name.lower() == "dia": return self.install_dia_dependencies() else: return False, [f"Unknown provider: {provider_name}"] def verify_installation(self, module_name: str) -> bool: """ Verify that a module was installed correctly. Args: module_name: Name of the module to verify Returns: bool: True if module can be imported, False otherwise """ try: # Clear import cache to ensure fresh import if module_name in sys.modules: del sys.modules[module_name] importlib.import_module(module_name) logger.info(f"Successfully verified installation of {module_name}") return True except ImportError as e: logger.error(f"Failed to verify installation of {module_name}: {e}") return False def get_installation_status(self) -> Dict[str, bool]: """ Get the installation status of key dependencies. Returns: Dict[str, bool]: Dictionary mapping module names to availability status """ modules_to_check = [ "torch", "transformers", "accelerate", "soundfile", "numpy", "dac", "dia" ] status = {} for module in modules_to_check: status[module] = self.check_module_available(module) return status def install_with_retry(self, package_name: str, max_retries: int = 3) -> bool: """ Install a package with retry logic. Args: package_name: Name of the package to install max_retries: Maximum number of retry attempts Returns: bool: True if installation succeeded, False otherwise """ for attempt in range(max_retries): if self.install_package(package_name): return True if attempt < max_retries - 1: logger.warning(f"Installation attempt {attempt + 1} failed for {package_name}, retrying...") else: logger.error(f"All {max_retries} installation attempts failed for {package_name}") return False # Global instance for reuse _dependency_installer = None def get_dependency_installer() -> DependencyInstaller: """ Get a global dependency installer instance. Returns: DependencyInstaller: Global dependency installer instance """ global _dependency_installer if _dependency_installer is None: _dependency_installer = DependencyInstaller() return _dependency_installer def install_dia_dependencies() -> Tuple[bool, List[str]]: """ Convenience function to install Dia TTS dependencies. Returns: Tuple[bool, List[str]]: (success, list of error messages) """ installer = get_dependency_installer() return installer.install_dia_dependencies() def check_and_install_module(module_name: str, package_name: Optional[str] = None) -> bool: """ Check if a module is available and install it if not. Args: module_name: Name of the module to check package_name: Name of the package to install (defaults to module_name) Returns: bool: True if module is available after check/install, False otherwise """ installer = get_dependency_installer() if installer.check_module_available(module_name): return True package_name = package_name or module_name if installer.install_package(package_name): return installer.verify_installation(module_name) return False