Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -7,12 +7,12 @@ from huggingface_hub import login
|
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from snac import SNAC
|
9 |
|
10 |
-
# — HF‑Token & Login (
|
11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
12 |
if HF_TOKEN:
|
13 |
login(HF_TOKEN)
|
14 |
|
15 |
-
# — Device
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
app = FastAPI()
|
@@ -26,39 +26,38 @@ model = None
|
|
26 |
tokenizer = None
|
27 |
snac_model = None
|
28 |
|
29 |
-
# — Startup: SNAC & Orpheus laden —
|
30 |
@app.on_event("startup")
|
31 |
async def load_models():
|
32 |
global model, tokenizer, snac_model
|
33 |
-
|
|
|
34 |
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
|
35 |
-
|
36 |
-
|
|
|
37 |
tokenizer = AutoTokenizer.from_pretrained(REPO)
|
38 |
model = AutoModelForCausalLM.from_pretrained(
|
39 |
REPO,
|
40 |
-
device_map="auto" if device=="cuda" else None,
|
41 |
-
torch_dtype=torch.bfloat16 if device=="cuda" else None,
|
42 |
low_cpu_mem_usage=True
|
43 |
).to(device)
|
44 |
model.config.pad_token_id = model.config.eos_token_id
|
45 |
|
46 |
-
# — Marker und Offsets
|
47 |
START_TOKEN = 128259
|
48 |
END_TOKENS = [128009, 128260]
|
49 |
AUDIO_OFFSET = 128266
|
50 |
|
51 |
def process_single_prompt(prompt: str, voice: str) -> list[int]:
|
52 |
-
# Prompt
|
53 |
-
if voice and voice != "in_prompt"
|
54 |
-
|
55 |
-
else:
|
56 |
-
text = prompt
|
57 |
# Tokenize + Marker
|
58 |
-
ids = tokenizer(text, return_tensors="pt").input_ids
|
59 |
-
start = torch.tensor([[START_TOKEN]], dtype=torch.int64)
|
60 |
-
end = torch.tensor([END_TOKENS],
|
61 |
-
input_ids = torch.cat([start, ids, end], dim=1)
|
62 |
attention_mask = torch.ones_like(input_ids)
|
63 |
|
64 |
# Generieren
|
@@ -74,8 +73,8 @@ def process_single_prompt(prompt: str, voice: str) -> list[int]:
|
|
74 |
use_cache=True,
|
75 |
)
|
76 |
|
77 |
-
#
|
78 |
-
token_to_find
|
79 |
token_to_remove = 128258
|
80 |
idxs = (gen == token_to_find).nonzero(as_tuple=True)[1]
|
81 |
if idxs.numel() > 0:
|
@@ -83,16 +82,16 @@ def process_single_prompt(prompt: str, voice: str) -> list[int]:
|
|
83 |
else:
|
84 |
cropped = gen
|
85 |
|
86 |
-
# Padding entfernen
|
87 |
row = cropped[0][cropped[0] != token_to_remove]
|
88 |
-
# Aus Länge ein Vielfaches von 7 machen
|
89 |
new_len = (row.size(0) // 7) * 7
|
90 |
trimmed = row[:new_len].tolist()
|
|
|
91 |
# Offset abziehen
|
92 |
return [t - AUDIO_OFFSET for t in trimmed]
|
93 |
|
94 |
def redistribute_codes(code_list: list[int]) -> np.ndarray:
|
95 |
-
#
|
96 |
layer1, layer2, layer3 = [], [], []
|
97 |
for i in range(len(code_list) // 7):
|
98 |
b = code_list[7*i : 7*i+7]
|
@@ -112,27 +111,25 @@ def redistribute_codes(code_list: list[int]) -> np.ndarray:
|
|
112 |
audio = snac_model.decode(codes).squeeze().cpu().numpy()
|
113 |
return audio # float32 @24 kHz
|
114 |
|
115 |
-
# — WebSocket‑Endpoint für TTS —
|
116 |
@app.websocket("/ws/tts")
|
117 |
async def tts_ws(ws: WebSocket):
|
118 |
await ws.accept()
|
119 |
try:
|
120 |
-
# 1) Text + Voice empfangen
|
121 |
msg = await ws.receive_text()
|
122 |
req = json.loads(msg)
|
123 |
text = req.get("text", "")
|
124 |
voice = req.get("voice", "")
|
125 |
|
126 |
-
#
|
127 |
with torch.no_grad():
|
128 |
codes = process_single_prompt(text, voice)
|
129 |
audio_np = redistribute_codes(codes)
|
130 |
|
131 |
-
#
|
132 |
pcm16 = (audio_np * 32767).astype(np.int16).tobytes()
|
133 |
await ws.send_bytes(pcm16)
|
134 |
|
135 |
-
#
|
136 |
await ws.close()
|
137 |
|
138 |
except WebSocketDisconnect:
|
|
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from snac import SNAC
|
9 |
|
10 |
+
# — HF‑Token & Login (falls gesetzt) —
|
11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
12 |
if HF_TOKEN:
|
13 |
login(HF_TOKEN)
|
14 |
|
15 |
+
# — Device auswählen —
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
app = FastAPI()
|
|
|
26 |
tokenizer = None
|
27 |
snac_model = None
|
28 |
|
|
|
29 |
@app.on_event("startup")
|
30 |
async def load_models():
|
31 |
global model, tokenizer, snac_model
|
32 |
+
|
33 |
+
# 1) SNAC laden
|
34 |
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
|
35 |
+
|
36 |
+
# 2) Orpheus‑TTS (public “natural”-Variante)
|
37 |
+
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
38 |
tokenizer = AutoTokenizer.from_pretrained(REPO)
|
39 |
model = AutoModelForCausalLM.from_pretrained(
|
40 |
REPO,
|
41 |
+
device_map="auto" if device == "cuda" else None,
|
42 |
+
torch_dtype=torch.bfloat16 if device == "cuda" else None,
|
43 |
low_cpu_mem_usage=True
|
44 |
).to(device)
|
45 |
model.config.pad_token_id = model.config.eos_token_id
|
46 |
|
47 |
+
# — Marker und Offsets —
|
48 |
START_TOKEN = 128259
|
49 |
END_TOKENS = [128009, 128260]
|
50 |
AUDIO_OFFSET = 128266
|
51 |
|
52 |
def process_single_prompt(prompt: str, voice: str) -> list[int]:
|
53 |
+
# Prompt zusammenstellen
|
54 |
+
text = f"{voice}: {prompt}" if voice and voice != "in_prompt" else prompt
|
55 |
+
|
|
|
|
|
56 |
# Tokenize + Marker
|
57 |
+
ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
|
58 |
+
start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
|
59 |
+
end = torch.tensor([END_TOKENS], dtype=torch.int64, device=device)
|
60 |
+
input_ids = torch.cat([start, ids, end], dim=1)
|
61 |
attention_mask = torch.ones_like(input_ids)
|
62 |
|
63 |
# Generieren
|
|
|
73 |
use_cache=True,
|
74 |
)
|
75 |
|
76 |
+
# Nach letztem START_TOKEN croppen
|
77 |
+
token_to_find = 128257
|
78 |
token_to_remove = 128258
|
79 |
idxs = (gen == token_to_find).nonzero(as_tuple=True)[1]
|
80 |
if idxs.numel() > 0:
|
|
|
82 |
else:
|
83 |
cropped = gen
|
84 |
|
85 |
+
# Padding entfernen & Länge auf Vielfaches von 7 bringen
|
86 |
row = cropped[0][cropped[0] != token_to_remove]
|
|
|
87 |
new_len = (row.size(0) // 7) * 7
|
88 |
trimmed = row[:new_len].tolist()
|
89 |
+
|
90 |
# Offset abziehen
|
91 |
return [t - AUDIO_OFFSET for t in trimmed]
|
92 |
|
93 |
def redistribute_codes(code_list: list[int]) -> np.ndarray:
|
94 |
+
# 7er‑Blöcke auf 3 Layer verteilen
|
95 |
layer1, layer2, layer3 = [], [], []
|
96 |
for i in range(len(code_list) // 7):
|
97 |
b = code_list[7*i : 7*i+7]
|
|
|
111 |
audio = snac_model.decode(codes).squeeze().cpu().numpy()
|
112 |
return audio # float32 @24 kHz
|
113 |
|
|
|
114 |
@app.websocket("/ws/tts")
|
115 |
async def tts_ws(ws: WebSocket):
|
116 |
await ws.accept()
|
117 |
try:
|
|
|
118 |
msg = await ws.receive_text()
|
119 |
req = json.loads(msg)
|
120 |
text = req.get("text", "")
|
121 |
voice = req.get("voice", "")
|
122 |
|
123 |
+
# 1) Prompt → Codes → Audio
|
124 |
with torch.no_grad():
|
125 |
codes = process_single_prompt(text, voice)
|
126 |
audio_np = redistribute_codes(codes)
|
127 |
|
128 |
+
# 2) In PCM16 wandeln & senden
|
129 |
pcm16 = (audio_np * 32767).astype(np.int16).tobytes()
|
130 |
await ws.send_bytes(pcm16)
|
131 |
|
132 |
+
# 3) sauber schließen
|
133 |
await ws.close()
|
134 |
|
135 |
except WebSocketDisconnect:
|