Spaces:
Running
on
Zero
Running
on
Zero
import spaces # ์ถ๊ฐ | |
import gradio as gr | |
import os | |
import asyncio | |
import torch | |
import io | |
import json | |
import re | |
import httpx | |
import tempfile | |
import wave | |
import base64 | |
import numpy as np | |
import soundfile as sf | |
import subprocess | |
import shutil | |
import requests | |
import logging | |
from datetime import datetime, timedelta | |
from dataclasses import dataclass | |
from typing import List, Tuple, Dict, Optional | |
from pathlib import Path | |
from threading import Thread | |
from dotenv import load_dotenv | |
# PDF processing imports | |
from langchain_community.document_loaders import PyPDFLoader | |
# Edge TTS imports | |
import edge_tts | |
from pydub import AudioSegment | |
# OpenAI imports | |
from openai import OpenAI | |
# Transformers imports (for legacy local mode) | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
BitsAndBytesConfig, | |
) | |
# Llama CPP imports (for new local mode) | |
try: | |
from llama_cpp import Llama | |
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType | |
from llama_cpp_agent.providers import LlamaCppPythonProvider | |
from llama_cpp_agent.chat_history import BasicChatHistory | |
from llama_cpp_agent.chat_history.messages import Roles | |
from huggingface_hub import hf_hub_download | |
LLAMA_CPP_AVAILABLE = True | |
except ImportError: | |
LLAMA_CPP_AVAILABLE = False | |
# Spark TTS imports | |
try: | |
from huggingface_hub import snapshot_download | |
SPARK_AVAILABLE = True | |
except: | |
SPARK_AVAILABLE = False | |
# MeloTTS imports (for local mode) | |
try: | |
# unidic ๋ค์ด๋ก๋๋ฅผ ์กฐ๊ฑด๋ถ๋ก ์ฒ๋ฆฌ | |
if not os.path.exists("/usr/local/lib/python3.10/site-packages/unidic"): | |
try: | |
os.system("python -m unidic download") | |
except: | |
pass | |
from melo.api import TTS as MeloTTS | |
MELO_AVAILABLE = True | |
except: | |
MELO_AVAILABLE = False | |
load_dotenv() | |
# Brave Search API ์ค์ | |
BRAVE_KEY = os.getenv("BSEARCH_API") | |
BRAVE_ENDPOINT = "https://api.search.brave.com/res/v1/web/search" | |
class ConversationConfig: | |
max_words: int = 6000 # 4000์์ 6000์ผ๋ก ์ฆ๊ฐ (1.5๋ฐฐ) | |
prefix_url: str = "https://r.jina.ai/" | |
api_model_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" | |
legacy_local_model_name: str = "NousResearch/Hermes-2-Pro-Llama-3-8B" | |
# ์๋ก์ด ๋ก์ปฌ ๋ชจ๋ธ ์ค์ | |
local_model_name: str = "Private-BitSix-Mistral-Small-3.1-24B-Instruct-2503.gguf" | |
local_model_repo: str = "ginigen/Private-BitSix-Mistral-Small-3.1-24B-Instruct-2503" | |
# ํ ํฐ ์ ์ฆ๊ฐ | |
max_tokens: int = 4500 # 3000์์ 4500์ผ๋ก ์ฆ๊ฐ (1.5๋ฐฐ) | |
max_new_tokens: int = 9000 # 6000์์ 9000์ผ๋ก ์ฆ๊ฐ (1.5๋ฐฐ) | |
min_conversation_turns: int = 12 # ์ต์ ๋ํ ํด ์ | |
max_conversation_turns: int = 15 # ์ต๋ ๋ํ ํด ์ | |
def brave_search(query: str, count: int = 8, freshness_days: int | None = None): | |
"""Brave Search API๋ฅผ ์ฌ์ฉํ์ฌ ์ต์ ์ ๋ณด ๊ฒ์""" | |
if not BRAVE_KEY: | |
return [] | |
params = {"q": query, "count": str(count)} | |
if freshness_days: | |
dt_from = (datetime.utcnow() - timedelta(days=freshness_days)).strftime("%Y-%m-%d") | |
params["freshness"] = dt_from | |
try: | |
r = requests.get( | |
BRAVE_ENDPOINT, | |
headers={"Accept": "application/json", "X-Subscription-Token": BRAVE_KEY}, | |
params=params, | |
timeout=15 | |
) | |
raw = r.json().get("web", {}).get("results") or [] | |
return [{ | |
"title": r.get("title", ""), | |
"url": r.get("url", r.get("link", "")), | |
"snippet": r.get("description", r.get("text", "")), | |
"host": re.sub(r"https?://(www\.)?", "", r.get("url", "")).split("/")[0] | |
} for r in raw[:count]] | |
except Exception as e: | |
logging.error(f"Brave search error: {e}") | |
return [] | |
def format_search_results(query: str, for_keyword: bool = False) -> str: | |
"""๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํฌ๋งทํ ํ์ฌ ๋ฐํ""" | |
# ํค์๋ ๊ฒ์์ ๊ฒฝ์ฐ ๋ ๋ง์ ๊ฒฐ๊ณผ ์ฌ์ฉ | |
count = 5 if for_keyword else 3 | |
rows = brave_search(query, count, freshness_days=7 if not for_keyword else None) | |
if not rows: | |
return "" | |
results = [] | |
# ํค์๋ ๊ฒ์์ ๊ฒฝ์ฐ ๋ ์์ธํ ์ ๋ณด ํฌํจ | |
max_results = 4 if for_keyword else 2 | |
for r in rows[:max_results]: | |
if for_keyword: | |
# ํค์๋ ๊ฒ์์ ๋ ๊ธด ์ค๋ํซ ์ฌ์ฉ | |
snippet = r['snippet'][:200] + "..." if len(r['snippet']) > 200 else r['snippet'] | |
results.append(f"**{r['title']}**\n{snippet}\nSource: {r['host']}") | |
else: | |
# ์ผ๋ฐ ๊ฒ์์ ์งง์ ์ค๋ํซ | |
snippet = r['snippet'][:100] + "..." if len(r['snippet']) > 100 else r['snippet'] | |
results.append(f"- {r['title']}: {snippet}") | |
return "\n\n".join(results) + "\n" | |
def extract_keywords_for_search(text: str, language: str = "English") -> List[str]: | |
"""ํ ์คํธ์์ ๊ฒ์ํ ํค์๋ ์ถ์ถ (๊ฐ์ )""" | |
# ํ ์คํธ ์๋ถ๋ถ๋ง ์ฌ์ฉ (๋๋ฌด ๋ง์ ํ ์คํธ ์ฒ๋ฆฌ ๋ฐฉ์ง) | |
text_sample = text[:500] | |
if language == "Korean": | |
import re | |
# ํ๊ตญ์ด ๋ช ์ฌ ์ถ์ถ (2๊ธ์ ์ด์) | |
keywords = re.findall(r'[๊ฐ-ํฃ]{2,}', text_sample) | |
# ์ค๋ณต ์ ๊ฑฐํ๊ณ ๊ฐ์ฅ ๊ธด ๋จ์ด 1๊ฐ๋ง ์ ํ | |
unique_keywords = list(dict.fromkeys(keywords)) | |
# ๊ธธ์ด ์์ผ๋ก ์ ๋ ฌํ๊ณ ๊ฐ์ฅ ์๋ฏธ์์ ๊ฒ ๊ฐ์ ๋จ์ด ์ ํ | |
unique_keywords.sort(key=len, reverse=True) | |
return unique_keywords[:1] # 1๊ฐ๋ง ๋ฐํ | |
else: | |
# ์์ด๋ ๋๋ฌธ์๋ก ์์ํ๋ ๋จ์ด ์ค ๊ฐ์ฅ ๊ธด ๊ฒ 1๊ฐ | |
words = text_sample.split() | |
keywords = [word.strip('.,!?;:') for word in words | |
if len(word) > 4 and word[0].isupper()] | |
if keywords: | |
return [max(keywords, key=len)] # ๊ฐ์ฅ ๊ธด ๋จ์ด 1๊ฐ | |
return [] | |
def search_and_compile_content(keyword: str, language: str = "English") -> str: | |
"""ํค์๋๋ก ๊ฒ์ํ์ฌ ์ฝํ ์ธ ์ปดํ์ผ""" | |
if not BRAVE_KEY: | |
return f"Search API not available. Using keyword: {keyword}" | |
# ์ธ์ด์ ๋ฐ๋ฅธ ๊ฒ์ ์ฟผ๋ฆฌ ์กฐ์ | |
if language == "Korean": | |
queries = [ | |
f"{keyword} ์ต์ ๋ด์ค", | |
f"{keyword} ์ ๋ณด", | |
f"{keyword} ํธ๋ ๋ 2024" | |
] | |
else: | |
queries = [ | |
f"{keyword} latest news", | |
f"{keyword} explained", | |
f"{keyword} trends 2024" | |
] | |
all_content = [] | |
for query in queries: | |
results = brave_search(query, count=3) | |
for r in results[:2]: # ๊ฐ ์ฟผ๋ฆฌ๋น ์์ 2๊ฐ ๊ฒฐ๊ณผ | |
content = f"**{r['title']}**\n{r['snippet']}\n" | |
all_content.append(content) | |
if not all_content: | |
return f"No search results found for: {keyword}" | |
# ์ปดํ์ผ๋ ์ฝํ ์ธ ๋ฐํ | |
compiled = "\n\n".join(all_content) | |
# ํค์๋ ๊ธฐ๋ฐ ์๊ฐ ์ถ๊ฐ | |
if language == "Korean": | |
intro = f"'{keyword}'์ ๋ํ ์ต์ ์ ๋ณด์ ํธ๋ ๋:\n\n" | |
else: | |
intro = f"Latest information and trends about '{keyword}':\n\n" | |
return intro + compiled | |
class UnifiedAudioConverter: | |
def __init__(self, config: ConversationConfig): | |
self.config = config | |
self.llm_client = None | |
self.legacy_local_model = None | |
self.legacy_tokenizer = None | |
# ์๋ก์ด ๋ก์ปฌ LLM ๊ด๋ จ | |
self.local_llm = None | |
self.local_llm_model = None | |
self.melo_models = None | |
self.spark_model_dir = None | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
def initialize_api_mode(self, api_key: str): | |
"""Initialize API mode with Together API (now fallback)""" | |
self.llm_client = OpenAI(api_key=api_key, base_url="https://api.together.xyz/v1") | |
def initialize_local_mode(self): | |
"""Initialize new local mode with Llama CPP""" | |
if not LLAMA_CPP_AVAILABLE: | |
raise RuntimeError("Llama CPP dependencies not available. Please install llama-cpp-python and llama-cpp-agent.") | |
if self.local_llm is None or self.local_llm_model != self.config.local_model_name: | |
try: | |
# ๋ชจ๋ธ ๋ค์ด๋ก๋ | |
model_path = hf_hub_download( | |
repo_id=self.config.local_model_repo, | |
filename=self.config.local_model_name, | |
local_dir="./models" | |
) | |
model_path_local = os.path.join("./models", self.config.local_model_name) | |
if not os.path.exists(model_path_local): | |
raise RuntimeError(f"Model file not found at {model_path_local}") | |
# Llama ๋ชจ๋ธ ์ด๊ธฐํ | |
self.local_llm = Llama( | |
model_path=model_path_local, | |
flash_attn=True, | |
n_gpu_layers=81 if torch.cuda.is_available() else 0, | |
n_batch=1024, | |
n_ctx=16384, | |
) | |
self.local_llm_model = self.config.local_model_name | |
print(f"Local LLM initialized: {model_path_local}") | |
except Exception as e: | |
print(f"Failed to initialize local LLM: {e}") | |
raise RuntimeError(f"Failed to initialize local LLM: {e}") | |
def initialize_legacy_local_mode(self): | |
"""Initialize legacy local mode with Hugging Face model (fallback)""" | |
if self.legacy_local_model is None: | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
self.legacy_local_model = AutoModelForCausalLM.from_pretrained( | |
self.config.legacy_local_model_name, | |
quantization_config=quantization_config | |
) | |
self.legacy_tokenizer = AutoTokenizer.from_pretrained( | |
self.config.legacy_local_model_name, | |
revision='8ab73a6800796d84448bc936db9bac5ad9f984ae' | |
) | |
def initialize_spark_tts(self): | |
"""Initialize Spark TTS model by downloading if needed""" | |
if not SPARK_AVAILABLE: | |
raise RuntimeError("Spark TTS dependencies not available") | |
model_dir = "pretrained_models/Spark-TTS-0.5B" | |
# Check if model exists, if not download it | |
if not os.path.exists(model_dir): | |
print("Downloading Spark-TTS model...") | |
try: | |
os.makedirs("pretrained_models", exist_ok=True) | |
snapshot_download( | |
"SparkAudio/Spark-TTS-0.5B", | |
local_dir=model_dir | |
) | |
print("Spark-TTS model downloaded successfully") | |
except Exception as e: | |
raise RuntimeError(f"Failed to download Spark-TTS model: {e}") | |
self.spark_model_dir = model_dir | |
# Check if we have the CLI inference script | |
if not os.path.exists("cli/inference.py"): | |
print("Warning: Spark-TTS CLI not found. Please clone the Spark-TTS repository.") | |
def initialize_melo_tts(self): | |
"""Initialize MeloTTS models""" | |
if MELO_AVAILABLE and self.melo_models is None: | |
self.melo_models = {"EN": MeloTTS(language="EN", device=self.device)} | |
def fetch_text(self, url: str) -> str: | |
"""Fetch text content from URL""" | |
if not url: | |
raise ValueError("URL cannot be empty") | |
if not url.startswith("http://") and not url.startswith("https://"): | |
raise ValueError("URL must start with 'http://' or 'https://'") | |
full_url = f"{self.config.prefix_url}{url}" | |
try: | |
response = httpx.get(full_url, timeout=60.0) | |
response.raise_for_status() | |
return response.text | |
except httpx.HTTPError as e: | |
raise RuntimeError(f"Failed to fetch URL: {e}") | |
def extract_text_from_pdf(self, pdf_file) -> str: | |
"""Extract text content from PDF file""" | |
try: | |
# Gradio returns file path, not file object | |
if isinstance(pdf_file, str): | |
pdf_path = pdf_file | |
else: | |
# If it's a file object (shouldn't happen with Gradio) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: | |
tmp_file.write(pdf_file.read()) | |
pdf_path = tmp_file.name | |
# PDF ๋ก๋ ๋ฐ ํ ์คํธ ์ถ์ถ | |
loader = PyPDFLoader(pdf_path) | |
pages = loader.load() | |
# ๋ชจ๋ ํ์ด์ง์ ํ ์คํธ๋ฅผ ๊ฒฐํฉ | |
text = "\n".join([page.page_content for page in pages]) | |
# ์์ ํ์ผ์ธ ๊ฒฝ์ฐ ์ญ์ | |
if not isinstance(pdf_file, str) and os.path.exists(pdf_path): | |
os.unlink(pdf_path) | |
return text | |
except Exception as e: | |
raise RuntimeError(f"Failed to extract text from PDF: {e}") | |
def _get_messages_formatter_type(self, model_name): | |
"""Get appropriate message formatter for the model""" | |
if "Mistral" in model_name or "BitSix" in model_name: | |
return MessagesFormatterType.CHATML | |
else: | |
return MessagesFormatterType.LLAMA_3 | |
def _build_prompt(self, text: str, language: str = "English", search_context: str = "") -> str: | |
"""Build prompt for conversation generation with enhanced radio talk show style""" | |
# ํ ์คํธ ๊ธธ์ด ์ ํ | |
max_text_length = 4500 if search_context else 6000 | |
if len(text) > max_text_length: | |
text = text[:max_text_length] + "..." | |
if language == "Korean": | |
# ๋ํ ํ ํ๋ฆฟ์ ๋ ๋ง์ ํด์ผ๋ก ํ์ฅ | |
template = """ | |
{ | |
"conversation": [ | |
{"speaker": "์ค์", "text": ""}, | |
{"speaker": "๋ฏผํธ", "text": ""}, | |
{"speaker": "์ค์", "text": ""}, | |
{"speaker": "๋ฏผํธ", "text": ""}, | |
{"speaker": "์ค์", "text": ""}, | |
{"speaker": "๋ฏผํธ", "text": ""}, | |
{"speaker": "์ค์", "text": ""}, | |
{"speaker": "๋ฏผํธ", "text": ""}, | |
{"speaker": "์ค์", "text": ""}, | |
{"speaker": "๋ฏผํธ", "text": ""}, | |
{"speaker": "์ค์", "text": ""}, | |
{"speaker": "๋ฏผํธ", "text": ""} | |
] | |
} | |
""" | |
context_part = "" | |
if search_context: | |
context_part = f"# ์ต์ ๊ด๋ จ ์ ๋ณด:\n{search_context}\n" | |
base_prompt = ( | |
f"# ์๋ณธ ์ฝํ ์ธ :\n{text}\n\n" | |
f"{context_part}" | |
f"์ ๋ด์ฉ์ผ๋ก ๋ผ๋์ค ๋๋ด ํ๋ก๊ทธ๋จ ๋๋ณธ์ ์์ฑํด์ฃผ์ธ์.\n\n" | |
f"## ํต์ฌ ์ง์นจ:\n" | |
f"1. **๋ํ ์คํ์ผ**: ์ค์ ๋ผ๋์ค ๋๋ด์ฒ๋ผ ์์ฃผ ์์ฐ์ค๋ฝ๊ณ ํธ์ํ ๊ตฌ์ด์ฒด ์ฌ์ฉ\n" | |
f"2. **ํ์ ์ญํ **:\n" | |
f" - ์ค์: ์งํ์/ํธ์คํธ (์ฃผ๋ก ์ง๋ฌธํ๊ณ ๋ํ๋ฅผ ์ด๋์ด๊ฐ)\n" | |
f" - ๋ฏผํธ: ์ ๋ฌธ๊ฐ (์ง๋ฌธ์ ๋ตํ๊ณ ์ค๋ช ํจ)\n" | |
f"3. **๋ํ ํจํด**:\n" | |
f" - ์ค์๋ ์ฃผ๋ก ์งง์ ์ง๋ฌธ์ด๋ ๋ฆฌ์ก์ (\"์, ๊ทธ๋ ๊ตฐ์\", \"ํฅ๋ฏธ๋กญ๋ค์\", \"๊ทธ๋ผ ~๋ ์ด๋ค๊ฐ์?\")\n" | |
f" - ๋ฏผํธ๋ 1-2๋ฌธ์ฅ์ผ๋ก ๊ฐ๊ฒฐํ๊ฒ ๋ต๋ณ\n" | |
f" - ์ ๋ ํ ์ฌ๋์ด 3๋ฌธ์ฅ ์ด์ ์ฐ์์ผ๋ก ๋งํ์ง ์์\n" | |
f"4. **์์ฐ์ค๋ฌ์**:\n" | |
f" - \"์...\", \"์...\", \"๋ค,\" ๊ฐ์ ์ถ์์ ์ฌ์ฉ\n" | |
f" - ๋๋ก๋ ์๋๋ฐฉ ๋ง์ ์งง๊ฒ ๋ฐ์ (\"๋ง์์\", \"๊ทธ๋ ์ฃ \")\n" | |
f"5. **ํ์ ๊ท์น**: ์๋ก ์กด๋๋ง ์ฌ์ฉ, 12-15ํ ๋ํ ๊ตํ\n\n" | |
f"JSON ํ์์ผ๋ก๋ง ๋ฐํ:\n{template}" | |
) | |
return base_prompt | |
else: | |
# ์์ด ํ ํ๋ฆฟ๋ ํ์ฅ | |
template = """ | |
{ | |
"conversation": [ | |
{"speaker": "Alex", "text": ""}, | |
{"speaker": "Jordan", "text": ""}, | |
{"speaker": "Alex", "text": ""}, | |
{"speaker": "Jordan", "text": ""}, | |
{"speaker": "Alex", "text": ""}, | |
{"speaker": "Jordan", "text": ""}, | |
{"speaker": "Alex", "text": ""}, | |
{"speaker": "Jordan", "text": ""}, | |
{"speaker": "Alex", "text": ""}, | |
{"speaker": "Jordan", "text": ""}, | |
{"speaker": "Alex", "text": ""}, | |
{"speaker": "Jordan", "text": ""} | |
] | |
} | |
""" | |
context_part = "" | |
if search_context: | |
context_part = f"# Latest Information:\n{search_context}\n" | |
base_prompt = ( | |
f"# Content:\n{text}\n\n" | |
f"{context_part}" | |
f"Create a natural radio talk show conversation.\n\n" | |
f"## Key Guidelines:\n" | |
f"1. **Style**: Natural, conversational English like a real radio show\n" | |
f"2. **Roles**:\n" | |
f" - Alex: Host (asks questions, guides conversation)\n" | |
f" - Jordan: Expert (answers, explains)\n" | |
f"3. **Pattern**:\n" | |
f" - Alex mostly asks short questions or reacts (\"I see\", \"Interesting\", \"What about...?\")\n" | |
f" - Jordan gives brief 1-2 sentence answers\n" | |
f" - Never more than 2-3 sentences per turn\n" | |
f"4. **Natural flow**:\n" | |
f" - Use fillers like \"Well,\" \"You know,\" \"Actually,\"\n" | |
f" - Short reactions (\"Right\", \"Exactly\")\n" | |
f"5. **Length**: 12-15 exchanges total\n\n" | |
f"Return JSON only:\n{template}" | |
) | |
return base_prompt | |
def _build_messages_for_local(self, text: str, language: str = "English", search_context: str = "") -> List[Dict]: | |
"""Build messages for local LLM with enhanced radio talk show style""" | |
if language == "Korean": | |
system_message = ( | |
"๋น์ ์ ํ๊ตญ ์ต๊ณ ์ ๋ผ๋์ค ๋๋ด ํ๋ก๊ทธ๋จ ์๊ฐ์ ๋๋ค. " | |
"์ค์ ๋ผ๋์ค ๋ฐฉ์ก์ฒ๋ผ ์์ฐ์ค๋ฝ๊ณ ์๋๊ฐ ์๋ ๋ํ๋ฅผ ๋ง๋ค์ด๋ ๋๋ค.\n\n" | |
"ํต์ฌ ์์น:\n" | |
"1. ๋ผ๋์ค ์งํ์(์ค์)๋ ์ฃผ๋ก ์งง์ ์ง๋ฌธ๊ณผ ๋ฆฌ์ก์ ์ผ๋ก ๋ํ๋ฅผ ์ด๋์ด๊ฐ๋๋ค\n" | |
"2. ์ ๋ฌธ๊ฐ(๋ฏผํธ)๋ ์ง๋ฌธ์ ๊ฐ๊ฒฐํ๊ณ ์ดํดํ๊ธฐ ์ฝ๊ฒ ๋ตํฉ๋๋ค\n" | |
"3. ํ ๋ฒ์ ๋๋ฌด ๋ง์ ์ ๋ณด๋ฅผ ์ ๋ฌํ์ง ์๊ณ , ๋ํ๋ฅผ ํตํด ์ ์ง์ ์ผ๋ก ํ์ด๊ฐ๋๋ค\n" | |
"4. \"์...\", \"์...\", \"๋ค,\" ๋ฑ ์์ฐ์ค๋ฌ์ด ๊ตฌ์ด์ฒด ํํ์ ์ฌ์ฉํฉ๋๋ค\n" | |
"5. ์ฒญ์ทจ์๊ฐ ๋ผ๋์ค๋ฅผ ๋ฃ๋ ๊ฒ์ฒ๋ผ ๋ชฐ์ ํ ์ ์๋๋ก ์์ํ๊ฒ ์์ฑํฉ๋๋ค\n" | |
"6. ๋ฐ๋์ ์๋ก ์กด๋๋ง์ ์ฌ์ฉํ๋ฉฐ, ์ ์คํ๋ฉด์๋ ์น๊ทผํ ํค์ ์ ์งํฉ๋๋ค" | |
) | |
else: | |
system_message = ( | |
"You are an expert radio talk show scriptwriter who creates engaging, " | |
"natural conversations that sound like real radio broadcasts.\n\n" | |
"Key principles:\n" | |
"1. The host (Alex) mainly asks short questions and gives reactions to guide the conversation\n" | |
"2. The expert (Jordan) answers concisely and clearly\n" | |
"3. Information is revealed gradually through dialogue, not in long monologues\n" | |
"4. Use natural speech patterns with fillers like 'Well,' 'You know,' etc.\n" | |
"5. Make it sound like an actual radio show that listeners would enjoy\n" | |
"6. Keep each turn brief - no more than 2-3 sentences" | |
) | |
return [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": self._build_prompt(text, language, search_context)} | |
] | |
def extract_conversation_local(self, text: str, language: str = "English", progress=None) -> Dict: | |
"""Extract conversation using new local LLM with enhanced search and style""" | |
try: | |
# ๊ฒ์ ์ปจํ ์คํธ ์์ฑ (ํค์๋ ๊ธฐ๋ฐ์ด ์๋ ๊ฒฝ์ฐ) | |
search_context = "" | |
if BRAVE_KEY and not text.startswith("Keyword-based content:"): | |
try: | |
keywords = extract_keywords_for_search(text, language) | |
if keywords: | |
search_query = keywords[0] if language == "Korean" else f"{keywords[0]} latest news" | |
search_context = format_search_results(search_query) | |
print(f"Search context added for: {search_query}") | |
except Exception as e: | |
print(f"Search failed, continuing without context: {e}") | |
# ๋จผ์ ์๋ก์ด ๋ก์ปฌ LLM ์๋ | |
self.initialize_local_mode() | |
chat_template = self._get_messages_formatter_type(self.config.local_model_name) | |
provider = LlamaCppPythonProvider(self.local_llm) | |
# ๊ฐํ๋ ๋ผ๋์ค ์คํ์ผ ์์คํ ๋ฉ์์ง | |
if language == "Korean": | |
system_message = ( | |
"๋น์ ์ ํ๊ตญ์ ์ธ๊ธฐ ๋ผ๋์ค ๋๋ด ํ๋ก๊ทธ๋จ ์ ๋ฌธ ์๊ฐ์ ๋๋ค. " | |
"์ฒญ์ทจ์๋ค์ด ์ค์ ๋ผ๋์ค๋ฅผ ๋ฃ๋ ๊ฒ์ฒ๋ผ ๋ชฐ์ ํ ์ ์๋ ์์ฐ์ค๋ฌ์ด ๋ํ๋ฅผ ๋ง๋ญ๋๋ค.\n\n" | |
"์์ฑ ๊ท์น:\n" | |
"1. ์งํ์(์ค์)๋ ์ฃผ๋ก ์งง์ ์ง๋ฌธ์ผ๋ก ๋ํ๋ฅผ ์ด๋์ด๊ฐ์ธ์ (\"๊ทธ๋ ๊ตฐ์\", \"์ด๋ค ์ ์ด ํน๋ณํ๊ฐ์?\", \"์ฒญ์ทจ์๋ถ๋ค์ด ๊ถ๊ธํดํ์ค ๊ฒ ๊ฐ์๋ฐ์\")\n" | |
"2. ์ ๋ฌธ๊ฐ(๋ฏผํธ)๋ 1-2๋ฌธ์ฅ์ผ๋ก ๊ฐ๊ฒฐํ๊ฒ ๋ต๋ณํ์ธ์\n" | |
"3. ์ ๋ ํ ์ฌ๋์ด 3๋ฌธ์ฅ ์ด์ ์ฐ์์ผ๋ก ๋งํ์ง ๋ง์ธ์\n" | |
"4. ๊ตฌ์ด์ฒด์ ์ถ์์๋ฅผ ์์ฐ์ค๋ฝ๊ฒ ์ฌ์ฉํ์ธ์\n" | |
"5. ๋ฐ๋์ ์๋ก ์กด๋๋ง์ ์ฌ์ฉํ์ธ์\n" | |
"6. 12-15ํ์ ๋ํ ๊ตํ์ผ๋ก ๊ตฌ์ฑํ์ธ์\n" | |
"7. JSON ํ์์ผ๋ก๋ง ์๋ตํ์ธ์" | |
) | |
else: | |
system_message = ( | |
"You are a professional radio talk show scriptwriter creating engaging, " | |
"natural conversations that sound like real radio broadcasts.\n\n" | |
"Writing rules:\n" | |
"1. Host (Alex) mainly asks short questions to guide the conversation (\"I see\", \"What makes it special?\", \"Our listeners might wonder...\")\n" | |
"2. Expert (Jordan) answers in 1-2 concise sentences\n" | |
"3. Never have one person speak more than 2-3 sentences at once\n" | |
"4. Use natural speech patterns and fillers\n" | |
"5. Create 12-15 conversation exchanges\n" | |
"6. Respond only in JSON format" | |
) | |
agent = LlamaCppAgent( | |
provider, | |
system_prompt=system_message, | |
predefined_messages_formatter_type=chat_template, | |
debug_output=False | |
) | |
settings = provider.get_provider_default_settings() | |
settings.temperature = 0.85 # ์ฝ๊ฐ ๋์ฌ์ ๋ ์์ฐ์ค๋ฌ์ด ๋ํ ์์ฑ | |
settings.top_k = 40 | |
settings.top_p = 0.95 | |
settings.max_tokens = self.config.max_tokens # ์ฆ๊ฐ๋ ํ ํฐ ์ ์ฌ์ฉ | |
settings.repeat_penalty = 1.1 | |
settings.stream = False | |
messages = BasicChatHistory() | |
prompt = self._build_prompt(text, language, search_context) | |
response = agent.get_chat_response( | |
prompt, | |
llm_sampling_settings=settings, | |
chat_history=messages, | |
returns_streaming_generator=False, | |
print_output=False | |
) | |
# JSON ํ์ฑ | |
pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}" | |
json_match = re.search(pattern, response) | |
if json_match: | |
conversation_data = json.loads(json_match.group()) | |
# ๋ํ ๊ธธ์ด ํ์ธ ๋ฐ ์กฐ์ | |
if len(conversation_data["conversation"]) < self.config.min_conversation_turns: | |
print(f"Conversation too short ({len(conversation_data['conversation'])} turns), regenerating...") | |
# ์ฌ์๋ ๋ก์ง ์ถ๊ฐ ๊ฐ๋ฅ | |
return conversation_data | |
else: | |
raise ValueError("No valid JSON found in local LLM response") | |
except Exception as e: | |
print(f"Local LLM failed: {e}, falling back to legacy local method") | |
return self.extract_conversation_legacy_local(text, language, progress, search_context) | |
def extract_conversation_legacy_local(self, text: str, language: str = "English", progress=None, search_context: str = "") -> Dict: | |
"""Extract conversation using legacy local model with enhanced style""" | |
try: | |
self.initialize_legacy_local_mode() | |
# ๊ฐํ๋ ๋ผ๋์ค ์คํ์ผ ์์คํ ๋ฉ์์ง | |
if language == "Korean": | |
system_message = ( | |
"๋น์ ์ ๋ผ๋์ค ๋๋ด ํ๋ก๊ทธ๋จ ์๊ฐ์ ๋๋ค. " | |
"์งํ์(์ค์)๋ ์งง์ ์ง๋ฌธ์ผ๋ก, ์ ๋ฌธ๊ฐ(๋ฏผํธ)๋ ๊ฐ๊ฒฐํ ๋ต๋ณ์ผ๋ก " | |
"์์ฐ์ค๋ฌ์ด ๋ํ๋ฅผ ๋ง๋์ธ์. ์๋ก ์กด๋๋ง์ ์ฌ์ฉํ๊ณ , " | |
"ํ ๋ฒ์ 2-3๋ฌธ์ฅ ์ด๋ด๋ก ๋งํ์ธ์. 12-15ํ ๋ํ ๊ตํ์ผ๋ก ๊ตฌ์ฑํ์ธ์." | |
) | |
else: | |
system_message = ( | |
"You are a radio talk show scriptwriter. " | |
"Create natural dialogue where the host (Alex) asks short questions " | |
"and the expert (Jordan) gives brief answers. " | |
"Keep each turn to 2-3 sentences max. Create 12-15 exchanges." | |
) | |
chat = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": self._build_prompt(text, language, search_context)} | |
] | |
terminators = [ | |
self.legacy_tokenizer.eos_token_id, | |
self.legacy_tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
] | |
messages = self.legacy_tokenizer.apply_chat_template( | |
chat, tokenize=False, add_generation_prompt=True | |
) | |
model_inputs = self.legacy_tokenizer([messages], return_tensors="pt").to(self.device) | |
streamer = TextIteratorStreamer( | |
self.legacy_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=self.config.max_new_tokens, # ์ฆ๊ฐ๋ ํ ํฐ ์ ์ฌ์ฉ | |
do_sample=True, | |
temperature=0.85, | |
eos_token_id=terminators, | |
) | |
t = Thread(target=self.legacy_local_model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}" | |
json_match = re.search(pattern, partial_text) | |
if json_match: | |
return json.loads(json_match.group()) | |
else: | |
raise ValueError("No valid JSON found in legacy local response") | |
except Exception as e: | |
print(f"Legacy local model also failed: {e}") | |
# Return enhanced default template | |
if language == "Korean": | |
return self._get_default_korean_conversation() | |
else: | |
return self._get_default_english_conversation() | |
def _get_default_korean_conversation(self) -> Dict: | |
"""๋ ๊ธด ๊ธฐ๋ณธ ํ๊ตญ์ด ๋ํ ํ ํ๋ฆฟ""" | |
return { | |
"conversation": [ | |
{"speaker": "์ค์", "text": "์๋ ํ์ธ์, ์ฌ๋ฌ๋ถ! ์ค๋๋ ์ ํฌ ํ์บ์คํธ๋ฅผ ์ฐพ์์ฃผ์ ์ ์ ๋ง ๊ฐ์ฌํฉ๋๋ค."}, | |
{"speaker": "๋ฏผํธ", "text": "๋ค, ์๋ ํ์ธ์! ์ค๋ ์ ๋ง ํฅ๋ฏธ๋ก์ด ์ฃผ์ ๋ฅผ ์ค๋นํ์ต๋๋ค."}, | |
{"speaker": "์ค์", "text": "์, ๊ทธ๋์? ์ด๋ค ๋ด์ฉ์ธ์ง ์ ๋ง ๊ถ๊ธํ๋ฐ์?"}, | |
{"speaker": "๋ฏผํธ", "text": "์ค๋์ ์ต๊ทผ ๋ง์ ๋ถ๋ค์ด ๊ด์ฌ์ ๊ฐ์ง๊ณ ๊ณ์ ์ฃผ์ ์ ๋ํด ์ด์ผ๊ธฐํด๋ณผ๊น ํด์."}, | |
{"speaker": "์ค์", "text": "์, ์์ฆ ์ ๋ง ํ์ ๊ฐ ๋๊ณ ์์ฃ . ๊ตฌ์ฒด์ ์ผ๋ก ์ด๋ค ์ธก๋ฉด์ ๋ค๋ฃฐ ์์ ์ด์ ๊ฐ์?"}, | |
{"speaker": "๋ฏผํธ", "text": "๋ค, ๋จผ์ ๊ธฐ๋ณธ์ ์ธ ๊ฐ๋ ๋ถํฐ ์ฐจ๊ทผ์ฐจ๊ทผ ์ค๋ช ๋๋ฆฌ๊ณ , ์ค์ํ์ ์ด๋ป๊ฒ ์ ์ฉํ ์ ์๋์ง ์์๋ณผ๊ฒ์."}, | |
{"speaker": "์ค์", "text": "์ข์์! ์ฒญ์ทจ์๋ถ๋ค๋ ์ดํดํ๊ธฐ ์ฝ๊ฒ ์ค๋ช ํด์ฃผ์ค ๊ฑฐ์ฃ ?"}, | |
{"speaker": "๋ฏผํธ", "text": "๋ฌผ๋ก ์ด์ฃ . ์ต๋ํ ์ฝ๊ณ ์ฌ๋ฏธ์๊ฒ ํ์ด์ ์ค๋ช ๋๋ฆด๊ฒ์."}, | |
{"speaker": "์ค์", "text": "๊ทธ๋ผ ๋ณธ๊ฒฉ์ ์ผ๋ก ์์ํด๋ณผ๊น์?"}, | |
{"speaker": "๋ฏผํธ", "text": "๋ค, ์ข์ต๋๋ค. ์ฐ์ ์ด ์ฃผ์ ๊ฐ ์ ์ค์ํ์ง๋ถํฐ ๋ง์๋๋ฆด๊ฒ์."}, | |
{"speaker": "์ค์", "text": "์, ๋ง์์. ๊ทธ ๋ถ๋ถ์ด ์ ๋ง ์ค์ํ์ฃ ."}, | |
{"speaker": "๋ฏผํธ", "text": "์ต๊ทผ ์ฐ๊ตฌ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด ์ ๋ง ๋๋ผ์ด ๋ฐ๊ฒฌ๋ค์ด ๋ง์์ด์."} | |
] | |
} | |
def _get_default_english_conversation(self) -> Dict: | |
"""Enhanced default English conversation template""" | |
return { | |
"conversation": [ | |
{"speaker": "Alex", "text": "Welcome everyone to our podcast! We have a fascinating topic today."}, | |
{"speaker": "Jordan", "text": "Thanks, Alex. I'm excited to dive into this subject with our listeners."}, | |
{"speaker": "Alex", "text": "So, what makes this topic particularly relevant right now?"}, | |
{"speaker": "Jordan", "text": "Well, there have been some significant developments recently that everyone should know about."}, | |
{"speaker": "Alex", "text": "Interesting! Can you break it down for us?"}, | |
{"speaker": "Jordan", "text": "Absolutely. Let me start with the basics and build from there."}, | |
{"speaker": "Alex", "text": "That sounds perfect. Our listeners will appreciate that approach."}, | |
{"speaker": "Jordan", "text": "So, first, let's understand what we're really talking about here."}, | |
{"speaker": "Alex", "text": "Right, the fundamentals are crucial."}, | |
{"speaker": "Jordan", "text": "Exactly. And once we grasp that, the rest becomes much clearer."}, | |
{"speaker": "Alex", "text": "I'm already learning something new! What's next?"}, | |
{"speaker": "Jordan", "text": "Now, here's where it gets really interesting..."} | |
] | |
} | |
def extract_conversation_api(self, text: str, language: str = "English") -> Dict: | |
"""Extract conversation using API with enhanced radio style""" | |
if not self.llm_client: | |
raise RuntimeError("API mode not initialized") | |
try: | |
# ๊ฒ์ ์ปจํ ์คํธ ์์ฑ | |
search_context = "" | |
if BRAVE_KEY and not text.startswith("Keyword-based content:"): | |
try: | |
keywords = extract_keywords_for_search(text, language) | |
if keywords: | |
search_query = keywords[0] if language == "Korean" else f"{keywords[0]} latest news" | |
search_context = format_search_results(search_query) | |
print(f"Search context added for: {search_query}") | |
except Exception as e: | |
print(f"Search failed, continuing without context: {e}") | |
# ๊ฐํ๋ ๋ผ๋์ค ์คํ์ผ ํ๋กฌํํธ | |
if language == "Korean": | |
system_message = ( | |
"๋น์ ์ ํ๊ตญ์ ์ธ๊ธฐ ๋ผ๋์ค ๋๋ด ํ๋ก๊ทธ๋จ ์๊ฐ์ ๋๋ค. " | |
"์ค์ ๋ผ๋์ค ๋ฐฉ์ก์ฒ๋ผ ์์ฐ์ค๋ฝ๊ณ ํธ์ํ ๋ํ๋ฅผ ๋ง๋์ธ์.\n" | |
"์ค์(์งํ์)๋ ์ฃผ๋ก ์งง์ ์ง๋ฌธ๊ณผ ๋ฆฌ์ก์ ์ผ๋ก ๋ํ๋ฅผ ์ด๋๊ณ , " | |
"๋ฏผํธ(์ ๋ฌธ๊ฐ)๋ 1-2๋ฌธ์ฅ์ผ๋ก ๊ฐ๊ฒฐํ๊ฒ ๋ต๋ณํฉ๋๋ค. " | |
"๊ตฌ์ด์ฒด์ ์ถ์์๋ฅผ ์ฌ์ฉํ๊ณ , ๋ฐ๋์ ์๋ก ์กด๋๋ง์ ์ฌ์ฉํ์ธ์. " | |
"12-15ํ์ ๋ํ ๊ตํ์ผ๋ก ๊ตฌ์ฑํ์ธ์." | |
) | |
else: | |
system_message = ( | |
"You are a professional radio talk show scriptwriter. " | |
"Create natural, engaging dialogue like a real radio broadcast. " | |
"Alex (host) mainly asks short questions and gives reactions, " | |
"while Jordan (expert) answers in 1-2 concise sentences. " | |
"Use conversational language with natural fillers. " | |
"Create 12-15 conversation exchanges." | |
) | |
chat_completion = self.llm_client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": self._build_prompt(text, language, search_context)} | |
], | |
model=self.config.api_model_name, | |
temperature=0.85, | |
) | |
pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}" | |
json_match = re.search(pattern, chat_completion.choices[0].message.content) | |
if not json_match: | |
raise ValueError("No valid JSON found in response") | |
return json.loads(json_match.group()) | |
except Exception as e: | |
raise RuntimeError(f"Failed to extract conversation: {e}") | |
def parse_conversation_text(self, conversation_text: str) -> Dict: | |
"""Parse conversation text back to JSON format""" | |
lines = conversation_text.strip().split('\n') | |
conversation_data = {"conversation": []} | |
for line in lines: | |
if ':' in line: | |
speaker, text = line.split(':', 1) | |
conversation_data["conversation"].append({ | |
"speaker": speaker.strip(), | |
"text": text.strip() | |
}) | |
return conversation_data | |
async def text_to_speech_edge(self, conversation_json: Dict, language: str = "English") -> Tuple[str, str]: | |
"""Convert text to speech using Edge TTS""" | |
output_dir = Path(self._create_output_directory()) | |
filenames = [] | |
try: | |
# ์ธ์ด๋ณ ์์ฑ ์ค์ - ํ๊ตญ์ด๋ ๋ชจ๋ ๋จ์ฑ ์์ฑ | |
if language == "Korean": | |
voices = [ | |
"ko-KR-HyunsuNeural", # ๋จ์ฑ ์์ฑ 1 (์ฐจ๋ถํ๊ณ ์ ๋ขฐ๊ฐ ์๋) | |
"ko-KR-InJoonNeural" # ๋จ์ฑ ์์ฑ 2 (ํ๊ธฐ์ฐจ๊ณ ์น๊ทผํ) | |
] | |
else: | |
voices = [ | |
"en-US-AndrewMultilingualNeural", # ๋จ์ฑ ์์ฑ 1 | |
"en-US-BrianMultilingualNeural" # ๋จ์ฑ ์์ฑ 2 | |
] | |
for i, turn in enumerate(conversation_json["conversation"]): | |
filename = output_dir / f"output_{i}.wav" | |
voice = voices[i % len(voices)] | |
tmp_path = await self._generate_audio_edge(turn["text"], voice) | |
os.rename(tmp_path, filename) | |
filenames.append(str(filename)) | |
# Combine audio files | |
final_output = os.path.join(output_dir, "combined_output.wav") | |
self._combine_audio_files(filenames, final_output) | |
# Generate conversation text | |
conversation_text = "\n".join( | |
f"{turn.get('speaker', f'Speaker {i+1}')}: {turn['text']}" | |
for i, turn in enumerate(conversation_json["conversation"]) | |
) | |
return final_output, conversation_text | |
except Exception as e: | |
raise RuntimeError(f"Failed to convert text to speech: {e}") | |
async def _generate_audio_edge(self, text: str, voice: str) -> str: | |
"""Generate audio using Edge TTS""" | |
if not text.strip(): | |
raise ValueError("Text cannot be empty") | |
voice_short_name = voice.split(" - ")[0] if " - " in voice else voice | |
communicate = edge_tts.Communicate(text, voice_short_name) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
tmp_path = tmp_file.name | |
await communicate.save(tmp_path) | |
return tmp_path | |
def text_to_speech_spark(self, conversation_json: Dict, language: str = "English", progress=None) -> Tuple[str, str]: | |
"""Convert text to speech using Spark TTS CLI""" | |
if not SPARK_AVAILABLE or not self.spark_model_dir: | |
raise RuntimeError("Spark TTS not available") | |
try: | |
output_dir = self._create_output_directory() | |
audio_files = [] | |
# Create different voice characteristics for different speakers | |
if language == "Korean": | |
voice_configs = [ | |
{"prompt_text": "์๋ ํ์ธ์, ์ค๋ ํ์บ์คํธ ์งํ์ ๋งก์ ์ค์์ ๋๋ค. ์ฌ๋ฌ๋ถ๊ณผ ํจ๊ป ํฅ๋ฏธ๋ก์ด ์ด์ผ๊ธฐ๋ฅผ ๋๋ ๋ณด๊ฒ ์ต๋๋ค.", "gender": "male"}, | |
{"prompt_text": "์๋ ํ์ธ์, ์ ๋ ์ค๋ ์ด ์ฃผ์ ์ ๋ํด ์ค๋ช ๋๋ฆด ๋ฏผํธ์ ๋๋ค. ์ฝ๊ณ ์ฌ๋ฏธ์๊ฒ ์ค๋ช ๋๋ฆด๊ฒ์.", "gender": "male"} | |
] | |
else: | |
voice_configs = [ | |
{"prompt_text": "Hello everyone, I'm Alex, your host for today's podcast. Let's explore this fascinating topic together.", "gender": "male"}, | |
{"prompt_text": "Hi, I'm Jordan. I'm excited to share my insights on this subject with you all today.", "gender": "male"} | |
] | |
for i, turn in enumerate(conversation_json["conversation"]): | |
text = turn["text"] | |
if not text.strip(): | |
continue | |
# Use different voice config for each speaker | |
voice_config = voice_configs[i % len(voice_configs)] | |
output_file = os.path.join(output_dir, f"spark_output_{i}.wav") | |
# Run Spark TTS CLI inference | |
cmd = [ | |
"python", "-m", "cli.inference", | |
"--text", text, | |
"--device", "0" if torch.cuda.is_available() else "cpu", | |
"--save_dir", output_dir, | |
"--model_dir", self.spark_model_dir, | |
"--prompt_text", voice_config["prompt_text"], | |
"--output_name", f"spark_output_{i}.wav" | |
] | |
try: | |
# Run the command | |
result = subprocess.run( | |
cmd, | |
capture_output=True, | |
text=True, | |
timeout=60, | |
cwd="." # Make sure we're in the right directory | |
) | |
if result.returncode == 0: | |
audio_files.append(output_file) | |
else: | |
print(f"Spark TTS error for turn {i}: {result.stderr}") | |
# Create a short silence as fallback | |
silence = np.zeros(int(22050 * 1.0)) # 1 second of silence | |
sf.write(output_file, silence, 22050) | |
audio_files.append(output_file) | |
except subprocess.TimeoutExpired: | |
print(f"Spark TTS timeout for turn {i}") | |
# Create silence as fallback | |
silence = np.zeros(int(22050 * 1.0)) | |
sf.write(output_file, silence, 22050) | |
audio_files.append(output_file) | |
except Exception as e: | |
print(f"Error running Spark TTS for turn {i}: {e}") | |
# Create silence as fallback | |
silence = np.zeros(int(22050 * 1.0)) | |
sf.write(output_file, silence, 22050) | |
audio_files.append(output_file) | |
# Combine all audio files | |
if audio_files: | |
final_output = os.path.join(output_dir, "spark_combined.wav") | |
self._combine_audio_files(audio_files, final_output) | |
else: | |
raise RuntimeError("No audio files generated") | |
# Generate conversation text | |
conversation_text = "\n".join( | |
f"{turn.get('speaker', f'Speaker {i+1}')}: {turn['text']}" | |
for i, turn in enumerate(conversation_json["conversation"]) | |
) | |
return final_output, conversation_text | |
except Exception as e: | |
raise RuntimeError(f"Failed to convert text to speech with Spark TTS: {e}") | |
def text_to_speech_melo(self, conversation_json: Dict, progress=None) -> Tuple[str, str]: | |
"""Convert text to speech using MeloTTS""" | |
if not MELO_AVAILABLE or not self.melo_models: | |
raise RuntimeError("MeloTTS not available") | |
speakers = ["EN-Default", "EN-US"] | |
combined_audio = AudioSegment.empty() | |
for i, turn in enumerate(conversation_json["conversation"]): | |
bio = io.BytesIO() | |
text = turn["text"] | |
speaker = speakers[i % 2] | |
speaker_id = self.melo_models["EN"].hps.data.spk2id[speaker] | |
# Generate audio | |
self.melo_models["EN"].tts_to_file( | |
text, speaker_id, bio, speed=1.0, | |
pbar=progress.tqdm if progress else None, | |
format="wav" | |
) | |
bio.seek(0) | |
audio_segment = AudioSegment.from_file(bio, format="wav") | |
combined_audio += audio_segment | |
# Save final audio | |
final_audio_path = "melo_podcast.mp3" | |
combined_audio.export(final_audio_path, format="mp3") | |
# Generate conversation text | |
conversation_text = "\n".join( | |
f"{turn.get('speaker', f'Speaker {i+1}')}: {turn['text']}" | |
for i, turn in enumerate(conversation_json["conversation"]) | |
) | |
return final_audio_path, conversation_text | |
def _create_output_directory(self) -> str: | |
"""Create a unique output directory""" | |
random_bytes = os.urandom(8) | |
folder_name = base64.urlsafe_b64encode(random_bytes).decode("utf-8") | |
os.makedirs(folder_name, exist_ok=True) | |
return folder_name | |
def _combine_audio_files(self, filenames: List[str], output_file: str) -> None: | |
"""Combine multiple audio files into one""" | |
if not filenames: | |
raise ValueError("No input files provided") | |
try: | |
audio_segments = [] | |
for filename in filenames: | |
if os.path.exists(filename): | |
audio_segment = AudioSegment.from_file(filename) | |
audio_segments.append(audio_segment) | |
if audio_segments: | |
combined = sum(audio_segments) | |
combined.export(output_file, format="wav") | |
# Clean up temporary files | |
for filename in filenames: | |
if os.path.exists(filename): | |
os.remove(filename) | |
except Exception as e: | |
raise RuntimeError(f"Failed to combine audio files: {e}") | |
# Global converter instance | |
converter = UnifiedAudioConverter(ConversationConfig()) | |
async def synthesize(article_input, input_type: str = "URL", mode: str = "Local", tts_engine: str = "Edge-TTS", language: str = "English"): | |
"""Main synthesis function - handles URL, PDF, and Keyword inputs""" | |
try: | |
# Extract text based on input type | |
if input_type == "URL": | |
if not article_input or not isinstance(article_input, str): | |
return "Please provide a valid URL.", None | |
text = converter.fetch_text(article_input) | |
elif input_type == "PDF": | |
if not article_input: | |
return "Please upload a PDF file.", None | |
text = converter.extract_text_from_pdf(article_input) | |
else: # Keyword | |
if not article_input or not isinstance(article_input, str): | |
return "Please provide a keyword or topic.", None | |
# ํค์๋๋ก ๊ฒ์ํ์ฌ ์ฝํ ์ธ ์์ฑ | |
text = search_and_compile_content(article_input, language) | |
text = f"Keyword-based content:\n{text}" # ๋ง์ปค ์ถ๊ฐ | |
# Limit text to max words | |
words = text.split() | |
if len(words) > converter.config.max_words: | |
text = " ".join(words[:converter.config.max_words]) | |
# Extract conversation based on mode | |
if mode == "Local": | |
# ๋ก์ปฌ ๋ชจ๋๊ฐ ๊ธฐ๋ณธ (์๋ก์ด Local LLM ์ฌ์ฉ) | |
try: | |
conversation_json = converter.extract_conversation_local(text, language) | |
except Exception as e: | |
print(f"Local mode failed: {e}, trying API fallback") | |
# API ํด๋ฐฑ | |
api_key = os.environ.get("TOGETHER_API_KEY") | |
if api_key: | |
converter.initialize_api_mode(api_key) | |
conversation_json = converter.extract_conversation_api(text, language) | |
else: | |
raise RuntimeError("Local mode failed and no API key available for fallback") | |
else: # API mode (now secondary) | |
api_key = os.environ.get("TOGETHER_API_KEY") | |
if not api_key: | |
print("API key not found, falling back to local mode") | |
conversation_json = converter.extract_conversation_local(text, language) | |
else: | |
try: | |
converter.initialize_api_mode(api_key) | |
conversation_json = converter.extract_conversation_api(text, language) | |
except Exception as e: | |
print(f"API mode failed: {e}, falling back to local mode") | |
conversation_json = converter.extract_conversation_local(text, language) | |
# Generate conversation text | |
conversation_text = "\n".join( | |
f"{turn.get('speaker', f'Speaker {i+1}')}: {turn['text']}" | |
for i, turn in enumerate(conversation_json["conversation"]) | |
) | |
return conversation_text, None | |
except Exception as e: | |
return f"Error: {str(e)}", None | |
async def regenerate_audio(conversation_text: str, tts_engine: str = "Edge-TTS", language: str = "English"): | |
"""Regenerate audio from edited conversation text""" | |
if not conversation_text.strip(): | |
return "Please provide conversation text.", None | |
try: | |
# Parse the conversation text back to JSON format | |
conversation_json = converter.parse_conversation_text(conversation_text) | |
if not conversation_json["conversation"]: | |
return "No valid conversation found in the text.", None | |
# ํ๊ตญ์ด์ธ ๊ฒฝ์ฐ Edge-TTS๋ง ์ฌ์ฉ (๋ค๋ฅธ TTS๋ ํ๊ตญ์ด ์ง์์ด ์ ํ์ ) | |
if language == "Korean" and tts_engine != "Edge-TTS": | |
tts_engine = "Edge-TTS" # ์๋์ผ๋ก Edge-TTS๋ก ๋ณ๊ฒฝ | |
# Generate audio based on TTS engine | |
if tts_engine == "Edge-TTS": | |
output_file, _ = await converter.text_to_speech_edge(conversation_json, language) | |
elif tts_engine == "Spark-TTS": | |
if not SPARK_AVAILABLE: | |
return "Spark TTS not available. Please install required dependencies and clone the Spark-TTS repository.", None | |
converter.initialize_spark_tts() | |
output_file, _ = converter.text_to_speech_spark(conversation_json, language) | |
else: # MeloTTS | |
if not MELO_AVAILABLE: | |
return "MeloTTS not available. Please install required dependencies.", None | |
if language == "Korean": | |
return "MeloTTS does not support Korean. Please use Edge-TTS for Korean.", None | |
converter.initialize_melo_tts() | |
output_file, _ = converter.text_to_speech_melo(conversation_json) | |
return "Audio generated successfully!", output_file | |
except Exception as e: | |
return f"Error generating audio: {str(e)}", None | |
def synthesize_sync(article_input, input_type: str = "URL", mode: str = "Local", tts_engine: str = "Edge-TTS", language: str = "English"): | |
"""Synchronous wrapper for async synthesis""" | |
return asyncio.run(synthesize(article_input, input_type, mode, tts_engine, language)) | |
def regenerate_audio_sync(conversation_text: str, tts_engine: str = "Edge-TTS", language: str = "English"): | |
"""Synchronous wrapper for async audio regeneration""" | |
return asyncio.run(regenerate_audio(conversation_text, tts_engine, language)) | |
def update_tts_engine_for_korean(language): | |
"""ํ๊ตญ์ด ์ ํ ์ TTS ์์ง ์ต์ ์ ๋ฐ์ดํธ""" | |
if language == "Korean": | |
return gr.Radio( | |
choices=["Edge-TTS"], | |
value="Edge-TTS", | |
label="TTS Engine", | |
info="ํ๊ตญ์ด๋ Edge-TTS๋ง ์ง์๋ฉ๋๋ค", | |
interactive=False | |
) | |
else: | |
return gr.Radio( | |
choices=["Edge-TTS", "Spark-TTS", "MeloTTS"], | |
value="Edge-TTS", | |
label="TTS Engine", | |
info="Edge-TTS: Cloud-based, natural voices | Spark-TTS: Local AI model | MeloTTS: Local, requires GPU", | |
interactive=True | |
) | |
def toggle_input_visibility(input_type): | |
"""Toggle visibility of URL input, file upload, and keyword input based on input type""" | |
if input_type == "URL": | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
elif input_type == "PDF": | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
else: # Keyword | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
# ๋ชจ๋ธ ์ด๊ธฐํ (์ฑ ์์ ์) | |
if LLAMA_CPP_AVAILABLE: | |
try: | |
model_path = hf_hub_download( | |
repo_id=converter.config.local_model_repo, | |
filename=converter.config.local_model_name, | |
local_dir="./models" | |
) | |
print(f"Model downloaded to: {model_path}") | |
except Exception as e: | |
print(f"Failed to download model at startup: {e}") | |
# Gradio Interface | |
with gr.Blocks(theme='soft', title="AI Podcast Generator") as demo: | |
gr.Markdown("# ๐๏ธ AI Podcast Generator") | |
gr.Markdown("Convert any article, blog, PDF document, or topic into an engaging podcast conversation!") | |
# ์๋จ์ ๋ก์ปฌ LLM ์ํ ํ์ | |
with gr.Row(): | |
gr.Markdown(f""" | |
### ๐ค Enhanced Configuration: | |
- **Primary**: Local LLM ({converter.config.local_model_name}) - Runs on your device | |
- **Fallback**: API LLM ({converter.config.api_model_name}) - Used when local fails | |
- **Status**: {"โ Llama CPP Available" if LLAMA_CPP_AVAILABLE else "โ Llama CPP Not Available - Install llama-cpp-python"} | |
- **Conversation Length**: {converter.config.min_conversation_turns}-{converter.config.max_conversation_turns} exchanges (1.5x longer) | |
- **Search**: {"โ Brave Search Enabled" if BRAVE_KEY else "โ Brave Search Not Available - Set BSEARCH_API"} | |
- **New**: ๐ฏ Keyword input for topic-based podcast generation | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
# Input type selector - ํค์๋ ์ต์ ์ถ๊ฐ | |
input_type_selector = gr.Radio( | |
choices=["URL", "PDF", "Keyword"], | |
value="URL", | |
label="Input Type", | |
info="Choose between URL, PDF file upload, or keyword/topic" | |
) | |
# URL input | |
url_input = gr.Textbox( | |
label="Article URL", | |
placeholder="Enter the article URL here...", | |
value="", | |
visible=True | |
) | |
# PDF upload | |
pdf_input = gr.File( | |
label="Upload PDF", | |
file_types=[".pdf"], | |
visible=False | |
) | |
# Keyword input (์๋ก ์ถ๊ฐ) | |
keyword_input = gr.Textbox( | |
label="Topic/Keyword", | |
placeholder="Enter a topic or keyword (e.g., 'AI trends', '์ธ๊ณต์ง๋ฅ ์ต์ ๋ํฅ')", | |
value="", | |
visible=False, | |
info="The system will search for latest information about this topic" | |
) | |
with gr.Column(scale=1): | |
# ์ธ์ด ์ ํ | |
language_selector = gr.Radio( | |
choices=["English", "Korean"], | |
value="English", | |
label="Language / ์ธ์ด", | |
info="Select output language / ์ถ๋ ฅ ์ธ์ด๋ฅผ ์ ํํ์ธ์" | |
) | |
mode_selector = gr.Radio( | |
choices=["Local", "API"], | |
value="Local", | |
label="Processing Mode", | |
info="Local: Runs on device (Primary) | API: Cloud-based (Fallback)" | |
) | |
# TTS ์์ง ์ ํ | |
with gr.Group(): | |
gr.Markdown("### TTS Engine Selection") | |
tts_selector = gr.Radio( | |
choices=["Edge-TTS", "Spark-TTS", "MeloTTS"], | |
value="Edge-TTS", | |
label="TTS Engine", | |
info="Edge-TTS: Cloud-based, natural voices | Spark-TTS: Local AI model | MeloTTS: Local, requires GPU" | |
) | |
gr.Markdown(""" | |
**๐ป Radio Talk Show Style:** | |
- Natural, conversational dialogue | |
- Host asks short questions | |
- Expert gives brief, clear answers | |
- 12-15 conversation exchanges | |
**๐ Keyword Feature:** | |
- Enter any topic to generate a podcast | |
- Automatically searches latest information | |
- Creates engaging discussion from search results | |
**๐ฐ๐ท ํ๊ตญ์ด ์ง์:** | |
- ์์ฐ์ค๋ฌ์ด ๋ผ๋์ค ๋๋ด ์คํ์ผ | |
- ์งํ์(์ค์)๊ฐ ์งง์ ์ง๋ฌธ์ผ๋ก ๋ํ ์ ๋ | |
- ์ ๋ฌธ๊ฐ(๋ฏผํธ)๊ฐ ๊ฐ๊ฒฐํ๊ฒ ๋ต๋ณ | |
- ์ต์ ์ ๋ณด ์๋ ๊ฒ์ ๋ฐ ๋ฐ์ | |
""") | |
convert_btn = gr.Button("๐ฏ Generate Conversation / ๋ํ ์์ฑ", variant="primary", size="lg") | |
with gr.Row(): | |
with gr.Column(): | |
conversation_output = gr.Textbox( | |
label="Generated Conversation (Editable) / ์์ฑ๋ ๋ํ (ํธ์ง ๊ฐ๋ฅ)", | |
lines=30, # ๋ ๊ธด ๋ํ๋ฅผ ์ํด ์ฆ๊ฐ | |
max_lines=60, | |
interactive=True, | |
placeholder="Generated conversation will appear here. You can edit it before generating audio.\n์์ฑ๋ ๋ํ๊ฐ ์ฌ๊ธฐ์ ํ์๋ฉ๋๋ค. ์ค๋์ค ์์ฑ ์ ์ ํธ์งํ ์ ์์ต๋๋ค.\n\n๋ผ๋์ค ๋๋ด ์คํ์ผ๋ก ์์ฐ์ค๋ฝ๊ฒ ์งํ๋ฉ๋๋ค.", | |
info="Edit the conversation as needed. Format: 'Speaker Name: Text' / ํ์์ ๋ฐ๋ผ ๋ํ๋ฅผ ํธ์งํ์ธ์. ํ์: 'ํ์ ์ด๋ฆ: ํ ์คํธ'" | |
) | |
# ์ค๋์ค ์์ฑ ๋ฒํผ ์ถ๊ฐ | |
with gr.Row(): | |
generate_audio_btn = gr.Button("๐๏ธ Generate Audio from Text / ํ ์คํธ์์ ์ค๋์ค ์์ฑ", variant="secondary", size="lg") | |
gr.Markdown("*Edit the conversation above, then click to generate audio / ์์ ๋ํ๋ฅผ ํธ์งํ ํ ํด๋ฆญํ์ฌ ์ค๋์ค๋ฅผ ์์ฑํ์ธ์*") | |
with gr.Column(): | |
audio_output = gr.Audio( | |
label="Podcast Audio / ํ์บ์คํธ ์ค๋์ค", | |
type="filepath", | |
interactive=False | |
) | |
# ์ํ ๋ฉ์์ง ์ถ๊ฐ | |
status_output = gr.Textbox( | |
label="Status / ์ํ", | |
interactive=False, | |
visible=True | |
) | |
gr.Examples( | |
examples=[ | |
["https://huggingface.co/blog/openfree/cycle-navigator", "URL", "Local", "Edge-TTS", "English"], | |
["", "Keyword", "Local", "Edge-TTS", "English"], # Keyword example | |
["https://huggingface.co/papers/2505.14810", "URL", "Local", "Edge-TTS", "Korean"], | |
["", "Keyword", "Local", "Edge-TTS", "Korean"], # Korean keyword example | |
], | |
inputs=[url_input, input_type_selector, mode_selector, tts_selector, language_selector], | |
outputs=[conversation_output, status_output], | |
fn=synthesize_sync, | |
cache_examples=False, | |
) | |
# Input type change handler - ์์ ๋จ | |
input_type_selector.change( | |
fn=toggle_input_visibility, | |
inputs=[input_type_selector], | |
outputs=[url_input, pdf_input, keyword_input] | |
) | |
# ์ธ์ด ๋ณ๊ฒฝ ์ TTS ์์ง ์ต์ ์ ๋ฐ์ดํธ | |
language_selector.change( | |
fn=update_tts_engine_for_korean, | |
inputs=[language_selector], | |
outputs=[tts_selector] | |
) | |
# ์ด๋ฒคํธ ์ฐ๊ฒฐ - ์์ ๋ ๋ถ๋ถ | |
def get_article_input(input_type, url_input, pdf_input, keyword_input): | |
"""Get the appropriate input based on input type""" | |
if input_type == "URL": | |
return url_input | |
elif input_type == "PDF": | |
return pdf_input | |
else: # Keyword | |
return keyword_input | |
convert_btn.click( | |
fn=lambda input_type, url_input, pdf_input, keyword_input, mode, tts, lang: synthesize_sync( | |
get_article_input(input_type, url_input, pdf_input, keyword_input), input_type, mode, tts, lang | |
), | |
inputs=[input_type_selector, url_input, pdf_input, keyword_input, mode_selector, tts_selector, language_selector], | |
outputs=[conversation_output, status_output] | |
) | |
generate_audio_btn.click( | |
fn=regenerate_audio_sync, | |
inputs=[conversation_output, tts_selector, language_selector], | |
outputs=[status_output, audio_output] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.queue(api_open=True, default_concurrency_limit=10).launch( | |
show_api=True, | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |