File size: 10,255 Bytes
fdc056d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""Automatic dependency installer for TTS providers."""

import logging
import subprocess
import sys
import importlib
from typing import List, Dict, Optional, Tuple
import os

logger = logging.getLogger(__name__)


class DependencyInstaller:
    """Utility class for automatically installing missing dependencies."""

    def __init__(self):
        """Initialize the dependency installer."""
        self.installed_packages = set()

    def check_module_available(self, module_name: str) -> bool:
        """
        Check if a module is available for import.

        Args:
            module_name: Name of the module to check

        Returns:
            bool: True if module is available, False otherwise
        """
        try:
            importlib.import_module(module_name)
            return True
        except ImportError:
            return False

    def install_package(self, package_name: str, upgrade: bool = False) -> bool:
        """
        Install a package using pip.

        Args:
            package_name: Name of the package to install
            upgrade: Whether to upgrade if already installed

        Returns:
            bool: True if installation succeeded, False otherwise
        """
        if package_name in self.installed_packages:
            logger.info(f"Package {package_name} already installed in this session")
            return True

        try:
            cmd = [sys.executable, "-m", "pip", "install"]
            if upgrade:
                cmd.append("--upgrade")
            cmd.append(package_name)

            logger.info(f"Installing package: {package_name}")
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=300  # 5 minute timeout
            )

            if result.returncode == 0:
                logger.info(f"Successfully installed {package_name}")
                self.installed_packages.add(package_name)
                return True
            else:
                logger.error(f"Failed to install {package_name}: {result.stderr}")
                return False

        except subprocess.TimeoutExpired:
            logger.error(f"Installation of {package_name} timed out")
            return False
        except Exception as e:
            logger.error(f"Error installing {package_name}: {e}")
            return False

    def install_from_git(self, git_url: str, package_name: Optional[str] = None) -> bool:
        """
        Install a package from a git repository.

        Args:
            git_url: Git repository URL
            package_name: Optional package name for tracking

        Returns:
            bool: True if installation succeeded, False otherwise
        """
        package_name = package_name or git_url.split('/')[-1].replace('.git', '')

        if package_name in self.installed_packages:
            logger.info(f"Package {package_name} already installed in this session")
            return True

        try:
            cmd = [sys.executable, "-m", "pip", "install", f"git+{git_url}"]

            logger.info(f"Installing package from git: {git_url}")
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=600  # 10 minute timeout for git installs
            )

            if result.returncode == 0:
                logger.info(f"Successfully installed {package_name} from git")
                self.installed_packages.add(package_name)
                return True
            else:
                logger.error(f"Failed to install {package_name} from git: {result.stderr}")
                return False

        except subprocess.TimeoutExpired:
            logger.error(f"Git installation of {package_name} timed out")
            return False
        except Exception as e:
            logger.error(f"Error installing {package_name} from git: {e}")
            return False

    def install_dia_dependencies(self) -> Tuple[bool, List[str]]:
        """
        Install all dependencies required for Dia TTS.

        Returns:
            Tuple[bool, List[str]]: (success, list of error messages)
        """
        errors = []

        # Check if Dia is already available
        if self.check_module_available("dia"):
            logger.info("Dia TTS is already available")
            return True, []

        # Install Dia TTS from git - this will automatically install all dependencies
        # including descript-audio-codec as specified in pyproject.toml
        logger.info("Installing Dia TTS and all dependencies from GitHub")
        if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"):
            logger.info("Successfully installed Dia TTS and dependencies")
            return True, []
        else:
            errors.append("Failed to install Dia TTS from git")

            # Fallback: try installing individual dependencies if git install fails
            logger.info("Git install failed, trying individual dependencies...")
            dependencies = [
                ("torch", "torch"),
                ("transformers", "transformers"),
                ("accelerate", "accelerate"),
                ("soundfile", "soundfile"),
                ("dac", "descript-audio-codec"),
            ]

            success = True
            for module_name, package_name in dependencies:
                if not self.check_module_available(module_name):
                    logger.info(f"Installing missing dependency: {package_name}")
                    if not self.install_package(package_name):
                        errors.append(f"Failed to install {package_name}")
                        success = False

            # Try installing Dia again after dependencies
            if success and not self.check_module_available("dia"):
                if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"):
                    return True, []
                else:
                    errors.append("Failed to install Dia TTS after installing dependencies")

            return success and len(errors) == 1, errors  # Only the initial git error if dependencies succeeded

    def install_dependencies_for_provider(self, provider_name: str) -> Tuple[bool, List[str]]:
        """
        Install dependencies for a specific TTS provider.

        Args:
            provider_name: Name of the TTS provider

        Returns:
            Tuple[bool, List[str]]: (success, list of error messages)
        """
        if provider_name.lower() == "dia":
            return self.install_dia_dependencies()
        else:
            return False, [f"Unknown provider: {provider_name}"]

    def verify_installation(self, module_name: str) -> bool:
        """
        Verify that a module was installed correctly.

        Args:
            module_name: Name of the module to verify

        Returns:
            bool: True if module can be imported, False otherwise
        """
        try:
            # Clear import cache to ensure fresh import
            if module_name in sys.modules:
                del sys.modules[module_name]

            importlib.import_module(module_name)
            logger.info(f"Successfully verified installation of {module_name}")
            return True
        except ImportError as e:
            logger.error(f"Failed to verify installation of {module_name}: {e}")
            return False

    def get_installation_status(self) -> Dict[str, bool]:
        """
        Get the installation status of key dependencies.

        Returns:
            Dict[str, bool]: Dictionary mapping module names to availability status
        """
        modules_to_check = [
            "torch",
            "transformers",
            "accelerate",
            "soundfile",
            "numpy",
            "dac",
            "dia"
        ]

        status = {}
        for module in modules_to_check:
            status[module] = self.check_module_available(module)

        return status

    def install_with_retry(self, package_name: str, max_retries: int = 3) -> bool:
        """
        Install a package with retry logic.

        Args:
            package_name: Name of the package to install
            max_retries: Maximum number of retry attempts

        Returns:
            bool: True if installation succeeded, False otherwise
        """
        for attempt in range(max_retries):
            if self.install_package(package_name):
                return True

            if attempt < max_retries - 1:
                logger.warning(f"Installation attempt {attempt + 1} failed for {package_name}, retrying...")
            else:
                logger.error(f"All {max_retries} installation attempts failed for {package_name}")

        return False


# Global instance for reuse
_dependency_installer = None


def get_dependency_installer() -> DependencyInstaller:
    """
    Get a global dependency installer instance.

    Returns:
        DependencyInstaller: Global dependency installer instance
    """
    global _dependency_installer
    if _dependency_installer is None:
        _dependency_installer = DependencyInstaller()
    return _dependency_installer


def install_dia_dependencies() -> Tuple[bool, List[str]]:
    """
    Convenience function to install Dia TTS dependencies.

    Returns:
        Tuple[bool, List[str]]: (success, list of error messages)
    """
    installer = get_dependency_installer()
    return installer.install_dia_dependencies()


def check_and_install_module(module_name: str, package_name: Optional[str] = None) -> bool:
    """
    Check if a module is available and install it if not.

    Args:
        module_name: Name of the module to check
        package_name: Name of the package to install (defaults to module_name)

    Returns:
        bool: True if module is available after check/install, False otherwise
    """
    installer = get_dependency_installer()

    if installer.check_module_available(module_name):
        return True

    package_name = package_name or module_name
    if installer.install_package(package_name):
        return installer.verify_installation(module_name)

    return False