Tomtom84 commited on
Commit
f63f843
·
verified ·
1 Parent(s): 986d4cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -105
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import json
3
  import asyncio
 
 
4
  import torch
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
  from huggingface_hub import login
@@ -10,142 +12,145 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
10
  # — HF‑Token & Login —
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  if HF_TOKEN:
13
- login(HF_TOKEN)
14
 
15
- # — Gerät wählen
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
- # — Modell‑Parameter —
19
- MODEL_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
20
- START_MARKER = 128259 # <|startoftranscript|>
21
- RESTART_MARKER = 128257 # <|startoftranscript_again|>
22
- EOS_TOKEN = 128258 # <|endoftranscript|>
23
- AUDIO_TOKEN_OFFSET = 128266 # Offset zum Zurückrechnen
24
- BLOCK_TOKENS = 7 # SNAC erwartet 7 Audio‑Tokens pro Block
25
- CHUNK_TOKENS = 50 # Anzahl neuer Tokens pro Generate‑Runde
26
-
27
  # — FastAPI instanziieren —
28
  app = FastAPI()
29
 
30
- # — Damit GET / nicht 404 wirft
31
  @app.get("/")
32
- async def read_root():
33
- return {"message": "Orpheus TTS Server ist live 🎙️"}
34
 
35
  # — Modelle bei Startup laden —
36
  @app.on_event("startup")
37
  async def load_models():
38
  global tokenizer, model, snac
39
- # SNAC laden
40
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
41
- # Tokenizer
42
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
43
- # TTS‑LM
 
44
  model = AutoModelForCausalLM.from_pretrained(
45
- MODEL_NAME,
46
- device_map="auto",
47
  torch_dtype=torch.bfloat16 if device=="cuda" else None,
48
  low_cpu_mem_usage=True
49
- )
50
- model.config.pad_token_id = EOS_TOKEN
 
51
 
52
- # — Eingabe aufbereiten
 
 
 
 
 
53
  def prepare_inputs(text: str, voice: str):
54
- prompt = f"{voice}: {text}"
55
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
56
- start = torch.tensor([[START_MARKER]], device=device)
57
- end = torch.tensor([[128009, EOS_TOKEN]], device=device)
58
- ids = torch.cat([start, input_ids, end], dim=1)
59
- attn_mask = torch.ones_like(ids)
60
- return ids, attn_mask
61
-
62
- # — Aus 7 Audio‑Tokens ein PCM‑Block erzeugen —
63
- def decode_block(block: list[int]) -> bytes:
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  l1, l2, l3 = [], [], []
65
- b = block
66
  l1.append(b[0])
67
- l2.append(b[1] - 4096)
68
- l3.append(b[2] - 2*4096)
69
- l3.append(b[3] - 3*4096)
70
- l2.append(b[4] - 4*4096)
71
- l3.append(b[5] - 5*4096)
72
- l3.append(b[6] - 6*4096)
 
73
  codes = [
74
- torch.tensor(l1, device=device).unsqueeze(0),
75
- torch.tensor(l2, device=device).unsqueeze(0),
76
- torch.tensor(l3, device=device).unsqueeze(0),
77
  ]
78
  audio = snac.decode(codes).squeeze().cpu().numpy()
79
- pcm16 = (audio * 32767).astype("int16").tobytes()
80
- return pcm16
81
-
82
- # — Generator: kleine Chunks token‑weise erzeugen und block‑weise dekodieren —
83
- async def generate_and_stream(ws: WebSocket, ids, attn_mask):
84
- buffer: list[int] = []
85
- past_kvs = None
86
-
87
- while True:
88
- # wir rufen model.generate in Häppchen auf
89
- outputs = model.generate(
90
- input_ids = ids if past_kvs is None else None,
91
- attention_mask = attn_mask if past_kvs is None else None,
92
- past_key_values= past_kvs,
93
- use_cache = True,
94
- max_new_tokens = CHUNK_TOKENS,
95
- do_sample = True,
96
- temperature = 0.7,
97
- top_p = 0.95,
98
- repetition_penalty = 1.1,
99
- eos_token_id = EOS_TOKEN,
100
- pad_token_id = EOS_TOKEN,
101
- return_dict_in_generate = True,
102
- output_scores = False,
103
- )
104
-
105
- # update past_kvs
106
- past_kvs = outputs.past_key_values
107
-
108
- # erhalte nur die gerade neu generierten Token
109
- seq = outputs.sequences[0]
110
- new_tokens = seq[-CHUNK_TOKENS:].tolist() if past_kvs is not None else seq[ids.shape[-1]:].tolist()
111
-
112
- for tok in new_tokens:
113
- # Neustart bei erneutem START‑Marker
114
- if tok == RESTART_MARKER:
115
- buffer = []
116
- continue
117
- # Ende
118
- if tok == EOS_TOKEN:
119
- return
120
- # Audio‑Code berechnen
121
- buffer.append(tok - AUDIO_TOKEN_OFFSET)
122
- # sobald 7 Audio‑Tokens, dekodieren und streamen
123
- if len(buffer) >= BLOCK_TOKENS:
124
- block = buffer[:BLOCK_TOKENS]
125
- buffer = buffer[BLOCK_TOKENS:]
126
- pcm = decode_block(block)
127
- await ws.send_bytes(pcm)
128
 
129
- # — WebSocket‑Endpoint für TTS Streaming —
130
  @app.websocket("/ws/tts")
131
  async def tts_ws(ws: WebSocket):
132
  await ws.accept()
133
  try:
134
- data = await ws.receive_text()
135
- req = json.loads(data)
136
- text = req.get("text", "")
137
- voice = req.get("voice", "Jakob")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- ids, attn_mask = prepare_inputs(text, voice)
140
- await generate_and_stream(ws, ids, attn_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
142
  await ws.close()
143
  except WebSocketDisconnect:
144
- pass
145
  except Exception as e:
146
- print("Error in /ws/tts:", e)
147
  await ws.close(code=1011)
148
-
149
- if __name__ == "__main__":
150
- import uvicorn
151
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
1
  import os
2
  import json
3
  import asyncio
4
+ import logging
5
+
6
  import torch
7
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8
  from huggingface_hub import login
 
12
  # — HF‑Token & Login —
13
  HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN:
15
+ login(token=HF_TOKEN)
16
 
17
+ # — Device auswählen
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
 
 
 
 
 
 
 
 
 
20
  # — FastAPI instanziieren —
21
  app = FastAPI()
22
 
23
+ # — Einfacher GET‑Endpunkt, damit / keine 404 liefert
24
  @app.get("/")
25
+ async def root():
26
+ return {"message": "Hello, world!"}
27
 
28
  # — Modelle bei Startup laden —
29
  @app.on_event("startup")
30
  async def load_models():
31
  global tokenizer, model, snac
32
+ logging.info("Lade SNAC...")
33
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
34
+
35
+ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
36
+ logging.info("Lade TTS‑Modell...")
37
+ tokenizer = AutoTokenizer.from_pretrained(REPO)
38
  model = AutoModelForCausalLM.from_pretrained(
39
+ REPO,
40
+ device_map="auto" if device=="cuda" else None,
41
  torch_dtype=torch.bfloat16 if device=="cuda" else None,
42
  low_cpu_mem_usage=True
43
+ ).to(device)
44
+ model.config.pad_token_id = model.config.eos_token_id
45
+ logging.info("Modelle geladen ✔️")
46
 
47
+ # — Konstanten für Audio‑Token und SNAC‑Blockgröße
48
+ AUDIO_TOKEN_OFFSET = 128266
49
+ AUDIO_CODE_SIZE = 4096
50
+ BLOCK_SIZE = 7
51
+
52
+ # — Hilfsfunktion: Prompt in Token/Mask umwandeln —
53
  def prepare_inputs(text: str, voice: str):
54
+ prompt = f"{voice}: {text}"
55
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
56
+ start = torch.tensor([[128259]], dtype=torch.int64, device=device)
57
+ end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
58
+ ids = torch.cat([start, input_ids, end], dim=1)
59
+ mask = torch.ones_like(ids)
60
+ return ids, mask
61
+
62
+ # — Hilfsfunktion: Dekodiere genau 7 Audio‑Codes
63
+ def decode_block(block_tokens: list[int]):
64
+ # Filter invalid
65
+ clean = []
66
+ for t in block_tokens:
67
+ code = t - AUDIO_TOKEN_OFFSET
68
+ if 0 <= code < AUDIO_CODE_SIZE:
69
+ clean.append(code)
70
+ else:
71
+ logging.warning(f"Ungültiger Audio‑Token {t}, skippe ihn")
72
+ if len(clean) != BLOCK_SIZE:
73
+ # Hier werfen wir raus, um nicht per CUDA‑Assertion zu crashen
74
+ logging.error(f"Block nicht gültig (saubere Codes={clean}), werfe Exception")
75
+ raise ValueError(f"Audio‑Block muss {BLOCK_SIZE} sauber haben, habe {len(clean)}")
76
+ # Baue SNAC‑Eingabe
77
  l1, l2, l3 = [], [], []
78
+ b = clean
79
  l1.append(b[0])
80
+ l2.append(b[1])
81
+ # das Original verschachtelte Layer‑Mapping
82
+ l3.append(b[2])
83
+ l3.append(b[3])
84
+ l2.append(b[4])
85
+ l3.append(b[5])
86
+ l3.append(b[6])
87
  codes = [
88
+ torch.tensor(l1, dtype=torch.int64, device=device).unsqueeze(0),
89
+ torch.tensor(l2, dtype=torch.int64, device=device).unsqueeze(0),
90
+ torch.tensor(l3, dtype=torch.int64, device=device).unsqueeze(0),
91
  ]
92
  audio = snac.decode(codes).squeeze().cpu().numpy()
93
+ return (audio * 32767).astype("int16").tobytes()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # — WebSocket‑Endpoint für TTSStreaming —
96
  @app.websocket("/ws/tts")
97
  async def tts_ws(ws: WebSocket):
98
  await ws.accept()
99
  try:
100
+ # 1) Input empfangen
101
+ msg = await ws.receive_text()
102
+ data = json.loads(msg)
103
+ text = data.get("text", "")
104
+ voice = data.get("voice", "Jakob")
105
+
106
+ # 2) Prompt → Input‑Tensors
107
+ input_ids, attention_mask = prepare_inputs(text, voice)
108
+ past_kvs = None
109
+ buffer = []
110
+
111
+ # 3) Token‑Loop (du kannst hier auch max_new_tokens=50 fahren,
112
+ # indem Du in jedem Durchgang bis zu 50 Token samplet und aufsummierst)
113
+ while True:
114
+ out = model(
115
+ input_ids=input_ids if past_kvs is None else None,
116
+ attention_mask=attention_mask if past_kvs is None else None,
117
+ past_key_values=past_kvs,
118
+ use_cache=True,
119
+ )
120
+ logits = out.logits[:, -1, :]
121
+ past_kvs = out.past_key_values
122
+ probs = torch.softmax(logits, dim=-1)
123
+ next_token = torch.multinomial(probs, num_samples=1).item()
124
+
125
+ # Ende‑Bedingungen
126
+ if next_token == model.config.eos_token_id:
127
+ break
128
+ if next_token == 128257:
129
+ # neuer Start → Buffer resetten
130
+ buffer = []
131
+ continue
132
 
133
+ buffer.append(next_token)
134
+ # immer, wenn wir ≥7 Codes sammeln, → dekodieren + senden
135
+ while len(buffer) >= BLOCK_SIZE:
136
+ block = buffer[:BLOCK_SIZE]
137
+ buffer = buffer[BLOCK_SIZE:]
138
+ try:
139
+ pcm = decode_block(block)
140
+ except Exception as e:
141
+ logging.error(f"Fehler beim Dekodieren: {e}")
142
+ await ws.close(code=1011)
143
+ return
144
+ await ws.send_bytes(pcm)
145
+
146
+ # Input nur beim ersten Schritt mitgeben
147
+ input_ids = None
148
+ attention_mask = None
149
 
150
+ # 4) nach Ende sauber schließen
151
  await ws.close()
152
  except WebSocketDisconnect:
153
+ logging.info("Client hat WS geschlossen")
154
  except Exception as e:
155
+ logging.error(f"Unbehandelter Fehler in /ws/tts: {e}")
156
  await ws.close(code=1011)