Spaces:
Paused
Paused
up1
Browse files
app.py
CHANGED
@@ -14,15 +14,14 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
14 |
if HF_TOKEN:
|
15 |
login(token=HF_TOKEN)
|
16 |
|
17 |
-
# —
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
|
20 |
# — Modelle laden —
|
21 |
-
print("Loading SNAC model
|
22 |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
|
23 |
|
24 |
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
25 |
-
print("Downloading Orpheus weights (konfig + safetensors)…")
|
26 |
snapshot_download(
|
27 |
repo_id=model_name,
|
28 |
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
|
@@ -33,112 +32,120 @@ snapshot_download(
|
|
33 |
]
|
34 |
)
|
35 |
|
36 |
-
print("Loading Orpheus model
|
37 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
38 |
model.config.pad_token_id = model.config.eos_token_id
|
39 |
-
|
40 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
41 |
|
42 |
-
# —
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
def process_prompt(text: str, voice: str):
|
44 |
-
"""Erzeuge input_ids und attention_mask für einen Prompt."""
|
45 |
prompt = f"{voice}: {text}"
|
46 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
47 |
-
start = torch.tensor([[
|
48 |
end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
|
49 |
-
ids
|
50 |
-
mask
|
51 |
return ids, mask
|
52 |
|
53 |
-
def
|
54 |
-
|
55 |
-
token_to_find = 128257
|
56 |
-
token_to_remove = 128258
|
57 |
-
|
58 |
-
idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
|
59 |
-
if idxs.numel() > 0:
|
60 |
-
cropped = generated_ids[:, idxs[-1].item() + 1 :]
|
61 |
-
else:
|
62 |
-
cropped = generated_ids
|
63 |
-
|
64 |
-
row = cropped[0]
|
65 |
-
row = row[row != token_to_remove]
|
66 |
-
return row.tolist()
|
67 |
-
|
68 |
-
def redistribute_codes(code_list: list[int]) -> bytes:
|
69 |
-
"""Verteile die Codes auf die drei SNAC-Layer und dekodiere zu PCM16-Bytes."""
|
70 |
l1, l2, l3 = [], [], []
|
71 |
-
for i in range(
|
72 |
-
|
73 |
-
l1.append(
|
74 |
-
l2.append(
|
75 |
-
l3.append(
|
76 |
-
l3.append(
|
77 |
-
l2.append(
|
78 |
-
l3.append(
|
79 |
-
l3.append(
|
80 |
-
|
81 |
-
dev = next(snac.parameters()).device
|
82 |
codes = [
|
83 |
torch.tensor(l1, device=dev).unsqueeze(0),
|
84 |
torch.tensor(l2, device=dev).unsqueeze(0),
|
85 |
torch.tensor(l3, device=dev).unsqueeze(0),
|
86 |
]
|
87 |
-
audio =
|
88 |
-
|
89 |
-
return pcm16
|
90 |
|
91 |
-
# — FastAPI
|
92 |
app = FastAPI()
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
@app.websocket("/ws/tts")
|
95 |
async def tts_ws(ws: WebSocket):
|
96 |
await ws.accept()
|
97 |
try:
|
98 |
while True:
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
text
|
103 |
-
voice = data.get("voice", "Jakob")
|
104 |
-
|
105 |
-
# 2) Prompt → IDs/Mask
|
106 |
ids, mask = process_prompt(text, voice)
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
except WebSocketDisconnect:
|
136 |
print("Client disconnected")
|
137 |
except Exception as e:
|
138 |
print("Error in /ws/tts:", e)
|
139 |
-
# Schließe erst, nachdem Fehler gemeldet
|
140 |
await ws.close(code=1011)
|
141 |
|
|
|
142 |
if __name__ == "__main__":
|
143 |
import uvicorn
|
144 |
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
|
|
14 |
if HF_TOKEN:
|
15 |
login(token=HF_TOKEN)
|
16 |
|
17 |
+
# — Gerät wählen —
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
|
20 |
# — Modelle laden —
|
21 |
+
print("Loading SNAC model...")
|
22 |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
|
23 |
|
24 |
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
|
|
25 |
snapshot_download(
|
26 |
repo_id=model_name,
|
27 |
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
|
|
|
32 |
]
|
33 |
)
|
34 |
|
35 |
+
print("Loading Orpheus model...")
|
36 |
+
model = AutoModelForCausalLM.from_pretrained(
|
37 |
+
model_name, torch_dtype=torch.bfloat16
|
38 |
+
).to(device)
|
39 |
model.config.pad_token_id = model.config.eos_token_id
|
|
|
40 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
41 |
|
42 |
+
# — Konstanten für Token‑Mapping —
|
43 |
+
AUDIO_TOKEN_OFFSET = 128266
|
44 |
+
START_TOKEN = 128259
|
45 |
+
SOS_TOKEN = 128257
|
46 |
+
EOS_TOKEN = 128258
|
47 |
+
|
48 |
+
# — Hilfsfunktionen —
|
49 |
def process_prompt(text: str, voice: str):
|
|
|
50 |
prompt = f"{voice}: {text}"
|
51 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
52 |
+
start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
|
53 |
end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
|
54 |
+
ids = torch.cat([start, input_ids, end], dim=1)
|
55 |
+
mask = torch.ones_like(ids, dtype=torch.int64, device=device)
|
56 |
return ids, mask
|
57 |
|
58 |
+
def redistribute_codes(block: list[int], snac_model: SNAC):
|
59 |
+
# exakt wie vorher: 7 Codes → 3 Layer → SNAC.decode → NumPy float32 @24 kHz
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
l1, l2, l3 = [], [], []
|
61 |
+
for i in range(len(block)//7):
|
62 |
+
b = block[7*i:7*i+7]
|
63 |
+
l1.append(b[0])
|
64 |
+
l2.append(b[1] - 4096)
|
65 |
+
l3.append(b[2] - 2*4096)
|
66 |
+
l3.append(b[3] - 3*4096)
|
67 |
+
l2.append(b[4] - 4*4096)
|
68 |
+
l3.append(b[5] - 5*4096)
|
69 |
+
l3.append(b[6] - 6*4096)
|
70 |
+
dev = next(snac_model.parameters()).device
|
|
|
71 |
codes = [
|
72 |
torch.tensor(l1, device=dev).unsqueeze(0),
|
73 |
torch.tensor(l2, device=dev).unsqueeze(0),
|
74 |
torch.tensor(l3, device=dev).unsqueeze(0),
|
75 |
]
|
76 |
+
audio = snac_model.decode(codes) # → Tensor[1, T]
|
77 |
+
return audio.squeeze().cpu().numpy()
|
|
|
78 |
|
79 |
+
# — FastAPI Setup —
|
80 |
app = FastAPI()
|
81 |
|
82 |
+
# 1) Hello‑World Endpoint
|
83 |
+
@app.get("/")
|
84 |
+
async def root():
|
85 |
+
return {"message": "Hallo Welt"}
|
86 |
+
|
87 |
+
# 2) WebSocket Token‑für‑Token TTS
|
88 |
@app.websocket("/ws/tts")
|
89 |
async def tts_ws(ws: WebSocket):
|
90 |
await ws.accept()
|
91 |
try:
|
92 |
while True:
|
93 |
+
# JSON mit Text & Voice empfangen
|
94 |
+
raw = await ws.receive_text()
|
95 |
+
req = json.loads(raw)
|
96 |
+
text, voice = req.get("text", ""), req.get("voice", "Jakob")
|
|
|
|
|
|
|
97 |
ids, mask = process_prompt(text, voice)
|
98 |
|
99 |
+
past_kv = None
|
100 |
+
collected = []
|
101 |
+
|
102 |
+
# im Sampling‑Loop Token für Token generieren
|
103 |
+
with torch.no_grad():
|
104 |
+
for _ in range(2000): # max 200 Tokens
|
105 |
+
out = model(
|
106 |
+
input_ids=ids if past_kv is None else None,
|
107 |
+
attention_mask=mask if past_kv is None else None,
|
108 |
+
past_key_values=past_kv,
|
109 |
+
use_cache=True,
|
110 |
+
)
|
111 |
+
logits = out.logits[:, -1, :]
|
112 |
+
next_id = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
|
113 |
+
past_kv = out.past_key_values
|
114 |
+
|
115 |
+
token = next_id.item()
|
116 |
+
# Ende
|
117 |
+
if token == EOS_TOKEN:
|
118 |
+
break
|
119 |
+
# Reset bei SOS
|
120 |
+
if token == SOS_TOKEN:
|
121 |
+
collected = []
|
122 |
+
continue
|
123 |
+
|
124 |
+
# in Audio‑Code konvertieren
|
125 |
+
collected.append(token - AUDIO_TOKEN_OFFSET)
|
126 |
+
|
127 |
+
# sobald 7 Codes → direkt dekodieren & streamen
|
128 |
+
if len(collected) >= 7:
|
129 |
+
block = collected[:7]
|
130 |
+
collected = collected[7:]
|
131 |
+
audio_np = redistribute_codes(block, snac)
|
132 |
+
pcm16 = (audio_np * 32767).astype("int16").tobytes()
|
133 |
+
await ws.send_bytes(pcm16)
|
134 |
+
|
135 |
+
# ab jetzt nur noch past_kv verwenden
|
136 |
+
ids = None
|
137 |
+
mask = None
|
138 |
+
|
139 |
+
# zum Schluss End‑Of‑Stream signalisieren
|
140 |
+
await ws.send_text(json.dumps({"event": "eos"}))
|
141 |
|
142 |
except WebSocketDisconnect:
|
143 |
print("Client disconnected")
|
144 |
except Exception as e:
|
145 |
print("Error in /ws/tts:", e)
|
|
|
146 |
await ws.close(code=1011)
|
147 |
|
148 |
+
# zum lokalen Test
|
149 |
if __name__ == "__main__":
|
150 |
import uvicorn
|
151 |
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|