Tomtom84 commited on
Commit
a4cfefc
·
verified ·
1 Parent(s): f63f843

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -88
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
  import json
3
  import asyncio
4
- import logging
5
-
6
  import torch
 
 
 
7
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8
  from huggingface_hub import login
9
  from snac import SNAC
@@ -12,145 +13,153 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
12
  # — HF‑Token & Login —
13
  HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN:
15
- login(token=HF_TOKEN)
16
 
17
- # — Device auswählen
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # — FastAPI instanziieren
21
  app = FastAPI()
22
 
23
- # — Einfacher GETEndpunkt, damit / keine 404 liefert
24
  @app.get("/")
25
- async def root():
26
- return {"message": "Hello, world!"}
27
 
28
- # — Modelle bei Startup laden —
29
  @app.on_event("startup")
30
  async def load_models():
31
  global tokenizer, model, snac
32
- logging.info("Lade SNAC...")
 
33
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
34
 
 
35
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
36
- logging.info("Lade TTS‑Modell...")
37
  tokenizer = AutoTokenizer.from_pretrained(REPO)
38
  model = AutoModelForCausalLM.from_pretrained(
39
  REPO,
40
- device_map="auto" if device=="cuda" else None,
41
  torch_dtype=torch.bfloat16 if device=="cuda" else None,
42
- low_cpu_mem_usage=True
 
43
  ).to(device)
44
  model.config.pad_token_id = model.config.eos_token_id
45
- logging.info("Modelle geladen ✔️")
46
 
47
- # Konstanten für AudioToken und SNAC‑Blockgröße —
48
- AUDIO_TOKEN_OFFSET = 128266
49
- AUDIO_CODE_SIZE = 4096
50
- BLOCK_SIZE = 7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # — Hilfsfunktion: Prompt in Token/Mask umwandeln —
53
  def prepare_inputs(text: str, voice: str):
54
  prompt = f"{voice}: {text}"
55
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
56
  start = torch.tensor([[128259]], dtype=torch.int64, device=device)
57
  end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
58
- ids = torch.cat([start, input_ids, end], dim=1)
59
- mask = torch.ones_like(ids)
60
- return ids, mask
61
-
62
- # — Hilfsfunktion: Dekodiere genau 7 Audio‑Codes →
63
- def decode_block(block_tokens: list[int]):
64
- # Filter invalid
65
- clean = []
66
- for t in block_tokens:
67
- code = t - AUDIO_TOKEN_OFFSET
68
- if 0 <= code < AUDIO_CODE_SIZE:
69
- clean.append(code)
70
- else:
71
- logging.warning(f"Ungültiger Audio‑Token {t}, skippe ihn")
72
- if len(clean) != BLOCK_SIZE:
73
- # Hier werfen wir raus, um nicht per CUDA‑Assertion zu crashen
74
- logging.error(f"Block nicht gültig (saubere Codes={clean}), werfe Exception")
75
- raise ValueError(f"Audio‑Block muss {BLOCK_SIZE} sauber haben, habe {len(clean)}")
76
- # Baue SNAC‑Eingabe
77
- l1, l2, l3 = [], [], []
78
- b = clean
79
- l1.append(b[0])
80
- l2.append(b[1])
81
- # das Original verschachtelte Layer‑Mapping
82
- l3.append(b[2])
83
- l3.append(b[3])
84
- l2.append(b[4])
85
- l3.append(b[5])
86
- l3.append(b[6])
87
  codes = [
88
- torch.tensor(l1, dtype=torch.int64, device=device).unsqueeze(0),
89
- torch.tensor(l2, dtype=torch.int64, device=device).unsqueeze(0),
90
- torch.tensor(l3, dtype=torch.int64, device=device).unsqueeze(0),
91
  ]
92
  audio = snac.decode(codes).squeeze().cpu().numpy()
93
- return (audio * 32767).astype("int16").tobytes()
 
 
94
 
95
- # — WebSocketEndpoint für TTSStreaming —
96
  @app.websocket("/ws/tts")
97
  async def tts_ws(ws: WebSocket):
98
  await ws.accept()
99
  try:
100
- # 1) Input empfangen
101
  msg = await ws.receive_text()
102
- data = json.loads(msg)
103
- text = data.get("text", "")
104
- voice = data.get("voice", "Jakob")
105
 
106
- # 2) Prompt → Input‑Tensors
107
  input_ids, attention_mask = prepare_inputs(text, voice)
108
  past_kvs = None
109
- buffer = []
110
 
111
- # 3) Token‑Loop (du kannst hier auch max_new_tokens=50 fahren,
112
- # indem Du in jedem Durchgang bis zu 50 Token samplet und aufsummierst)
113
  while True:
114
  out = model(
115
  input_ids=input_ids if past_kvs is None else None,
116
  attention_mask=attention_mask if past_kvs is None else None,
117
  past_key_values=past_kvs,
118
  use_cache=True,
 
119
  )
120
- logits = out.logits[:, -1, :]
121
- past_kvs = out.past_key_values
122
- probs = torch.softmax(logits, dim=-1)
123
- next_token = torch.multinomial(probs, num_samples=1).item()
124
 
125
- # Ende‑Bedingungen
126
- if next_token == model.config.eos_token_id:
127
  break
128
- if next_token == 128257:
129
- # neuer Start Buffer resetten
130
- buffer = []
 
 
 
131
  continue
132
 
133
- buffer.append(next_token)
134
- # immer, wenn wir ≥7 Codes sammeln, → dekodieren + senden
135
- while len(buffer) >= BLOCK_SIZE:
136
- block = buffer[:BLOCK_SIZE]
137
- buffer = buffer[BLOCK_SIZE:]
138
- try:
139
- pcm = decode_block(block)
140
- except Exception as e:
141
- logging.error(f"Fehler beim Dekodieren: {e}")
142
- await ws.close(code=1011)
143
- return
144
  await ws.send_bytes(pcm)
145
 
146
- # Input nur beim ersten Schritt mitgeben
147
- input_ids = None
148
- attention_mask = None
149
 
150
- # 4) nach Ende sauber schließen
151
  await ws.close()
152
  except WebSocketDisconnect:
153
- logging.info("Client hat WS geschlossen")
154
  except Exception as e:
155
- logging.error(f"Unbehandelter Fehler in /ws/tts: {e}")
156
  await ws.close(code=1011)
 
 
 
 
 
 
1
  import os
2
  import json
3
  import asyncio
 
 
4
  import torch
5
+ # Bugfix für PyTorch 2.2.x Flash‑SDP‑Assertion
6
+ torch.backends.cuda.enable_flash_sdp(False)
7
+
8
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from huggingface_hub import login
10
  from snac import SNAC
 
13
  # — HF‑Token & Login —
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if HF_TOKEN:
16
+ login(HF_TOKEN)
17
 
18
+ # — Device wählen
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # — FastAPI instanzieren
22
  app = FastAPI()
23
 
24
+ # — HelloRoute, damit GET / kein 404 mehr gibt
25
  @app.get("/")
26
+ async def read_root():
27
+ return {"message": "Orpheus TTS WebSocket Server läuft"}
28
 
29
+ # — Modelle beim Startup laden —
30
  @app.on_event("startup")
31
  async def load_models():
32
  global tokenizer, model, snac
33
+
34
+ # SNAC für Audio‑Decoding
35
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
36
 
37
+ # Orpheus‑TTS Base
38
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
 
39
  tokenizer = AutoTokenizer.from_pretrained(REPO)
40
  model = AutoModelForCausalLM.from_pretrained(
41
  REPO,
42
+ device_map={"": 0} if device=="cuda" else None,
43
  torch_dtype=torch.bfloat16 if device=="cuda" else None,
44
+ low_cpu_mem_usage=True,
45
+ return_legacy_cache=True # für compatibility mit past_key_values als Tuple
46
  ).to(device)
47
  model.config.pad_token_id = model.config.eos_token_id
 
48
 
49
+ # --- LogitMasking vorbereiten ---
50
+ # reine Audio‑Tokens laufen von 128266 bis 128266+4096-1
51
+ AUDIO_OFFSET = 128266
52
+ AUDIO_COUNT = 4096
53
+ valid_audio = torch.arange(AUDIO_OFFSET, AUDIO_OFFSET + AUDIO_COUNT, device=device)
54
+ ctrl_tokens = torch.tensor([128257, model.config.eos_token_id], device=device)
55
+ global ALLOWED_IDS
56
+ ALLOWED_IDS = torch.cat([valid_audio, ctrl_tokens])
57
+
58
+ def sample_from_logits(logits: torch.Tensor) -> int:
59
+ """
60
+ Maskt alle IDs außer ALLOWED_IDS und sampelt dann einen Token.
61
+ """
62
+ # logits: [1, vocab_size]
63
+ mask = torch.full_like(logits, float("-inf"))
64
+ mask[0, ALLOWED_IDS] = 0.0
65
+ probs = torch.softmax(logits + mask, dim=-1)
66
+ return torch.multinomial(probs, num_samples=1).item()
67
 
 
68
  def prepare_inputs(text: str, voice: str):
69
  prompt = f"{voice}: {text}"
70
+ ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
71
+ # Start‐/End‐Marker
72
  start = torch.tensor([[128259]], dtype=torch.int64, device=device)
73
  end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
74
+ input_ids = torch.cat([start, ids, end], dim=1)
75
+ attention_mask = torch.ones_like(input_ids, device=device)
76
+ return input_ids, attention_mask
77
+
78
+ def decode_block(block: list[int]) -> bytes:
79
+ """
80
+ Aus 7 gesampelten Audio‑Codes einen PCM‑16‑Byte‐Block dekodieren.
81
+ Hier erwarten wir block[i] = raw_token - 128266.
82
+ """
83
+ layer1, layer2, layer3 = [], [], []
84
+ b = block
85
+ layer1.append(b[0])
86
+ layer2.append(b[1] - 4096)
87
+ layer3.append(b[2] - 2*4096)
88
+ layer3.append(b[3] - 3*4096)
89
+ layer2.append(b[4] - 4*4096)
90
+ layer3.append(b[5] - 5*4096)
91
+ layer3.append(b[6] - 6*4096)
92
+
93
+ dev = next(snac.parameters()).device
 
 
 
 
 
 
 
 
 
94
  codes = [
95
+ torch.tensor(layer1, device=dev).unsqueeze(0),
96
+ torch.tensor(layer2, device=dev).unsqueeze(0),
97
+ torch.tensor(layer3, device=dev).unsqueeze(0),
98
  ]
99
  audio = snac.decode(codes).squeeze().cpu().numpy()
100
+ # in PCM16 umwandeln
101
+ pcm16 = (audio * 32767).astype("int16").tobytes()
102
+ return pcm16
103
 
104
+ # — WebSocket Endpoint für TTS Streaming —
105
  @app.websocket("/ws/tts")
106
  async def tts_ws(ws: WebSocket):
107
  await ws.accept()
108
  try:
 
109
  msg = await ws.receive_text()
110
+ req = json.loads(msg)
111
+ text = req.get("text", "")
112
+ voice = req.get("voice", "Jakob")
113
 
114
+ # Inputs vorbereiten
115
  input_ids, attention_mask = prepare_inputs(text, voice)
116
  past_kvs = None
117
+ buffer = [] # sammelt die 7 Audio‑Codes
118
 
119
+ # Token‑für‑Token Loop
 
120
  while True:
121
  out = model(
122
  input_ids=input_ids if past_kvs is None else None,
123
  attention_mask=attention_mask if past_kvs is None else None,
124
  past_key_values=past_kvs,
125
  use_cache=True,
126
+ return_dict=True
127
  )
128
+ past_kvs = out.past_key_values
129
+ next_tok = sample_from_logits(out.logits[:, -1, :])
 
 
130
 
131
+ # Ende?
132
+ if next_tok == model.config.eos_token_id:
133
  break
134
+
135
+ # Reset bei neuem Audio‑Block‑Start
136
+ if next_tok == 128257:
137
+ buffer.clear()
138
+ input_ids = torch.tensor([[next_tok]], device=device)
139
+ attention_mask = torch.ones_like(input_ids)
140
  continue
141
 
142
+ # Audio‑Code sammeln (Offset abziehen)
143
+ buffer.append(next_tok - 128266)
144
+ # sobald wir 7 Codes haben → dekodieren & senden
145
+ if len(buffer) == 7:
146
+ pcm = decode_block(buffer)
147
+ buffer.clear()
 
 
 
 
 
148
  await ws.send_bytes(pcm)
149
 
150
+ # nächster Schritt: genau diesen Token wieder einspeisen
151
+ input_ids = torch.tensor([[next_tok]], device=device)
152
+ attention_mask = torch.ones_like(input_ids)
153
 
154
+ # sauber beenden
155
  await ws.close()
156
  except WebSocketDisconnect:
157
+ pass
158
  except Exception as e:
159
+ print("Error in /ws/tts:", e)
160
  await ws.close(code=1011)
161
+
162
+ # — CLI zum lokalen Testen —
163
+ if __name__ == "__main__":
164
+ import uvicorn
165
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)