Tomtom84 commited on
Commit
a3af518
·
verified ·
1 Parent(s): 10be82b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -124
app.py CHANGED
@@ -1,186 +1,139 @@
1
  import os
2
  import json
 
3
  import torch
4
- import numpy as np
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
  from huggingface_hub import login
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from snac import SNAC
 
9
 
10
- # — HF‑Token & Login (falls gesetzt)
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  if HF_TOKEN:
13
  login(HF_TOKEN)
14
 
15
- # — Device auswählen
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
 
18
  app = FastAPI()
19
 
 
20
  @app.get("/")
21
  async def read_root():
22
  return {"message": "Hello, world!"}
23
 
24
- # — Globale Modelle —
25
- model = None
26
- tokenizer = None
27
- snac_model = None
28
-
29
  @app.on_event("startup")
30
  async def load_models():
31
- global model, tokenizer, snac_model
32
-
33
- # 1) SNAC laden
34
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
35
-
36
- # 2) Orpheus‑TTS (public “natural”-Variante)
37
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
38
  tokenizer = AutoTokenizer.from_pretrained(REPO)
39
  model = AutoModelForCausalLM.from_pretrained(
40
  REPO,
41
- device_map="auto" if device == "cuda" else None,
42
  torch_dtype=torch.bfloat16 if device == "cuda" else None,
43
  low_cpu_mem_usage=True
44
- ).to(device)
 
45
  model.config.pad_token_id = model.config.eos_token_id
46
 
47
- # — Marker und Offsets —
48
- START_TOKEN = 128259
49
- END_TOKENS = [128009, 128260]
50
- AUDIO_OFFSET = 128266
51
-
52
- def process_single_prompt(prompt: str, voice: str) -> list[int]:
53
- # Prompt zusammenstellen
54
- text = f"{voice}: {prompt}" if voice and voice != "in_prompt" else prompt
55
-
56
- # Tokenize + Marker
57
- ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
58
- start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
59
- end = torch.tensor([END_TOKENS], dtype=torch.int64, device=device)
60
  input_ids = torch.cat([start, ids, end], dim=1)
61
  attention_mask = torch.ones_like(input_ids)
62
-
63
- # Generieren
64
- gen = model.generate(
65
- input_ids=input_ids,
66
- attention_mask=attention_mask,
67
- max_new_tokens=4000,
68
- do_sample=True,
69
- temperature=0.6,
70
- top_p=0.95,
71
- repetition_penalty=1.1,
72
- eos_token_id=128258,
73
- use_cache=True,
74
- )
75
-
76
- # Nach letztem START_TOKEN croppen
77
- token_to_find = 128257
78
- token_to_remove = 128258
79
- idxs = (gen == token_to_find).nonzero(as_tuple=True)[1]
80
- if idxs.numel() > 0:
81
- cropped = gen[:, idxs[-1] + 1 :]
82
- else:
83
- cropped = gen
84
-
85
- # Padding entfernen & Länge auf Vielfaches von 7 bringen
86
- row = cropped[0][cropped[0] != token_to_remove]
87
- new_len = (row.size(0) // 7) * 7
88
- trimmed = row[:new_len].tolist()
89
-
90
- # Offset abziehen
91
- return [t - AUDIO_OFFSET for t in trimmed]
92
-
93
- def redistribute_codes(code_list: list[int]) -> np.ndarray:
94
- # 7er‑Blöcke auf 3 Layer verteilen
95
- layer1, layer2, layer3 = [], [], []
96
- for i in range(len(code_list) // 7):
97
- b = code_list[7*i : 7*i+7]
98
- layer1.append(b[0])
99
- layer2.append(b[1] - 4096)
100
- layer3.append(b[2] - 2*4096)
101
- layer3.append(b[3] - 3*4096)
102
- layer2.append(b[4] - 4*4096)
103
- layer3.append(b[5] - 5*4096)
104
- layer3.append(b[6] - 6*4096)
105
-
106
  codes = [
107
- torch.tensor(layer1, device=device).unsqueeze(0),
108
- torch.tensor(layer2, device=device).unsqueeze(0),
109
- torch.tensor(layer3, device=device).unsqueeze(0),
110
  ]
111
- audio = snac_model.decode(codes).squeeze().cpu().numpy()
112
- return audio # float32 @24 kHz
113
 
 
114
  @app.websocket("/ws/tts")
115
  async def tts_ws(ws: WebSocket):
116
  await ws.accept()
117
  try:
118
- msg = await ws.receive_text()
119
  req = json.loads(msg)
120
  text = req.get("text", "")
121
- voice = req.get("voice", "")
122
 
123
- # 1) Prompt vorbereiten
124
  input_ids, attention_mask = prepare_inputs(text, voice)
125
  past_kvs = None
126
- buffer = []
127
 
128
- # 2) Token‑für‑Token (oder in kleinen Blöcken)
129
  while True:
130
- # Nur max_new_tokens=50 pro Aufruf
131
- out = model.generate(
132
- input_ids=input_ids if past_kvs is None else None,
133
  attention_mask=attention_mask if past_kvs is None else None,
134
  past_key_values=past_kvs,
135
  use_cache=True,
136
- do_sample=True,
137
- temperature=0.7,
138
- top_p=0.95,
139
- repetition_penalty=1.1,
140
- max_new_tokens=50,
141
- eos_token_id=128258,
142
- return_dict_in_generate=True,
143
- output_past_key_values=True,
144
- return_legacy_cache=True, # falls Ihr noch das alte past_key_values-Format braucht
145
  )
146
-
147
- # Extrahiere neue Token (ohne die already generated ones)
148
- new_ids = out.sequences[0, input_ids.shape[-1]:].tolist()
149
  past_kvs = out.past_key_values
150
 
151
- for tok in new_ids:
152
- if tok == model.config.eos_token_id:
153
- # Stream zu Ende
154
- break
155
- if tok == 128257: # Reset-Start‑Marker
156
- buffer = []
157
- continue
158
- buffer.append(tok - AUDIO_OFFSET)
159
-
160
- # Sobald wir 7 Audio‑Codes gesammelt haben → dekodieren & schicken
161
- if len(buffer) == 7:
162
- pcm = decode_block(buffer)
163
- buffer = []
164
- await ws.send_bytes(pcm)
165
-
166
- # Wenn EOS im Chunk war, abbrechen
167
- if model.config.eos_token_id in new_ids:
168
- break
169
 
170
- # Danach weiter mit nächsten 50 Tokens,
171
- # input_ids & attention_mask nur beim ersten Aufruf nötig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  input_ids = None
173
  attention_mask = None
174
 
175
- # 3) Am Ende WebSocket sauber schließen
176
  await ws.close()
177
 
178
  except WebSocketDisconnect:
 
179
  pass
 
180
  except Exception as e:
 
181
  print("Error in /ws/tts:", e)
182
  await ws.close(code=1011)
183
-
184
- if __name__ == "__main__":
185
- import uvicorn
186
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
1
  import os
2
  import json
3
+ import asyncio
4
  import torch
 
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
  from huggingface_hub import login
 
7
  from snac import SNAC
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
+ # — HF‑Token & Login —
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  if HF_TOKEN:
13
  login(HF_TOKEN)
14
 
15
+ # — Device wählen
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ # — FastAPI instanziieren —
19
  app = FastAPI()
20
 
21
+ # — Hello‑Route, damit GET / nicht 404 wirft —
22
  @app.get("/")
23
  async def read_root():
24
  return {"message": "Hello, world!"}
25
 
26
+ # — Modelle bei Startup laden
 
 
 
 
27
  @app.on_event("startup")
28
  async def load_models():
29
+ global tokenizer, model, snac
30
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
 
 
 
 
31
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
32
  tokenizer = AutoTokenizer.from_pretrained(REPO)
33
  model = AutoModelForCausalLM.from_pretrained(
34
  REPO,
35
+ device_map="auto",
36
  torch_dtype=torch.bfloat16 if device == "cuda" else None,
37
  low_cpu_mem_usage=True
38
+ )
39
+ # Für pad-token fallback auf eos
40
  model.config.pad_token_id = model.config.eos_token_id
41
 
42
+ # — Hilfsfunktionen
43
+ START_TOKEN = 128259
44
+ END_TOKENS = [128009, 128260]
45
+ RESET_TOKEN = 128257
46
+ AUDIO_OFFSET = 128266
47
+ EOS_TOKEN = model.config.eos_token_id if 'model' in globals() else 128258
48
+
49
+ def prepare_inputs(text: str, voice: str):
50
+ prompt = f"{voice}: {text}"
51
+ ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
52
+ start = torch.tensor([[START_TOKEN]], device=device)
53
+ end = torch.tensor([END_TOKENS], device=device)
 
54
  input_ids = torch.cat([start, ids, end], dim=1)
55
  attention_mask = torch.ones_like(input_ids)
56
+ return input_ids, attention_mask
57
+
58
+ def decode_block(block: list[int]):
59
+ # aus genau 7 Audio‑Codes ein PCM‑Byte‑Block bauen
60
+ l1, l2, l3 = [], [], []
61
+ b = block
62
+ l1.append(b[0])
63
+ l2.append(b[1] - 4096)
64
+ l3.append(b[2] - 2*4096)
65
+ l3.append(b[3] - 3*4096)
66
+ l2.append(b[4] - 4*4096)
67
+ l3.append(b[5] - 5*4096)
68
+ l3.append(b[6] - 6*4096)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  codes = [
70
+ torch.tensor(l1, device=device).unsqueeze(0),
71
+ torch.tensor(l2, device=device).unsqueeze(0),
72
+ torch.tensor(l3, device=device).unsqueeze(0),
73
  ]
74
+ audio = snac.decode(codes).squeeze().cpu().numpy()
75
+ return (audio * 32767).astype("int16").tobytes()
76
 
77
+ # — WebSocket‑Endpoint für TTS Streaming —
78
  @app.websocket("/ws/tts")
79
  async def tts_ws(ws: WebSocket):
80
  await ws.accept()
81
  try:
82
+ msg = await ws.receive_text()
83
  req = json.loads(msg)
84
  text = req.get("text", "")
85
+ voice = req.get("voice", "Jakob")
86
 
 
87
  input_ids, attention_mask = prepare_inputs(text, voice)
88
  past_kvs = None
89
+ collected = []
90
 
91
+ # Token‑für‑Token mit eigener Sampling‑Schleife
92
  while True:
93
+ out = model(
94
+ input_ids=input_ids if past_kvs is None else None,
 
95
  attention_mask=attention_mask if past_kvs is None else None,
96
  past_key_values=past_kvs,
97
  use_cache=True,
 
 
 
 
 
 
 
 
 
98
  )
99
+ logits = out.logits[:, -1, :]
 
 
100
  past_kvs = out.past_key_values
101
 
102
+ # Sampling
103
+ probs = torch.softmax(logits, dim=-1)
104
+ nxt = torch.multinomial(probs, num_samples=1).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ # EOS fertig
107
+ if nxt == EOS_TOKEN:
108
+ break
109
+ # RESET → alte Sammlung verwerfen
110
+ if nxt == RESET_TOKEN:
111
+ collected = []
112
+ # und input_ids für nächsten Durchlauf auf None setzen
113
+ input_ids = None
114
+ attention_mask = None
115
+ continue
116
+
117
+ # Audio‑Code abziehen & sammeln
118
+ collected.append(nxt - AUDIO_OFFSET)
119
+ # jede 7 Codes → dekodieren & streamen
120
+ if len(collected) == 7:
121
+ pcm = decode_block(collected)
122
+ collected = []
123
+ await ws.send_bytes(pcm)
124
+
125
+ # nur beim allerersten Schritt mit IDs arbeiten
126
  input_ids = None
127
  attention_mask = None
128
 
129
+ # Stream sauber beenden
130
  await ws.close()
131
 
132
  except WebSocketDisconnect:
133
+ # Client hat Disconnect gemacht → nichts tun
134
  pass
135
+
136
  except Exception as e:
137
+ # auf Fehler 1011 senden
138
  print("Error in /ws/tts:", e)
139
  await ws.close(code=1011)