openfree commited on
Commit
6e12ecc
ยท
verified ยท
1 Parent(s): a159e51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -60
app.py CHANGED
@@ -11,6 +11,8 @@ import wave
11
  import base64
12
  import numpy as np
13
  import soundfile as sf
 
 
14
  from dataclasses import dataclass
15
  from typing import List, Tuple, Dict, Optional
16
  from pathlib import Path
@@ -34,7 +36,7 @@ from transformers import (
34
 
35
  # Spark TTS imports
36
  try:
37
- from transformers import AutoModel
38
  SPARK_AVAILABLE = True
39
  except:
40
  SPARK_AVAILABLE = False
@@ -65,8 +67,7 @@ class UnifiedAudioConverter:
65
  self.local_model = None
66
  self.tokenizer = None
67
  self.melo_models = None
68
- self.spark_model = None
69
- self.spark_tokenizer = None
70
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
71
 
72
  def initialize_api_mode(self, api_key: str):
@@ -90,19 +91,31 @@ class UnifiedAudioConverter:
90
  )
91
 
92
  def initialize_spark_tts(self):
93
- """Initialize Spark TTS model"""
94
- if SPARK_AVAILABLE and self.spark_model is None:
 
 
 
 
 
 
 
95
  try:
96
- self.spark_model = AutoModel.from_pretrained(
97
- "SparkAudio/Spark-TTS-0.5B",
98
- trust_remote_code=True,
99
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
100
- ).to(self.device)
101
- # Spark TTS๋Š” ๋ณ„๋„์˜ ํ† ํฌ๋‚˜์ด์ €๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Œ
102
- print("Spark TTS model loaded successfully")
103
  except Exception as e:
104
- print(f"Failed to load Spark TTS: {e}")
105
-
 
 
 
 
 
 
106
  def initialize_melo_tts(self):
107
  """Initialize MeloTTS models"""
108
  if MELO_AVAILABLE and self.melo_models is None:
@@ -274,54 +287,79 @@ class UnifiedAudioConverter:
274
  return tmp_path
275
 
276
  def text_to_speech_spark(self, conversation_json: Dict, progress=None) -> Tuple[str, str]:
277
- """Convert text to speech using Spark TTS"""
278
- if not SPARK_AVAILABLE or self.spark_model is None:
279
  raise RuntimeError("Spark TTS not available")
280
 
281
  try:
282
- combined_audio = []
283
- sample_rate = 22050 # Default sample rate for Spark TTS
284
 
285
- # Two different speaker configurations for variety
286
- speakers = ["female_1", "male_1"] # Adjust based on actual Spark TTS speaker options
 
 
 
287
 
288
  for i, turn in enumerate(conversation_json["conversation"]):
289
  text = turn["text"]
290
  if not text.strip():
291
  continue
292
-
293
- # Generate audio with Spark TTS
294
- # Note: The exact API might differ, adjust based on actual Spark TTS documentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  try:
296
- with torch.no_grad():
297
- audio_output = self.spark_model.synthesize(
298
- text=text,
299
- voice=speakers[i % 2] if len(speakers) > 1 else speakers[0],
300
- speed=1.0,
301
- # Add other parameters as needed
302
- )
303
-
304
- # Convert to numpy array if needed
305
- if isinstance(audio_output, torch.Tensor):
306
- audio_output = audio_output.cpu().numpy()
307
-
308
- combined_audio.append(audio_output)
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  except Exception as e:
311
- print(f"Error generating audio for turn {i}: {e}")
312
- # Generate silence as fallback
313
- silence = np.zeros(int(sample_rate * 0.5)) # 0.5 second silence
314
- combined_audio.append(silence)
315
-
316
- # Combine all audio segments
317
- if combined_audio:
318
- final_audio = np.concatenate(combined_audio)
319
-
320
- # Save to file
321
- output_path = "spark_podcast_output.wav"
322
- sf.write(output_path, final_audio, sample_rate)
323
  else:
324
- raise RuntimeError("No audio generated")
325
 
326
  # Generate conversation text
327
  conversation_text = "\n".join(
@@ -329,7 +367,7 @@ class UnifiedAudioConverter:
329
  for i, turn in enumerate(conversation_json["conversation"])
330
  )
331
 
332
- return output_path, conversation_text
333
 
334
  except Exception as e:
335
  raise RuntimeError(f"Failed to convert text to speech with Spark TTS: {e}")
@@ -386,15 +424,18 @@ class UnifiedAudioConverter:
386
  try:
387
  audio_segments = []
388
  for filename in filenames:
389
- audio_segment = AudioSegment.from_file(filename)
390
- audio_segments.append(audio_segment)
 
391
 
392
- combined = sum(audio_segments)
393
- combined.export(output_file, format="wav")
 
394
 
395
  # Clean up temporary files
396
  for filename in filenames:
397
- os.remove(filename)
 
398
 
399
  except Exception as e:
400
  raise RuntimeError(f"Failed to combine audio files: {e}")
@@ -462,7 +503,7 @@ async def regenerate_audio(conversation_text: str, tts_engine: str = "Edge-TTS")
462
  )
463
  elif tts_engine == "Spark-TTS":
464
  if not SPARK_AVAILABLE:
465
- return "Spark TTS not available. Please install required dependencies.", None
466
  converter.initialize_spark_tts()
467
  output_file, _ = converter.text_to_speech_spark(conversation_json)
468
  else: # MeloTTS
@@ -519,8 +560,8 @@ with gr.Blocks(theme='soft', title="URL to Podcast Converter") as demo:
519
 
520
  gr.Markdown("""
521
  **Recommended:**
522
- - ๐ŸŒŸ **Edge-TTS**: Best quality, cloud-based
523
- - ๐Ÿค– **Spark-TTS**: Local AI model, good quality
524
 
525
  **Additional Option:**
526
  - โšก **MeloTTS**: Local processing, GPU recommended
@@ -558,14 +599,24 @@ with gr.Blocks(theme='soft', title="URL to Podcast Converter") as demo:
558
  visible=True
559
  )
560
 
561
- # TTS ์—”์ง„๋ณ„ ์„ค๋ช… ์ถ”๊ฐ€
562
  with gr.Row():
563
  gr.Markdown("""
564
  ### TTS Engine Details:
565
 
566
  - **Edge-TTS**: Microsoft's cloud TTS service with high-quality natural voices. Requires internet connection.
567
- - **Spark-TTS**: SparkAudio's local AI model (0.5B parameters). Runs on your device, good for privacy.
 
 
 
568
  - **MeloTTS**: Local TTS with multiple voice options. GPU recommended for better performance.
 
 
 
 
 
 
 
569
  """)
570
 
571
  gr.Examples(
 
11
  import base64
12
  import numpy as np
13
  import soundfile as sf
14
+ import subprocess
15
+ import shutil
16
  from dataclasses import dataclass
17
  from typing import List, Tuple, Dict, Optional
18
  from pathlib import Path
 
36
 
37
  # Spark TTS imports
38
  try:
39
+ from huggingface_hub import snapshot_download
40
  SPARK_AVAILABLE = True
41
  except:
42
  SPARK_AVAILABLE = False
 
67
  self.local_model = None
68
  self.tokenizer = None
69
  self.melo_models = None
70
+ self.spark_model_dir = None
 
71
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
72
 
73
  def initialize_api_mode(self, api_key: str):
 
91
  )
92
 
93
  def initialize_spark_tts(self):
94
+ """Initialize Spark TTS model by downloading if needed"""
95
+ if not SPARK_AVAILABLE:
96
+ raise RuntimeError("Spark TTS dependencies not available")
97
+
98
+ model_dir = "pretrained_models/Spark-TTS-0.5B"
99
+
100
+ # Check if model exists, if not download it
101
+ if not os.path.exists(model_dir):
102
+ print("Downloading Spark-TTS model...")
103
  try:
104
+ os.makedirs("pretrained_models", exist_ok=True)
105
+ snapshot_download(
106
+ "SparkAudio/Spark-TTS-0.5B",
107
+ local_dir=model_dir
108
+ )
109
+ print("Spark-TTS model downloaded successfully")
 
110
  except Exception as e:
111
+ raise RuntimeError(f"Failed to download Spark-TTS model: {e}")
112
+
113
+ self.spark_model_dir = model_dir
114
+
115
+ # Check if we have the CLI inference script
116
+ if not os.path.exists("cli/inference.py"):
117
+ print("Warning: Spark-TTS CLI not found. Please clone the Spark-TTS repository.")
118
+
119
  def initialize_melo_tts(self):
120
  """Initialize MeloTTS models"""
121
  if MELO_AVAILABLE and self.melo_models is None:
 
287
  return tmp_path
288
 
289
  def text_to_speech_spark(self, conversation_json: Dict, progress=None) -> Tuple[str, str]:
290
+ """Convert text to speech using Spark TTS CLI"""
291
+ if not SPARK_AVAILABLE or not self.spark_model_dir:
292
  raise RuntimeError("Spark TTS not available")
293
 
294
  try:
295
+ output_dir = self._create_output_directory()
296
+ audio_files = []
297
 
298
+ # Create different voice characteristics for different speakers
299
+ voice_configs = [
300
+ {"prompt_text": "Hello, welcome to our podcast. I'm your host today.", "gender": "female"},
301
+ {"prompt_text": "Thank you for having me. I'm excited to be here.", "gender": "male"}
302
+ ]
303
 
304
  for i, turn in enumerate(conversation_json["conversation"]):
305
  text = turn["text"]
306
  if not text.strip():
307
  continue
308
+
309
+ # Use different voice config for each speaker
310
+ voice_config = voice_configs[i % len(voice_configs)]
311
+
312
+ output_file = os.path.join(output_dir, f"spark_output_{i}.wav")
313
+
314
+ # Run Spark TTS CLI inference
315
+ cmd = [
316
+ "python", "-m", "cli.inference",
317
+ "--text", text,
318
+ "--device", "0" if torch.cuda.is_available() else "cpu",
319
+ "--save_dir", output_dir,
320
+ "--model_dir", self.spark_model_dir,
321
+ "--prompt_text", voice_config["prompt_text"],
322
+ "--output_name", f"spark_output_{i}.wav"
323
+ ]
324
+
325
  try:
326
+ # Run the command
327
+ result = subprocess.run(
328
+ cmd,
329
+ capture_output=True,
330
+ text=True,
331
+ timeout=60,
332
+ cwd="." # Make sure we're in the right directory
333
+ )
 
 
 
 
 
334
 
335
+ if result.returncode == 0:
336
+ audio_files.append(output_file)
337
+ else:
338
+ print(f"Spark TTS error for turn {i}: {result.stderr}")
339
+ # Create a short silence as fallback
340
+ silence = np.zeros(int(22050 * 1.0)) # 1 second of silence
341
+ sf.write(output_file, silence, 22050)
342
+ audio_files.append(output_file)
343
+
344
+ except subprocess.TimeoutExpired:
345
+ print(f"Spark TTS timeout for turn {i}")
346
+ # Create silence as fallback
347
+ silence = np.zeros(int(22050 * 1.0))
348
+ sf.write(output_file, silence, 22050)
349
+ audio_files.append(output_file)
350
  except Exception as e:
351
+ print(f"Error running Spark TTS for turn {i}: {e}")
352
+ # Create silence as fallback
353
+ silence = np.zeros(int(22050 * 1.0))
354
+ sf.write(output_file, silence, 22050)
355
+ audio_files.append(output_file)
356
+
357
+ # Combine all audio files
358
+ if audio_files:
359
+ final_output = os.path.join(output_dir, "spark_combined.wav")
360
+ self._combine_audio_files(audio_files, final_output)
 
 
361
  else:
362
+ raise RuntimeError("No audio files generated")
363
 
364
  # Generate conversation text
365
  conversation_text = "\n".join(
 
367
  for i, turn in enumerate(conversation_json["conversation"])
368
  )
369
 
370
+ return final_output, conversation_text
371
 
372
  except Exception as e:
373
  raise RuntimeError(f"Failed to convert text to speech with Spark TTS: {e}")
 
424
  try:
425
  audio_segments = []
426
  for filename in filenames:
427
+ if os.path.exists(filename):
428
+ audio_segment = AudioSegment.from_file(filename)
429
+ audio_segments.append(audio_segment)
430
 
431
+ if audio_segments:
432
+ combined = sum(audio_segments)
433
+ combined.export(output_file, format="wav")
434
 
435
  # Clean up temporary files
436
  for filename in filenames:
437
+ if os.path.exists(filename):
438
+ os.remove(filename)
439
 
440
  except Exception as e:
441
  raise RuntimeError(f"Failed to combine audio files: {e}")
 
503
  )
504
  elif tts_engine == "Spark-TTS":
505
  if not SPARK_AVAILABLE:
506
+ return "Spark TTS not available. Please install required dependencies and clone the Spark-TTS repository.", None
507
  converter.initialize_spark_tts()
508
  output_file, _ = converter.text_to_speech_spark(conversation_json)
509
  else: # MeloTTS
 
560
 
561
  gr.Markdown("""
562
  **Recommended:**
563
+ - ๐ŸŒŸ **Edge-TTS**: Best quality, cloud-based, instant setup
564
+ - ๐Ÿค– **Spark-TTS**: Local AI model (0.5B), zero-shot voice cloning
565
 
566
  **Additional Option:**
567
  - โšก **MeloTTS**: Local processing, GPU recommended
 
599
  visible=True
600
  )
601
 
602
+ # TTS ์—”์ง„๋ณ„ ์„ค๋ช… ๋ฐ ์„ค์น˜ ์•ˆ๋‚ด ์ถ”๊ฐ€
603
  with gr.Row():
604
  gr.Markdown("""
605
  ### TTS Engine Details:
606
 
607
  - **Edge-TTS**: Microsoft's cloud TTS service with high-quality natural voices. Requires internet connection.
608
+ - **Spark-TTS**: SparkAudio's local AI model (0.5B parameters) with zero-shot voice cloning capability.
609
+ - **Setup required**: Clone [Spark-TTS repository](https://github.com/SparkAudio/Spark-TTS) in current directory
610
+ - Features: Bilingual support (Chinese/English), controllable speech generation
611
+ - License: CC BY-NC-SA (Non-commercial use only)
612
  - **MeloTTS**: Local TTS with multiple voice options. GPU recommended for better performance.
613
+
614
+ ### Spark-TTS Setup Instructions:
615
+ ```bash
616
+ git clone https://github.com/SparkAudio/Spark-TTS.git
617
+ cd Spark-TTS
618
+ pip install -r requirements.txt
619
+ ```
620
  """)
621
 
622
  gr.Examples(