Tomtom84 commited on
Commit
d408dd5
·
1 Parent(s): b3e4aa7
Files changed (1) hide show
  1. app.py +84 -77
app.py CHANGED
@@ -14,15 +14,14 @@ HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
- # — Device
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  # — Modelle laden —
21
- print("Loading SNAC model")
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
24
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
25
- print("Downloading Orpheus weights (konfig + safetensors)…")
26
  snapshot_download(
27
  repo_id=model_name,
28
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
@@ -33,112 +32,120 @@ snapshot_download(
33
  ]
34
  )
35
 
36
- print("Loading Orpheus model")
37
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
 
 
38
  model.config.pad_token_id = model.config.eos_token_id
39
-
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
41
 
42
- # — Hilfsfunktionen
 
 
 
 
 
 
43
  def process_prompt(text: str, voice: str):
44
- """Erzeuge input_ids und attention_mask für einen Prompt."""
45
  prompt = f"{voice}: {text}"
46
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
47
- start = torch.tensor([[128259]], dtype=torch.int64, device=device)
48
  end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
49
- ids = torch.cat([start, input_ids, end], dim=1)
50
- mask = torch.ones_like(ids)
51
  return ids, mask
52
 
53
- def parse_output(generated_ids: torch.LongTensor) -> list[int]:
54
- """Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
55
- token_to_find = 128257
56
- token_to_remove = 128258
57
-
58
- idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
59
- if idxs.numel() > 0:
60
- cropped = generated_ids[:, idxs[-1].item() + 1 :]
61
- else:
62
- cropped = generated_ids
63
-
64
- row = cropped[0]
65
- row = row[row != token_to_remove]
66
- return row.tolist()
67
-
68
- def redistribute_codes(code_list: list[int]) -> bytes:
69
- """Verteile die Codes auf die drei SNAC-Layer und dekodiere zu PCM16-Bytes."""
70
  l1, l2, l3 = [], [], []
71
- for i in range((len(code_list) + 1) // 7):
72
- base = code_list[7*i : 7*i+7]
73
- l1.append(base[0])
74
- l2.append(base[1] - 4096)
75
- l3.append(base[2] - 2*4096)
76
- l3.append(base[3] - 3*4096)
77
- l2.append(base[4] - 4*4096)
78
- l3.append(base[5] - 5*4096)
79
- l3.append(base[6] - 6*4096)
80
-
81
- dev = next(snac.parameters()).device
82
  codes = [
83
  torch.tensor(l1, device=dev).unsqueeze(0),
84
  torch.tensor(l2, device=dev).unsqueeze(0),
85
  torch.tensor(l3, device=dev).unsqueeze(0),
86
  ]
87
- audio = snac.decode(codes).squeeze().cpu().numpy() # float32 @24 kHz
88
- pcm16 = (audio * 32767).astype("int16").tobytes()
89
- return pcm16
90
 
91
- # — FastAPI + WebSocket-Endpoint
92
  app = FastAPI()
93
 
 
 
 
 
 
 
94
  @app.websocket("/ws/tts")
95
  async def tts_ws(ws: WebSocket):
96
  await ws.accept()
97
  try:
98
  while True:
99
- # 1) Nachricht empfangen
100
- msg = await ws.receive_text()
101
- data = json.loads(msg)
102
- text = data.get("text", "")
103
- voice = data.get("voice", "Jakob")
104
-
105
- # 2) Prompt → IDs/Mask
106
  ids, mask = process_prompt(text, voice)
107
 
108
- # 3) Token-Generation
109
- gen_ids = model.generate(
110
- input_ids=ids,
111
- attention_mask=mask,
112
- max_new_tokens=2000,
113
- do_sample=True,
114
- temperature=0.7,
115
- top_p=0.95,
116
- repetition_penalty=1.1,
117
- eos_token_id=128258,
118
- )
119
-
120
- # 4) Parse + SNAC → PCM16‑Bytes
121
- codes = parse_output(gen_ids)
122
- pcm16 = redistribute_codes(codes)
123
- chunk_sz = 2400 * 2 # 0.1 s @24 kHz
124
-
125
- # 5) Stream audio‑Chunks
126
- for i in range(0, len(pcm16), chunk_sz):
127
- await ws.send_bytes(pcm16[i : i + chunk_sz])
128
- await asyncio.sleep(0.1)
129
-
130
- # 6) Ende‑Signal
131
- await ws.send_json({"event": "eos"})
132
-
133
- # (Verbindung bleibt offen für nächste Anfrage)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  except WebSocketDisconnect:
136
  print("Client disconnected")
137
  except Exception as e:
138
  print("Error in /ws/tts:", e)
139
- # Schließe erst, nachdem Fehler gemeldet
140
  await ws.close(code=1011)
141
 
 
142
  if __name__ == "__main__":
143
  import uvicorn
144
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
+ # — Gerät wählen
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  # — Modelle laden —
21
+ print("Loading SNAC model...")
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
24
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
 
25
  snapshot_download(
26
  repo_id=model_name,
27
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
 
32
  ]
33
  )
34
 
35
+ print("Loading Orpheus model...")
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_name, torch_dtype=torch.bfloat16
38
+ ).to(device)
39
  model.config.pad_token_id = model.config.eos_token_id
 
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
41
 
42
+ # — Konstanten für Token‑Mapping
43
+ AUDIO_TOKEN_OFFSET = 128266
44
+ START_TOKEN = 128259
45
+ SOS_TOKEN = 128257
46
+ EOS_TOKEN = 128258
47
+
48
+ # — Hilfsfunktionen —
49
  def process_prompt(text: str, voice: str):
 
50
  prompt = f"{voice}: {text}"
51
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
52
+ start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
53
  end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
54
+ ids = torch.cat([start, input_ids, end], dim=1)
55
+ mask = torch.ones_like(ids, dtype=torch.int64, device=device)
56
  return ids, mask
57
 
58
+ def redistribute_codes(block: list[int], snac_model: SNAC):
59
+ # exakt wie vorher: 7 Codes → 3 Layer → SNAC.decode → NumPy float32 @24 kHz
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  l1, l2, l3 = [], [], []
61
+ for i in range(len(block)//7):
62
+ b = block[7*i:7*i+7]
63
+ l1.append(b[0])
64
+ l2.append(b[1] - 4096)
65
+ l3.append(b[2] - 2*4096)
66
+ l3.append(b[3] - 3*4096)
67
+ l2.append(b[4] - 4*4096)
68
+ l3.append(b[5] - 5*4096)
69
+ l3.append(b[6] - 6*4096)
70
+ dev = next(snac_model.parameters()).device
 
71
  codes = [
72
  torch.tensor(l1, device=dev).unsqueeze(0),
73
  torch.tensor(l2, device=dev).unsqueeze(0),
74
  torch.tensor(l3, device=dev).unsqueeze(0),
75
  ]
76
+ audio = snac_model.decode(codes) # Tensor[1, T]
77
+ return audio.squeeze().cpu().numpy()
 
78
 
79
+ # — FastAPI Setup
80
  app = FastAPI()
81
 
82
+ # 1) Hello‑World Endpoint
83
+ @app.get("/")
84
+ async def root():
85
+ return {"message": "Hallo Welt"}
86
+
87
+ # 2) WebSocket Token‑für‑Token TTS
88
  @app.websocket("/ws/tts")
89
  async def tts_ws(ws: WebSocket):
90
  await ws.accept()
91
  try:
92
  while True:
93
+ # JSON mit Text & Voice empfangen
94
+ raw = await ws.receive_text()
95
+ req = json.loads(raw)
96
+ text, voice = req.get("text", ""), req.get("voice", "Jakob")
 
 
 
97
  ids, mask = process_prompt(text, voice)
98
 
99
+ past_kv = None
100
+ collected = []
101
+
102
+ # im Sampling‑Loop Token für Token generieren
103
+ with torch.no_grad():
104
+ for _ in range(2000): # max 200 Tokens
105
+ out = model(
106
+ input_ids=ids if past_kv is None else None,
107
+ attention_mask=mask if past_kv is None else None,
108
+ past_key_values=past_kv,
109
+ use_cache=True,
110
+ )
111
+ logits = out.logits[:, -1, :]
112
+ next_id = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
113
+ past_kv = out.past_key_values
114
+
115
+ token = next_id.item()
116
+ # Ende
117
+ if token == EOS_TOKEN:
118
+ break
119
+ # Reset bei SOS
120
+ if token == SOS_TOKEN:
121
+ collected = []
122
+ continue
123
+
124
+ # in Audio‑Code konvertieren
125
+ collected.append(token - AUDIO_TOKEN_OFFSET)
126
+
127
+ # sobald 7 Codes → direkt dekodieren & streamen
128
+ if len(collected) >= 7:
129
+ block = collected[:7]
130
+ collected = collected[7:]
131
+ audio_np = redistribute_codes(block, snac)
132
+ pcm16 = (audio_np * 32767).astype("int16").tobytes()
133
+ await ws.send_bytes(pcm16)
134
+
135
+ # ab jetzt nur noch past_kv verwenden
136
+ ids = None
137
+ mask = None
138
+
139
+ # zum Schluss End‑Of‑Stream signalisieren
140
+ await ws.send_text(json.dumps({"event": "eos"}))
141
 
142
  except WebSocketDisconnect:
143
  print("Client disconnected")
144
  except Exception as e:
145
  print("Error in /ws/tts:", e)
 
146
  await ws.close(code=1011)
147
 
148
+ # zum lokalen Test
149
  if __name__ == "__main__":
150
  import uvicorn
151
  uvicorn.run("app:app", host="0.0.0.0", port=7860)