Tomtom84 commited on
Commit
f92444a
Β·
verified Β·
1 Parent(s): 9ef5e61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -76
app.py CHANGED
@@ -1,53 +1,52 @@
1
- # app.py ─────────────────────────────────────────────────────────────
2
- import os, json, asyncio, torch
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
6
  from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
- # ── 0.Β HF‑Auth & Device ──────────────────────────────────────────────
10
- HF_TOKEN = os.getenv("HF_TOKEN")
11
- if HF_TOKEN:
12
- login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- # Flash‑Attention‑Bug in PyTorchΒ 2.2.x
17
- torch.backends.cuda.enable_flash_sdp(False)
18
-
19
- # ── 1.Β Konstanten ────────────────────────────────────────────────────
20
- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
21
- CHUNK_TOKENS = 50
22
- START_TOKEN = 128259
23
- NEW_BLOCK_TOKEN = 128257
24
- EOS_TOKEN = 128258
25
- AUDIO_BASE = 128266
26
- VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
27
-
28
- # ── 2.Β Logit‑Processor zum Maskieren ────────────────────────────────
29
- class AudioLogitMask(LogitsProcessor):
30
- def __init__(self, allowed_ids: torch.Tensor):
31
- super().__init__()
32
- self.allowed = allowed_ids
33
-
34
- def __call__(self, input_ids, scores):
35
- # scores shape: [batch, vocab]
36
  mask = torch.full_like(scores, float("-inf"))
37
- mask[:, self.allowed] = 0
38
  return scores + mask
39
 
40
  ALLOWED_IDS = torch.cat(
41
- [VALID_AUDIO_IDS, torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])]
 
42
  ).to(device)
43
- MASKER = AudioLogitMask(ALLOWED_IDS)
44
 
45
- # ── 3.Β FastAPI ‑ GrundgerΓΌst ─────────────────────────────────────────
46
  app = FastAPI()
47
 
48
  @app.get("/")
49
- async def ping():
50
- return {"msg": "Orpheus‑TTS OK"}
 
 
 
51
 
52
  @app.on_event("startup")
53
  async def load_models():
@@ -63,11 +62,11 @@ async def load_models():
63
  model.config.pad_token_id = model.config.eos_token_id
64
  model.config.use_cache = True
65
 
66
- # ── 4.Β Hilfs‑Funktionen ─────────────────────────────────────────────
67
- def build_prompt(text:str, voice:str):
68
- base = f"{voice}: {text}"
69
- ids = tok(base, return_tensors="pt").input_ids.to(device)
70
- ids = torch.cat(
71
  [
72
  torch.tensor([[START_TOKEN]], device=device),
73
  ids,
@@ -77,29 +76,32 @@ def build_prompt(text:str, voice:str):
77
  )
78
  return ids, torch.ones_like(ids)
79
 
80
- def decode_snac(block7:list[int])->bytes:
81
- l1,l2,l3=[],[],[]
82
- b=block7
83
- l1.append(b[0])
84
- l2.append(b[1]-4096)
85
- l3.extend([b[2]-8192, b[3]-12288])
86
- l2.append(b[4]-16384)
87
- l3.extend([b[5]-20480, b[6]-24576])
88
-
89
- codes=[torch.tensor(x,device=device).unsqueeze(0)
90
- for x in (l1,l2,l3)]
91
- audio=snac.decode(codes).squeeze().cpu().numpy()
92
- return (audio*32767).astype("int16").tobytes()
93
-
94
- # ── 5.Β WebSocket‑Endpoint ───────────────────────────────────────────
 
 
95
  @app.websocket("/ws/tts")
96
  async def tts(ws: WebSocket):
97
  await ws.accept()
98
  try:
99
  req = json.loads(await ws.receive_text())
100
  ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
 
101
 
102
- past = None # Cache
103
  buf = []
104
 
105
  while True:
@@ -112,14 +114,18 @@ async def tts(ws: WebSocket):
112
  do_sample=True, top_p=0.95, temperature=0.7,
113
  return_dict_in_generate=True,
114
  use_cache=True,
115
- return_legacy_cache=True, # ⚑ wichtig
116
  )
117
 
118
- # ⚑ Legacy‑Cache weitergeben
119
- past = gen.past_key_values
 
 
 
 
120
 
121
- # die tatsΓ€chlich erzeugten neuen Tokens
122
- new_tok = gen.sequences[0, -gen.num_generated_tokens :].tolist()
123
 
124
  for t in new_tok:
125
  if t == EOS_TOKEN:
@@ -132,24 +138,20 @@ async def tts(ws: WebSocket):
132
  await ws.send_bytes(decode_block(buf))
133
  buf.clear()
134
 
135
- # ab jetzt nur Cache, keine neuen IDs mehr nΓΆtig
136
- ids, attn = None, None
137
 
138
- except (StopIteration, WebSocketDisconnect):
139
- pass # normales Ende
140
  except Exception as e:
141
  print("WS‑Error:", e)
142
- if ws.client_state.name != "DISCONNECTED":
143
- await ws.close(code=1011) # Fehlercode nur, falls noch offen
144
  finally:
145
- try:
146
- if ws.client_state.name != "DISCONNECTED":
147
- await ws.close() # sauberes Close
148
- except RuntimeError:
149
- # Starlette hat bereits ein Close‑Frame verschickt
150
- pass
151
-
152
- # ── 6.Β Lokaler Test ─────────────────────────────────────────────────
153
  if __name__ == "__main__":
154
- import uvicorn
155
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
 
1
+ # app.py -------------------------------------------------------------
2
+ import os, json, torch
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
  from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
+ # ── 0. Auth & Device ────────────────────────────────────────────────
10
+ if (tok := os.getenv("HF_TOKEN")):
11
+ login(tok)
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2 fix
15
+
16
+ # ── 1. Konstanten ───────────────────────────────────────────────────
17
+ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
18
+ CHUNK_TOKENS = 50 # ≀ 50Β β†’Β <Β 1Β s Latenz
19
+ START_TOKEN = 128259
20
+ NEW_BLOCK_TOKEN = 128257
21
+ EOS_TOKEN = 128258
22
+ AUDIO_BASE = 128266
23
+ VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
24
+
25
+ # ── 2. Logit‑Maske (nur Audio‑ und Steuer‑Token) ──────────────────
26
+ class AudioMask(LogitsProcessor):
27
+ def __init__(self, allowed: torch.Tensor): # allowed @device!
28
+ self.allowed = allowed
29
+
30
+ def __call__(self, _ids, scores):
 
 
 
 
31
  mask = torch.full_like(scores, float("-inf"))
32
+ mask[:, self.allowed] = 0.0
33
  return scores + mask
34
 
35
  ALLOWED_IDS = torch.cat(
36
+ [VALID_AUDIO_IDS,
37
+ torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])]
38
  ).to(device)
39
+ MASKER = AudioMask(ALLOWED_IDS)
40
 
41
+ # ── 3. FastAPI GrundgerΓΌst ──────────────────────────────────────────
42
  app = FastAPI()
43
 
44
  @app.get("/")
45
+ async def root():
46
+ return {"msg": "Orpheus‑TTS ready"}
47
+
48
+ # global handles
49
+ tok = model = snac = None
50
 
51
  @app.on_event("startup")
52
  async def load_models():
 
62
  model.config.pad_token_id = model.config.eos_token_id
63
  model.config.use_cache = True
64
 
65
+ # ── 4. Helper ───────────────────────────────────────────────────────
66
+ def build_inputs(text: str, voice: str):
67
+ prompt = f"{voice}: {text}"
68
+ ids = tok(prompt, return_tensors="pt").input_ids.to(device)
69
+ ids = torch.cat(
70
  [
71
  torch.tensor([[START_TOKEN]], device=device),
72
  ids,
 
76
  )
77
  return ids, torch.ones_like(ids)
78
 
79
+ def decode_block(b7: list[int]) -> bytes:
80
+ l1, l2, l3 = [], [], []
81
+ l1.append(b7[0])
82
+ l2.append(b7[1] - 4096)
83
+ l3.extend([b7[2] - 8192, b7[3] - 12288])
84
+ l2.append(b7[4] - 16384)
85
+ l3.extend([b7[5] - 20480, b7[6] - 24576])
86
+
87
+ codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)]
88
+ audio = snac.decode(codes).squeeze().cpu().numpy()
89
+ return (audio * 32767).astype("int16").tobytes()
90
+
91
+ def new_tokens_only(full_seq, prev_len):
92
+ """liefert Liste der Tokens, die *neu* hinzukamen"""
93
+ return full_seq[prev_len:].tolist()
94
+
95
+ # ── 5. WebSocket‑Endpoint ───────────────────────────────────────────
96
  @app.websocket("/ws/tts")
97
  async def tts(ws: WebSocket):
98
  await ws.accept()
99
  try:
100
  req = json.loads(await ws.receive_text())
101
  ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
102
+ prompt_len = ids.size(1) # LΓ€nge des Prompts
103
 
104
+ past = None
105
  buf = []
106
 
107
  while True:
 
114
  do_sample=True, top_p=0.95, temperature=0.7,
115
  return_dict_in_generate=True,
116
  use_cache=True,
117
+ return_legacy_cache=True, # wichtig <4.49
118
  )
119
 
120
+ # Cache fΓΌr den nΓ€chsten Loop
121
+ past = gen.past_key_values if not isinstance(gen.past_key_values, Cache) else gen.past_key_values.to_legacy()
122
+
123
+ seq = gen.sequences[0].tolist()
124
+ new_tok = new_tokens_only(seq, prompt_len)
125
+ prompt_len = len(seq) # nΓ€chstes Delta
126
 
127
+ if not new_tok: # (selten) nichts erzeugt β‡’ weiter
128
+ continue
129
 
130
  for t in new_tok:
131
  if t == EOS_TOKEN:
 
138
  await ws.send_bytes(decode_block(buf))
139
  buf.clear()
140
 
141
+ ids = None; attn = None # ab jetzt nur noch Cache
 
142
 
143
+ except (StopAsyncIteration, WebSocketDisconnect):
144
+ pass
145
  except Exception as e:
146
  print("WS‑Error:", e)
147
+ if ws.client_state.name == "CONNECTED":
148
+ await ws.close(code=1011)
149
  finally:
150
+ if ws.client_state.name == "CONNECTED":
151
+ await ws.close()
152
+
153
+ # ── 6. Local run ────────────────────────────────────────────────────
 
 
 
 
154
  if __name__ == "__main__":
155
+ import uvicorn, sys
156
+ port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860
157
+ uvicorn.run("app:app", host="0.0.0.0", port=port)