Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ print("Loading SNAC model…")
|
|
22 |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
|
23 |
|
24 |
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
25 |
-
|
26 |
snapshot_download(
|
27 |
repo_id=model_name,
|
28 |
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
|
@@ -34,71 +34,59 @@ snapshot_download(
|
|
34 |
)
|
35 |
|
36 |
print("Loading Orpheus model…")
|
37 |
-
model = AutoModelForCausalLM.from_pretrained(
|
38 |
-
model_name, torch_dtype=torch.bfloat16
|
39 |
-
).to(device)
|
40 |
model.config.pad_token_id = model.config.eos_token_id
|
41 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
42 |
-
|
43 |
-
# — Konstanten für Audio‑Token →
|
44 |
-
# (muss übereinstimmen mit Deinem Training; hier 128266)
|
45 |
-
AUDIO_TOKEN_OFFSET = 128266
|
46 |
|
47 |
-
|
48 |
|
|
|
49 |
def process_prompt(text: str, voice: str):
|
|
|
50 |
prompt = f"{voice}: {text}"
|
51 |
-
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
ids = torch.cat([start, input_ids, end], dim=1).to(device)
|
57 |
-
mask = torch.ones_like(ids).to(device)
|
58 |
return ids, mask
|
59 |
|
60 |
-
def parse_output(generated_ids: torch.LongTensor):
|
61 |
-
"""
|
62 |
-
|
63 |
-
|
64 |
-
"""
|
65 |
-
# finde letztes Audio‑Start‑Token
|
66 |
-
token_to_start = 128257
|
67 |
-
token_to_remove = model.config.eos_token_id # 128258
|
68 |
|
69 |
-
idxs = (generated_ids ==
|
70 |
if idxs.numel() > 0:
|
71 |
-
|
72 |
-
cropped = generated_ids[:, cut:]
|
73 |
else:
|
74 |
cropped = generated_ids
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
"""
|
85 |
-
layer1, layer2, layer3 = [], [], []
|
86 |
-
for i in range(len(code_list) // 7):
|
87 |
base = code_list[7*i : 7*i+7]
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
dev = next(
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
102 |
|
103 |
# — FastAPI + WebSocket-Endpoint —
|
104 |
app = FastAPI()
|
@@ -108,41 +96,47 @@ async def tts_ws(ws: WebSocket):
|
|
108 |
await ws.accept()
|
109 |
try:
|
110 |
while True:
|
|
|
111 |
msg = await ws.receive_text()
|
112 |
data = json.loads(msg)
|
113 |
text = data.get("text", "")
|
114 |
voice = data.get("voice", "Jakob")
|
115 |
|
116 |
-
#
|
117 |
ids, mask = process_prompt(text, voice)
|
118 |
|
119 |
-
#
|
120 |
gen_ids = model.generate(
|
121 |
input_ids=ids,
|
122 |
attention_mask=mask,
|
123 |
-
max_new_tokens=
|
124 |
do_sample=True,
|
125 |
temperature=0.7,
|
126 |
top_p=0.95,
|
127 |
repetition_penalty=1.1,
|
128 |
-
eos_token_id=
|
129 |
)
|
130 |
|
131 |
-
#
|
132 |
-
|
133 |
-
|
|
|
134 |
|
135 |
-
#
|
136 |
-
|
137 |
-
|
138 |
-
for i in range(0, len(pcm16), chunk):
|
139 |
-
await ws.send_bytes(pcm16[i : i+chunk])
|
140 |
await asyncio.sleep(0.1)
|
141 |
|
|
|
|
|
|
|
|
|
|
|
142 |
except WebSocketDisconnect:
|
143 |
print("Client disconnected")
|
144 |
except Exception as e:
|
145 |
print("Error in /ws/tts:", e)
|
|
|
146 |
await ws.close(code=1011)
|
147 |
|
148 |
if __name__ == "__main__":
|
|
|
22 |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
|
23 |
|
24 |
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
25 |
+
print("Downloading Orpheus weights (konfig + safetensors)…")
|
26 |
snapshot_download(
|
27 |
repo_id=model_name,
|
28 |
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
|
|
|
34 |
)
|
35 |
|
36 |
print("Loading Orpheus model…")
|
37 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
|
|
|
|
|
38 |
model.config.pad_token_id = model.config.eos_token_id
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
41 |
|
42 |
+
# — Hilfsfunktionen —
|
43 |
def process_prompt(text: str, voice: str):
|
44 |
+
"""Erzeuge input_ids und attention_mask für einen Prompt."""
|
45 |
prompt = f"{voice}: {text}"
|
46 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
47 |
+
start = torch.tensor([[128259]], dtype=torch.int64, device=device)
|
48 |
+
end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
|
49 |
+
ids = torch.cat([start, input_ids, end], dim=1)
|
50 |
+
mask = torch.ones_like(ids)
|
|
|
|
|
51 |
return ids, mask
|
52 |
|
53 |
+
def parse_output(generated_ids: torch.LongTensor) -> list[int]:
|
54 |
+
"""Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
|
55 |
+
token_to_find = 128257
|
56 |
+
token_to_remove = 128258
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
|
59 |
if idxs.numel() > 0:
|
60 |
+
cropped = generated_ids[:, idxs[-1].item() + 1 :]
|
|
|
61 |
else:
|
62 |
cropped = generated_ids
|
63 |
|
64 |
+
row = cropped[0]
|
65 |
+
row = row[row != token_to_remove]
|
66 |
+
return row.tolist()
|
67 |
+
|
68 |
+
def redistribute_codes(code_list: list[int]) -> bytes:
|
69 |
+
"""Verteile die Codes auf die drei SNAC-Layer und dekodiere zu PCM16-Bytes."""
|
70 |
+
l1, l2, l3 = [], [], []
|
71 |
+
for i in range((len(code_list) + 1) // 7):
|
|
|
|
|
|
|
72 |
base = code_list[7*i : 7*i+7]
|
73 |
+
l1.append(base[0])
|
74 |
+
l2.append(base[1] - 4096)
|
75 |
+
l3.append(base[2] - 2*4096)
|
76 |
+
l3.append(base[3] - 3*4096)
|
77 |
+
l2.append(base[4] - 4*4096)
|
78 |
+
l3.append(base[5] - 5*4096)
|
79 |
+
l3.append(base[6] - 6*4096)
|
80 |
+
|
81 |
+
dev = next(snac.parameters()).device
|
82 |
+
codes = [
|
83 |
+
torch.tensor(l1, device=dev).unsqueeze(0),
|
84 |
+
torch.tensor(l2, device=dev).unsqueeze(0),
|
85 |
+
torch.tensor(l3, device=dev).unsqueeze(0),
|
86 |
+
]
|
87 |
+
audio = snac.decode(codes).squeeze().cpu().numpy() # float32 @24 kHz
|
88 |
+
pcm16 = (audio * 32767).astype("int16").tobytes()
|
89 |
+
return pcm16
|
90 |
|
91 |
# — FastAPI + WebSocket-Endpoint —
|
92 |
app = FastAPI()
|
|
|
96 |
await ws.accept()
|
97 |
try:
|
98 |
while True:
|
99 |
+
# 1) Nachricht empfangen
|
100 |
msg = await ws.receive_text()
|
101 |
data = json.loads(msg)
|
102 |
text = data.get("text", "")
|
103 |
voice = data.get("voice", "Jakob")
|
104 |
|
105 |
+
# 2) Prompt → IDs/Mask
|
106 |
ids, mask = process_prompt(text, voice)
|
107 |
|
108 |
+
# 3) Token-Generation
|
109 |
gen_ids = model.generate(
|
110 |
input_ids=ids,
|
111 |
attention_mask=mask,
|
112 |
+
max_new_tokens=2000,
|
113 |
do_sample=True,
|
114 |
temperature=0.7,
|
115 |
top_p=0.95,
|
116 |
repetition_penalty=1.1,
|
117 |
+
eos_token_id=128258,
|
118 |
)
|
119 |
|
120 |
+
# 4) Parse + SNAC → PCM16‑Bytes
|
121 |
+
codes = parse_output(gen_ids)
|
122 |
+
pcm16 = redistribute_codes(codes)
|
123 |
+
chunk_sz = 2400 * 2 # 0.1 s @24 kHz
|
124 |
|
125 |
+
# 5) Stream audio‑Chunks
|
126 |
+
for i in range(0, len(pcm16), chunk_sz):
|
127 |
+
await ws.send_bytes(pcm16[i : i + chunk_sz])
|
|
|
|
|
128 |
await asyncio.sleep(0.1)
|
129 |
|
130 |
+
# 6) Ende‑Signal
|
131 |
+
await ws.send_json({"event": "eos"})
|
132 |
+
|
133 |
+
# (Verbindung bleibt offen für nächste Anfrage)
|
134 |
+
|
135 |
except WebSocketDisconnect:
|
136 |
print("Client disconnected")
|
137 |
except Exception as e:
|
138 |
print("Error in /ws/tts:", e)
|
139 |
+
# Schließe erst, nachdem Fehler gemeldet
|
140 |
await ws.close(code=1011)
|
141 |
|
142 |
if __name__ == "__main__":
|