Tomtom84 commited on
Commit
0b5b901
Β·
verified Β·
1 Parent(s): 87012a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -101
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py ─────────────────────────────────────────────────────────────
2
  import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
@@ -6,109 +6,97 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
  from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
- # ── 0. HF‑Auth & Device ──────────────────────────────────────────────
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- # Flash‑SDP‑Bug (PyTorch 2.2) deaktivieren
17
- torch.backends.cuda.enable_flash_sdp(False)
18
-
19
- # ── 1. Konstanten ────────────────────────────────────────────────────
20
- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
21
- CHUNK_TOKENS = 50 # β€žMini‑Generateβ€œβ€‘LΓ€nge
22
- START_TOKEN = 128259
23
- NEW_BLOCK_TOKEN = 128257
24
- EOS_TOKEN = 128258
25
- AUDIO_BASE = 128266
26
- VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
27
-
28
- # ── 2. Dynamischer Logit‑Masker ──────────────────────────────────────
29
  class DynamicAudioMask(LogitsProcessor):
30
- """LΓ€sst zu Beginn nur Audio‑ und NEW_BLOCK‑Tokens zu;
31
- EOS erst, wenn min_audio_blocks fertig sind."""
32
- def __init__(self, audio_ids: torch.Tensor, min_audio_blocks: int = 1):
33
  super().__init__()
34
- self.audio_ids = audio_ids
35
- self.ctrl_ids = torch.tensor([NEW_BLOCK_TOKEN], device=audio_ids.device)
36
- self.min_blocks = min_audio_blocks
37
- self.blocks_done = 0
38
-
39
- def __call__(self, input_ids, scores):
40
- allowed = torch.cat([self.audio_ids, self.ctrl_ids])
41
- if self.blocks_done >= self.min_blocks:
42
- allowed = torch.cat([allowed,
43
- torch.tensor([EOS_TOKEN],
44
- device=scores.device)])
45
  mask = torch.full_like(scores, float("-inf"))
46
- mask[:, allowed] = 0
47
  return scores + mask
48
 
49
- # ── 3. FastAPI GrundgerΓΌst ───────────────────────────────────────────
50
  app = FastAPI()
51
 
52
  @app.get("/")
53
- async def ping():
54
- return {"msg": "Orpheus‑TTS OK"}
55
 
56
  @app.on_event("startup")
57
- async def load_models():
58
  global tok, model, snac, masker
59
  print("⏳ Lade Modelle …")
60
-
61
- tok = AutoTokenizer.from_pretrained(REPO)
62
- snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
63
-
64
  model = AutoModelForCausalLM.from_pretrained(
65
  REPO,
66
  low_cpu_mem_usage=True,
67
- device_map={"": 0} if device == "cuda" else None,
68
- torch_dtype=torch.bfloat16 if device == "cuda" else None,
69
- )
70
  model.config.pad_token_id = model.config.eos_token_id
71
  model.config.use_cache = True
72
-
73
- masker = DynamicAudioMask(VALID_AUDIO_IDS.to(device))
74
  print("βœ…Β Modelle geladen")
75
 
76
- # ── 4. Hilfs‑Funktionen ──────────────────────────────────────────────
77
- def build_inputs(text: str, voice: str):
78
  prompt = f"{voice}: {text}"
79
  ids = tok(prompt, return_tensors="pt").input_ids.to(device)
80
- ids = torch.cat(
81
- [ torch.tensor([[START_TOKEN]], device=device),
82
- ids,
83
- torch.tensor([[128009, 128260]], device=device) ], dim=1)
84
  return ids, torch.ones_like(ids)
85
 
86
- def decode_block(block7: list[int]) -> bytes:
87
- l1, l2, l3 = [], [], []
88
- b = block7
89
- l1.append(b[0])
90
- l2.append(b[1] - 4096)
91
- l3.extend([b[2] - 8192, b[3] - 12288])
92
- l2.append(b[4] - 16384)
93
- l3.extend([b[5] - 20480, b[6] - 24576])
94
-
95
- codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1,l2,l3)]
96
- audio = snac.decode(codes).squeeze().cpu().numpy()
97
- return (audio * 32767).astype("int16").tobytes()
98
-
99
- # ── 5. WebSocket‑Endpoint ────────────────────────────────────────────
100
  @app.websocket("/ws/tts")
101
- async def tts(ws: WebSocket):
102
  await ws.accept()
103
  try:
104
- req = json.loads(await ws.receive_text())
105
- text = req.get("text", "")
106
- voice = req.get("voice", "Jakob")
107
 
108
- ids, attn = build_inputs(text, voice)
109
- past = None
110
- last_tok = None
111
- buf = []
 
112
 
113
  while True:
114
  out = model.generate(
@@ -118,40 +106,32 @@ async def tts(ws: WebSocket):
118
  max_new_tokens = CHUNK_TOKENS,
119
  logits_processor= [masker],
120
  do_sample=True, temperature=0.7, top_p=0.95,
121
- return_dict_in_generate=True,
122
- use_cache=True,
123
- return_legacy_cache=True # verhindert Cache‑Warnung
124
- )
125
 
126
  pkv = out.past_key_values
127
- if isinstance(pkv, Cache): # HFΒ β‰₯Β 4.47
128
- pkv = pkv.to_legacy()
129
  past = pkv
130
 
131
- new_toks = out.sequences[0, -out.num_generated_tokens:].tolist()
132
- print("new tokens:", new_toks[:32]) # Debug‑Ausgabe
 
 
133
 
134
- if not new_toks:
135
  raise StopIteration
136
 
137
- for t in new_toks:
138
  last_tok = t
139
-
140
- if t == EOS_TOKEN:
141
- raise StopIteration
142
-
143
- if t == NEW_BLOCK_TOKEN:
144
- buf.clear()
145
- continue
146
-
147
- buf.append(t - AUDIO_BASE)
148
- if len(buf) == 7:
149
  await ws.send_bytes(decode_block(buf))
150
  buf.clear()
151
- masker.blocks_done += 1
152
-
153
- # ab jetzt nur noch 1Β Token + Cache
154
- ids, attn = None, None
155
 
156
  except (StopIteration, WebSocketDisconnect):
157
  pass
@@ -161,12 +141,10 @@ async def tts(ws: WebSocket):
161
  await ws.close(code=1011)
162
  finally:
163
  if ws.client_state.name != "DISCONNECTED":
164
- try:
165
- await ws.close()
166
- except RuntimeError:
167
- pass
168
 
169
- # ── 6. Lokaler Test‑Start ───────────────────────────────────────────
170
  if __name__ == "__main__":
171
  import uvicorn
172
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
1
+ # app.py ──────────────────────────────────────────────────────────────
2
  import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
 
6
  from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
+ # ── 0Β Β·Β Login & Device ───────────────────────────────────────────────
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ torch.backends.cuda.enable_flash_sdp(False) #Β CUDA‑Assert‑Fix
16
+
17
+ # ── 1Β Β·Β Konstanten ───────────────────────────────────────────────────
18
+ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
19
+ CHUNK_TOKENS = 50
20
+ START_TOKEN = 128259
21
+ NEW_BLOCK = 128257
22
+ EOS_TOKEN = 128258
23
+ AUDIO_BASE = 128266
24
+ VALID_AUDIO = torch.arange(AUDIO_BASE, AUDIO_BASE+4096)
25
+
26
+ # ── 2Β Β·Β Logit‑Masker ─────────────────────────────────────────────────
 
 
27
  class DynamicAudioMask(LogitsProcessor):
28
+ def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1):
 
 
29
  super().__init__()
30
+ self.audio_ids = audio_ids
31
+ self.ctrl_ids = torch.tensor([NEW_BLOCK], device=audio_ids.device)
32
+ self.min_blocks = min_blocks
33
+ self.blocks = 0
34
+ def __call__(self, inp, scores):
35
+ allow = torch.cat([self.audio_ids, self.ctrl_ids])
36
+ if self.blocks >= self.min_blocks:
37
+ allow = torch.cat([allow,
38
+ torch.tensor([EOS_TOKEN], device=scores.device)])
 
 
39
  mask = torch.full_like(scores, float("-inf"))
40
+ mask[:, allow] = 0
41
  return scores + mask
42
 
43
+ # ── 3Β Β·Β FastAPI‑App ──────────────────────────────────────────────────
44
  app = FastAPI()
45
 
46
  @app.get("/")
47
+ async def root():
48
+ return {"msg": "Orpheus‑TTS alive"}
49
 
50
  @app.on_event("startup")
51
+ async def load():
52
  global tok, model, snac, masker
53
  print("⏳ Lade Modelle …")
54
+ tok = AutoTokenizer.from_pretrained(REPO)
55
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
 
 
56
  model = AutoModelForCausalLM.from_pretrained(
57
  REPO,
58
  low_cpu_mem_usage=True,
59
+ device_map={"":0} if device=="cuda" else None,
60
+ torch_dtype=torch.bfloat16 if device=="cuda" else None)
 
61
  model.config.pad_token_id = model.config.eos_token_id
62
  model.config.use_cache = True
63
+ masker = DynamicAudioMask(VALID_AUDIO.to(device))
 
64
  print("βœ…Β Modelle geladen")
65
 
66
+ # ── 4Β Β·Β Hilfsfunktionen ──────────────────────────────────────────────
67
+ def build_inputs(text:str, voice:str):
68
  prompt = f"{voice}: {text}"
69
  ids = tok(prompt, return_tensors="pt").input_ids.to(device)
70
+ ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
71
+ ids,
72
+ torch.tensor([[128009,128260]], device=device)],1)
 
73
  return ids, torch.ones_like(ids)
74
 
75
+ def decode_block(block):
76
+ l1,l2,l3=[],[],[]
77
+ l1.append(block[0])
78
+ l2.append(block[1]-4096)
79
+ l3.extend([block[2]-8192, block[3]-12288])
80
+ l2.append(block[4]-16384)
81
+ l3.extend([block[5]-20480, block[6]-24576])
82
+ codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)]
83
+ audio=snac.decode(codes).squeeze().cpu().numpy()
84
+ return (audio*32767).astype("int16").tobytes()
85
+
86
+ # ── 5Β Β·Β WebSocket‑TTS ────────────────────────────────────────────────
 
 
87
  @app.websocket("/ws/tts")
88
+ async def tts(ws:WebSocket):
89
  await ws.accept()
90
  try:
91
+ req = json.loads(await ws.receive_text())
92
+ text = req.get("text","")
93
+ voice = req.get("voice","Jakob")
94
 
95
+ ids, attn = build_inputs(text, voice)
96
+ total_len = ids.shape[1] # LΓ€nge des Prompts
97
+ past = None
98
+ last_tok = None
99
+ buf = []
100
 
101
  while True:
102
  out = model.generate(
 
106
  max_new_tokens = CHUNK_TOKENS,
107
  logits_processor= [masker],
108
  do_sample=True, temperature=0.7, top_p=0.95,
109
+ use_cache=True, return_dict_in_generate=True,
110
+ return_legacy_cache=True)
 
 
111
 
112
  pkv = out.past_key_values
113
+ if isinstance(pkv, Cache): pkv = pkv.to_legacy()
 
114
  past = pkv
115
 
116
+ seq = out.sequences[0].tolist()
117
+ new = seq[total_len:] # alles *nach* Prompt
118
+ total_len = len(seq) # fΓΌrs nΓ€chste Mal
119
+ print("new tokens:", new[:32])
120
 
121
+ if not new: # nichts generiert
122
  raise StopIteration
123
 
124
+ for t in new:
125
  last_tok = t
126
+ if t == EOS_TOKEN: raise StopIteration
127
+ if t == NEW_BLOCK:
128
+ buf.clear(); continue
129
+ buf.append(t-AUDIO_BASE)
130
+ if len(buf)==7:
 
 
 
 
 
131
  await ws.send_bytes(decode_block(buf))
132
  buf.clear()
133
+ masker.blocks += 1
134
+ ids, attn = None, None # ab jetzt 1‑Token‑Step
 
 
135
 
136
  except (StopIteration, WebSocketDisconnect):
137
  pass
 
141
  await ws.close(code=1011)
142
  finally:
143
  if ws.client_state.name != "DISCONNECTED":
144
+ try: await ws.close()
145
+ except RuntimeError: pass
 
 
146
 
147
+ # ── 6Β Β·Β local run ────────────────────────────────────────────────────
148
  if __name__ == "__main__":
149
  import uvicorn
150
  uvicorn.run("app:app", host="0.0.0.0", port=7860)