Tomtom84 commited on
Commit
bca75ea
Β·
verified Β·
1 Parent(s): b32ff77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -133
app.py CHANGED
@@ -1,167 +1,145 @@
1
- import os
2
- import json
3
- import asyncio
4
- import torch
5
- # Bugfix fΓΌr PyTorchΒ 2.2.x Flash‑SDP‑Assertion
6
- torch.backends.cuda.enable_flash_sdp(False)
7
-
8
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from huggingface_hub import login
 
10
  from snac import SNAC
11
- from transformers import AutoModelForCausalLM, AutoTokenizer
12
 
13
- # β€” HF‑Token & Login β€”
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if HF_TOKEN:
16
  login(HF_TOKEN)
17
 
18
- # β€” Device wΓ€hlen β€”
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # β€” FastAPI instanzieren β€”
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  app = FastAPI()
23
 
24
- # β€” Hello‑Route, damit GET / kein 404 mehr gibt β€”
25
  @app.get("/")
26
- async def read_root():
27
- return {"message": "Orpheus TTS WebSocket Server lΓ€uft"}
28
 
29
- # β€” Modelle beim Startup laden β€”
30
  @app.on_event("startup")
31
  async def load_models():
32
- global tokenizer, model, snac
33
-
34
- # SNAC fΓΌr Audio‑Decoding
35
- snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
36
-
37
- # Orpheus‑TTS Base
38
- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
39
- tokenizer = AutoTokenizer.from_pretrained(REPO)
40
  model = AutoModelForCausalLM.from_pretrained(
41
  REPO,
42
- device_map={"": 0} if device=="cuda" else None,
43
- torch_dtype=torch.bfloat16 if device=="cuda" else None,
44
- low_cpu_mem_usage=True
45
- #return_legacy_cache=True # fΓΌr compatibility mit past_key_values als Tuple
46
- ).to(device)
47
  model.config.pad_token_id = model.config.eos_token_id
48
- # optional, aber explizit:
49
- model.config.use_cache = True
50
-
51
- # --- Logit‑Masking vorbereiten ---
52
- # reine Audio‑Tokens laufen von 128266 bis 128266+4096-1
53
- AUDIO_OFFSET = 128266
54
- AUDIO_COUNT = 4096
55
- valid_audio = torch.arange(AUDIO_OFFSET, AUDIO_OFFSET + AUDIO_COUNT, device=device)
56
- ctrl_tokens = torch.tensor([128257, model.config.eos_token_id], device=device)
57
- global ALLOWED_IDS
58
- ALLOWED_IDS = torch.cat([valid_audio, ctrl_tokens])
59
-
60
- def sample_from_logits(logits: torch.Tensor) -> int:
61
- """
62
- Maskt alle IDs außer ALLOWED_IDS und sampelt dann einen Token.
63
- """
64
- # logits: [1, vocab_size]
65
- mask = torch.full_like(logits, float("-inf"))
66
- mask[0, ALLOWED_IDS] = 0.0
67
- probs = torch.softmax(logits + mask, dim=-1)
68
- return torch.multinomial(probs, num_samples=1).item()
69
-
70
- def prepare_inputs(text: str, voice: str):
71
- prompt = f"{voice}: {text}"
72
- ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
73
- # Start‐/End‐Marker
74
- start = torch.tensor([[128259]], dtype=torch.int64, device=device)
75
- end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
76
- input_ids = torch.cat([start, ids, end], dim=1)
77
- attention_mask = torch.ones_like(input_ids, device=device)
78
- return input_ids, attention_mask
79
-
80
- def decode_block(block: list[int]) -> bytes:
81
- """
82
- Aus 7 gesampelten Audio‑Codes einen PCM‑16‑Byte‐Block dekodieren.
83
- Hier erwarten wir block[i] = raw_token - 128266.
84
- """
85
- layer1, layer2, layer3 = [], [], []
86
- b = block
87
- layer1.append(b[0])
88
- layer2.append(b[1] - 4096)
89
- layer3.append(b[2] - 2*4096)
90
- layer3.append(b[3] - 3*4096)
91
- layer2.append(b[4] - 4*4096)
92
- layer3.append(b[5] - 5*4096)
93
- layer3.append(b[6] - 6*4096)
94
-
95
- dev = next(snac.parameters()).device
96
- codes = [
97
- torch.tensor(layer1, device=dev).unsqueeze(0),
98
- torch.tensor(layer2, device=dev).unsqueeze(0),
99
- torch.tensor(layer3, device=dev).unsqueeze(0),
100
- ]
101
- audio = snac.decode(codes).squeeze().cpu().numpy()
102
- # in PCM16 umwandeln
103
- pcm16 = (audio * 32767).astype("int16").tobytes()
104
- return pcm16
105
-
106
- # β€” WebSocket Endpoint fΓΌr TTS Streaming β€”
107
  @app.websocket("/ws/tts")
108
- async def tts_ws(ws: WebSocket):
109
  await ws.accept()
110
  try:
111
- msg = await ws.receive_text()
112
- req = json.loads(msg)
113
- text = req.get("text", "")
114
- voice = req.get("voice", "Jakob")
115
 
116
- # Inputs vorbereiten
117
- input_ids, attention_mask = prepare_inputs(text, voice)
118
- past_kvs = None
119
- buffer = [] # sammelt die 7 Audio‑Codes
120
 
121
- # Token‑fΓΌr‑Token Loop
122
  while True:
123
- out = model(
124
- input_ids=input_ids if past_kvs is None else None,
125
- attention_mask=attention_mask if past_kvs is None else None,
126
- past_key_values=past_kvs,
 
 
 
127
  use_cache=True,
128
- return_dict=True
129
  )
130
- past_kvs = out.past_key_values
131
- next_tok = sample_from_logits(out.logits[:, -1, :])
132
-
133
- # Ende?
134
- if next_tok == model.config.eos_token_id:
135
- break
136
-
137
- # Reset bei neuem Audio‑Block‑Start
138
- if next_tok == 128257:
139
- buffer.clear()
140
- input_ids = torch.tensor([[next_tok]], device=device)
141
- attention_mask = torch.ones_like(input_ids)
142
- continue
143
-
144
- # Audio‑Code sammeln (Offset abziehen)
145
- buffer.append(next_tok - 128266)
146
- # sobald wir 7 Codes haben β†’ dekodieren & senden
147
- if len(buffer) == 7:
148
- pcm = decode_block(buffer)
149
- buffer.clear()
150
- await ws.send_bytes(pcm)
151
-
152
- # nΓ€chster Schritt: genau diesen Token wieder einspeisen
153
- input_ids = torch.tensor([[next_tok]], device=device)
154
- attention_mask = torch.ones_like(input_ids)
155
-
156
- # sauber beenden
157
- await ws.close()
158
- except WebSocketDisconnect:
159
  pass
160
  except Exception as e:
161
- print("Error in /ws/tts:", e)
162
  await ws.close(code=1011)
 
 
 
163
 
164
- # β€” CLI zum lokalen Testen β€”
165
  if __name__ == "__main__":
166
  import uvicorn
167
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
1
+ # app.py ─────────────────────────────────────────────────────────────
2
+ import os, json, asyncio, torch
 
 
 
 
 
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
6
  from snac import SNAC
 
7
 
8
+ # ── 0.Β HF‑Auth & Device ──────────────────────────────────────────────
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
  if HF_TOKEN:
11
  login(HF_TOKEN)
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ # Flash‑Attention‑Bug in PyTorchΒ 2.2.x
16
+ torch.backends.cuda.enable_flash_sdp(False)
17
+
18
+ # ── 1.Β Konstanten ────────────────────────────────────────────────────
19
+ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
20
+ CHUNK_TOKENS = 50
21
+ START_TOKEN = 128259
22
+ NEW_BLOCK_TOKEN = 128257
23
+ EOS_TOKEN = 128258
24
+ AUDIO_BASE = 128266
25
+ VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
26
+
27
+ # ── 2.Β Logit‑Processor zum Maskieren ────────────────────────────────
28
+ class AudioLogitMask(LogitsProcessor):
29
+ def __init__(self, allowed_ids: torch.Tensor):
30
+ super().__init__()
31
+ self.allowed = allowed_ids
32
+
33
+ def __call__(self, input_ids, scores):
34
+ # scores shape: [batch, vocab]
35
+ mask = torch.full_like(scores, float("-inf"))
36
+ mask[:, self.allowed] = 0
37
+ return scores + mask
38
+
39
+ ALLOWED_IDS = torch.cat(
40
+ [VALID_AUDIO_IDS, torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])]
41
+ ).to(device)
42
+ MASKER = AudioLogitMask(ALLOWED_IDS)
43
+
44
+ # ── 3.Β FastAPI ‑ GrundgerΓΌst ─────────────────────────────────────────
45
  app = FastAPI()
46
 
 
47
  @app.get("/")
48
+ async def ping():
49
+ return {"msg": "Orpheus‑TTS OK"}
50
 
 
51
  @app.on_event("startup")
52
  async def load_models():
53
+ global tok, model, snac
54
+ tok = AutoTokenizer.from_pretrained(REPO)
55
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
 
 
 
 
 
56
  model = AutoModelForCausalLM.from_pretrained(
57
  REPO,
58
+ low_cpu_mem_usage=True,
59
+ device_map={"": 0} if device == "cuda" else None,
60
+ torch_dtype=torch.bfloat16 if device == "cuda" else None,
61
+ )
 
62
  model.config.pad_token_id = model.config.eos_token_id
63
+ model.config.use_cache = True
64
+
65
+ # ── 4.Β Hilfs‑Funktionen ─────────────────────────────────────────────
66
+ def build_prompt(text:str, voice:str):
67
+ base = f"{voice}: {text}"
68
+ ids = tok(base, return_tensors="pt").input_ids.to(device)
69
+ ids = torch.cat(
70
+ [
71
+ torch.tensor([[START_TOKEN]], device=device),
72
+ ids,
73
+ torch.tensor([[128009, 128260]], device=device),
74
+ ],
75
+ 1,
76
+ )
77
+ return ids, torch.ones_like(ids)
78
+
79
+ def decode_snac(block7:list[int])->bytes:
80
+ l1,l2,l3=[],[],[]
81
+ b=block7
82
+ l1.append(b[0])
83
+ l2.append(b[1]-4096)
84
+ l3.extend([b[2]-8192, b[3]-12288])
85
+ l2.append(b[4]-16384)
86
+ l3.extend([b[5]-20480, b[6]-24576])
87
+
88
+ codes=[torch.tensor(x,device=device).unsqueeze(0)
89
+ for x in (l1,l2,l3)]
90
+ audio=snac.decode(codes).squeeze().cpu().numpy()
91
+ return (audio*32767).astype("int16").tobytes()
92
+
93
+ # ── 5.Β WebSocket‑Endpoint ───────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  @app.websocket("/ws/tts")
95
+ async def tts(ws: WebSocket):
96
  await ws.accept()
97
  try:
98
+ req = json.loads(await ws.receive_text())
99
+ text = req.get("text","")
100
+ voice = req.get("voice","Jakob")
 
101
 
102
+ ids, attn = build_prompt(text, voice)
103
+ past = None
104
+ buf = []
 
105
 
 
106
  while True:
107
+ out = model.generate(
108
+ input_ids=ids if past is None else None,
109
+ attention_mask=attn if past is None else None,
110
+ past_key_values=past,
111
+ max_new_tokens=CHUNK_TOKENS,
112
+ logits_processor=[MASKER],
113
+ do_sample=True, temperature=0.7, top_p=0.95,
114
  use_cache=True,
115
+ return_dict_in_generate=True,
116
  )
117
+ past = out.past_key_values
118
+ newtok = out.sequences[0,-out.num_generated_tokens:].tolist()
119
+
120
+ for t in newtok:
121
+ if t==EOS_TOKEN:
122
+ raise StopIteration
123
+ if t==NEW_BLOCK_TOKEN:
124
+ buf.clear(); continue
125
+ buf.append(t-AUDIO_BASE)
126
+ if len(buf)==7:
127
+ await ws.send_bytes(decode_snac(buf))
128
+ buf.clear()
129
+
130
+ # ab jetzt nur noch mit Cache weiter‑generieren
131
+ ids, attn = None, None
132
+
133
+ except (StopIteration, WebSocketDisconnect):
 
 
 
 
 
 
 
 
 
 
 
 
134
  pass
135
  except Exception as e:
136
+ print("WS‑Error:", e)
137
  await ws.close(code=1011)
138
+ finally:
139
+ if ws.client_state.name!="DISCONNECTED":
140
+ await ws.close()
141
 
142
+ # ── 6.Β Lokaler Test ─────────────────────────────────────────────────
143
  if __name__ == "__main__":
144
  import uvicorn
145
  uvicorn.run("app:app", host="0.0.0.0", port=7860)