Spaces:
Build error
Build error
"""Automatic dependency installer for TTS providers.""" | |
import logging | |
import subprocess | |
import sys | |
import importlib | |
from typing import List, Dict, Optional, Tuple | |
import os | |
logger = logging.getLogger(__name__) | |
class DependencyInstaller: | |
"""Utility class for automatically installing missing dependencies.""" | |
def __init__(self): | |
"""Initialize the dependency installer.""" | |
self.installed_packages = set() | |
def check_module_available(self, module_name: str) -> bool: | |
""" | |
Check if a module is available for import. | |
Args: | |
module_name: Name of the module to check | |
Returns: | |
bool: True if module is available, False otherwise | |
""" | |
try: | |
importlib.import_module(module_name) | |
return True | |
except ImportError: | |
return False | |
def install_package(self, package_name: str, upgrade: bool = False) -> bool: | |
""" | |
Install a package using pip. | |
Args: | |
package_name: Name of the package to install | |
upgrade: Whether to upgrade if already installed | |
Returns: | |
bool: True if installation succeeded, False otherwise | |
""" | |
if package_name in self.installed_packages: | |
logger.info(f"Package {package_name} already installed in this session") | |
return True | |
try: | |
cmd = [sys.executable, "-m", "pip", "install"] | |
if upgrade: | |
cmd.append("--upgrade") | |
cmd.append(package_name) | |
logger.info(f"Installing package: {package_name}") | |
result = subprocess.run( | |
cmd, | |
capture_output=True, | |
text=True, | |
timeout=300 # 5 minute timeout | |
) | |
if result.returncode == 0: | |
logger.info(f"Successfully installed {package_name}") | |
self.installed_packages.add(package_name) | |
return True | |
else: | |
logger.error(f"Failed to install {package_name}: {result.stderr}") | |
return False | |
except subprocess.TimeoutExpired: | |
logger.error(f"Installation of {package_name} timed out") | |
return False | |
except Exception as e: | |
logger.error(f"Error installing {package_name}: {e}") | |
return False | |
def install_from_git(self, git_url: str, package_name: Optional[str] = None) -> bool: | |
""" | |
Install a package from a git repository. | |
Args: | |
git_url: Git repository URL | |
package_name: Optional package name for tracking | |
Returns: | |
bool: True if installation succeeded, False otherwise | |
""" | |
package_name = package_name or git_url.split('/')[-1].replace('.git', '') | |
if package_name in self.installed_packages: | |
logger.info(f"Package {package_name} already installed in this session") | |
return True | |
try: | |
cmd = [sys.executable, "-m", "pip", "install", f"git+{git_url}"] | |
logger.info(f"Installing package from git: {git_url}") | |
result = subprocess.run( | |
cmd, | |
capture_output=True, | |
text=True, | |
timeout=600 # 10 minute timeout for git installs | |
) | |
if result.returncode == 0: | |
logger.info(f"Successfully installed {package_name} from git") | |
self.installed_packages.add(package_name) | |
return True | |
else: | |
logger.error(f"Failed to install {package_name} from git: {result.stderr}") | |
return False | |
except subprocess.TimeoutExpired: | |
logger.error(f"Git installation of {package_name} timed out") | |
return False | |
except Exception as e: | |
logger.error(f"Error installing {package_name} from git: {e}") | |
return False | |
def install_dia_dependencies(self) -> Tuple[bool, List[str]]: | |
""" | |
Install all dependencies required for Dia TTS. | |
Returns: | |
Tuple[bool, List[str]]: (success, list of error messages) | |
""" | |
errors = [] | |
# Check if Dia is already available | |
if self.check_module_available("dia"): | |
logger.info("Dia TTS is already available") | |
return True, [] | |
# Install Dia TTS from git - this will automatically install all dependencies | |
# including descript-audio-codec as specified in pyproject.toml | |
logger.info("Installing Dia TTS and all dependencies from GitHub") | |
if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"): | |
logger.info("Successfully installed Dia TTS and dependencies") | |
return True, [] | |
else: | |
errors.append("Failed to install Dia TTS from git") | |
# Fallback: try installing individual dependencies if git install fails | |
logger.info("Git install failed, trying individual dependencies...") | |
dependencies = [ | |
("torch", "torch"), | |
("transformers", "transformers"), | |
("accelerate", "accelerate"), | |
("soundfile", "soundfile"), | |
("dac", "descript-audio-codec"), | |
] | |
success = True | |
for module_name, package_name in dependencies: | |
if not self.check_module_available(module_name): | |
logger.info(f"Installing missing dependency: {package_name}") | |
if not self.install_package(package_name): | |
errors.append(f"Failed to install {package_name}") | |
success = False | |
# Try installing Dia again after dependencies | |
if success and not self.check_module_available("dia"): | |
if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"): | |
return True, [] | |
else: | |
errors.append("Failed to install Dia TTS after installing dependencies") | |
return success and len(errors) == 1, errors # Only the initial git error if dependencies succeeded | |
def install_dependencies_for_provider(self, provider_name: str) -> Tuple[bool, List[str]]: | |
""" | |
Install dependencies for a specific TTS provider. | |
Args: | |
provider_name: Name of the TTS provider | |
Returns: | |
Tuple[bool, List[str]]: (success, list of error messages) | |
""" | |
if provider_name.lower() == "dia": | |
return self.install_dia_dependencies() | |
else: | |
return False, [f"Unknown provider: {provider_name}"] | |
def verify_installation(self, module_name: str) -> bool: | |
""" | |
Verify that a module was installed correctly. | |
Args: | |
module_name: Name of the module to verify | |
Returns: | |
bool: True if module can be imported, False otherwise | |
""" | |
try: | |
# Clear import cache to ensure fresh import | |
if module_name in sys.modules: | |
del sys.modules[module_name] | |
importlib.import_module(module_name) | |
logger.info(f"Successfully verified installation of {module_name}") | |
return True | |
except ImportError as e: | |
logger.error(f"Failed to verify installation of {module_name}: {e}") | |
return False | |
def get_installation_status(self) -> Dict[str, bool]: | |
""" | |
Get the installation status of key dependencies. | |
Returns: | |
Dict[str, bool]: Dictionary mapping module names to availability status | |
""" | |
modules_to_check = [ | |
"torch", | |
"transformers", | |
"accelerate", | |
"soundfile", | |
"numpy", | |
"dac", | |
"dia" | |
] | |
status = {} | |
for module in modules_to_check: | |
status[module] = self.check_module_available(module) | |
return status | |
def install_with_retry(self, package_name: str, max_retries: int = 3) -> bool: | |
""" | |
Install a package with retry logic. | |
Args: | |
package_name: Name of the package to install | |
max_retries: Maximum number of retry attempts | |
Returns: | |
bool: True if installation succeeded, False otherwise | |
""" | |
for attempt in range(max_retries): | |
if self.install_package(package_name): | |
return True | |
if attempt < max_retries - 1: | |
logger.warning(f"Installation attempt {attempt + 1} failed for {package_name}, retrying...") | |
else: | |
logger.error(f"All {max_retries} installation attempts failed for {package_name}") | |
return False | |
# Global instance for reuse | |
_dependency_installer = None | |
def get_dependency_installer() -> DependencyInstaller: | |
""" | |
Get a global dependency installer instance. | |
Returns: | |
DependencyInstaller: Global dependency installer instance | |
""" | |
global _dependency_installer | |
if _dependency_installer is None: | |
_dependency_installer = DependencyInstaller() | |
return _dependency_installer | |
def install_dia_dependencies() -> Tuple[bool, List[str]]: | |
""" | |
Convenience function to install Dia TTS dependencies. | |
Returns: | |
Tuple[bool, List[str]]: (success, list of error messages) | |
""" | |
installer = get_dependency_installer() | |
return installer.install_dia_dependencies() | |
def check_and_install_module(module_name: str, package_name: Optional[str] = None) -> bool: | |
""" | |
Check if a module is available and install it if not. | |
Args: | |
module_name: Name of the module to check | |
package_name: Name of the package to install (defaults to module_name) | |
Returns: | |
bool: True if module is available after check/install, False otherwise | |
""" | |
installer = get_dependency_installer() | |
if installer.check_module_available(module_name): | |
return True | |
package_name = package_name or module_name | |
if installer.install_package(package_name): | |
return installer.verify_installation(module_name) | |
return False |