Tomtom84 commited on
Commit
5031731
Β·
verified Β·
1 Parent(s): 1bad3fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -76
app.py CHANGED
@@ -1,149 +1,173 @@
1
- # app.py ──────────────────────────────────────────────────────────────
2
- import os, json, asyncio, torch, logging
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
- from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
- # ── 0. Auth & Device ────────────────────────────────────────────────
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- torch.backends.cuda.enable_flash_sdp(False) # Flash‑Bug umgehen
16
- logging.getLogger("transformers.generation.utils").setLevel("ERROR")
17
 
18
- # ── 1. Konstanten ───────────────────────────────────────────────────
19
- MODEL_REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
20
- CHUNK_TOKENS = 50
21
-
22
- START_TOKEN = 128259 # <𝑠>
23
- NEW_BLOCK_TOKEN = 128257 # πŸ”Šβ€‘Start
24
- EOS_TOKEN = 128258 # <eos>
25
- PROMPT_END = [128009, 128260]
26
- AUDIO_BASE = 128266
27
 
 
 
 
 
 
 
 
28
  VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
29
 
30
- # ── 2. Logit‑Masker ─────────────────────────────────────────────────
31
- class AudioMask(LogitsProcessor):
32
- def __init__(self, allowed: torch.Tensor):
 
 
 
33
  super().__init__()
34
- self.allowed = allowed
 
 
 
35
 
36
  def __call__(self, input_ids, scores):
 
 
 
37
  mask = torch.full_like(scores, float("-inf"))
38
- mask[:, self.allowed] = 0
39
  return scores + mask
40
 
41
- ALLOWED_IDS = torch.cat([
42
- VALID_AUDIO_IDS,
43
- torch.tensor([START_TOKEN, NEW_BLOCK_TOKEN, EOS_TOKEN])
44
- ]).to(device)
45
- MASKER = AudioMask(ALLOWED_IDS)
46
-
47
- # ── 3. FastAPI GrundgerΓΌst ──────────────────────────────────────────
48
  app = FastAPI()
49
 
50
  @app.get("/")
51
  async def ping():
52
- return {"message": "Orpheus‑TTSΒ ready"}
53
 
54
  @app.on_event("startup")
55
  async def load_models():
56
- global tok, model, snac
57
- tok = AutoTokenizer.from_pretrained(MODEL_REPO)
 
 
 
 
58
  model = AutoModelForCausalLM.from_pretrained(
59
- MODEL_REPO,
60
  low_cpu_mem_usage=True,
61
  device_map={"": 0} if device == "cuda" else None,
62
  torch_dtype=torch.bfloat16 if device == "cuda" else None,
63
  )
64
  model.config.pad_token_id = model.config.eos_token_id
65
- snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
 
 
 
66
 
67
- # ── 4. Hilfsfunktionen ──────────────────────────────────────────────
68
  def build_inputs(text: str, voice: str):
69
- prompt = f"{voice}: {text}" if voice and voice != "in_prompt" else text
70
  ids = tok(prompt, return_tensors="pt").input_ids.to(device)
71
- ids = torch.cat([
72
- torch.tensor([[START_TOKEN]], device=device),
73
- ids,
74
- torch.tensor([PROMPT_END], device=device)
75
- ], 1)
76
- mask = torch.ones_like(ids)
77
- return ids, mask # shape (1,Β L)
 
78
 
79
  def decode_block(block7: list[int]) -> bytes:
80
  l1, l2, l3 = [], [], []
81
  b = block7
82
  l1.append(b[0])
83
- l2.append(b[1] - 4096)
84
- l3 += [b[2]-8192, b[3]-12288]
85
  l2.append(b[4] - 16384)
86
- l3 += [b[5]-20480, b[6]-24576]
87
 
88
- codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)]
 
 
 
 
89
  audio = snac.decode(codes).squeeze().cpu().numpy()
90
  return (audio * 32767).astype("int16").tobytes()
91
 
92
- # ── 5. WebSocket‑Endpoint ───────────────────────────────────────────
93
  @app.websocket("/ws/tts")
94
  async def tts(ws: WebSocket):
95
  await ws.accept()
96
  try:
97
  req = json.loads(await ws.receive_text())
98
- ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
99
- prompt_len = ids.size(1)
100
- past, buf = None, []
 
 
 
101
 
102
  while True:
103
  out = model.generate(
104
- input_ids=ids if past is None else None,
105
- attention_mask=attn if past is None else None,
106
- past_key_values=past,
107
- max_new_tokens=CHUNK_TOKENS,
108
- logits_processor=[MASKER],
109
- do_sample=True, top_p=0.95, temperature=0.7,
110
  return_dict_in_generate=True,
111
  use_cache=True,
112
- return_legacy_cache=True, # β‡  Warnung verschwindet
113
  )
114
 
115
- past = out.past_key_values # unverΓ€ndert weiterreichen
116
- seq = out.sequences[0].tolist()
117
- new = seq[prompt_len:]; prompt_len = len(seq)
 
 
118
 
119
- if not new: # selten, aber mΓΆglich
120
- continue
121
 
 
122
  for t in new:
123
  if t == EOS_TOKEN:
124
- await ws.close()
125
- return
126
  if t == NEW_BLOCK_TOKEN:
127
- buf.clear(); continue
128
- if t < AUDIO_BASE: # sollte durch Maske nie passieren
129
  continue
 
130
  buf.append(t - AUDIO_BASE)
131
  if len(buf) == 7:
132
  await ws.send_bytes(decode_block(buf))
133
  buf.clear()
 
134
 
135
- # Ab jetzt nur noch Cache – IDs & Mask nicht mehr nΓΆtig
136
- ids = attn = None
137
 
138
- except WebSocketDisconnect:
139
  pass
140
  except Exception as e:
141
- print("WS‑Error:", e)
142
- if ws.client_state.name == "CONNECTED":
143
  await ws.close(code=1011)
144
-
145
- # ── 6. Lokaler Start ────────────────────────────────────────────────
 
 
 
 
 
 
146
  if __name__ == "__main__":
147
- import uvicorn, sys
148
- port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860
149
- uvicorn.run("app:app", host="0.0.0.0", port=port)
 
1
+ # app.py ─────────────────────────────────────────────────────────────
2
+ import os, json, asyncio, torch
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
+ from transformers import (AutoTokenizer, AutoModelForCausalLM, LogitsProcessor)
6
+ from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
+ # ── 0. HF‑Login & Device ─────────────────────────────────────────────
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
15
 
16
+ #Β Flash‑Attention‑Bug in PyTorchΒ 2.2.x umgehen
17
+ torch.backends.cuda.enable_flash_sdp(False)
 
 
 
 
 
 
 
18
 
19
+ # ── 1. Konstanten ────────────────────────────────────────────────────
20
+ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
21
+ CHUNK_TOKENS = 50 # pro mini‑generate
22
+ START_TOKEN = 128259
23
+ NEW_BLOCK_TOKEN = 128257
24
+ EOS_TOKEN = 128258
25
+ AUDIO_BASE = 128266 # erster Audio‑Code
26
  VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
27
 
28
+ # ── 2. Dynamischer Logit‑Masker ──────────────────────────────────────
29
+ class DynamicAudioMask(LogitsProcessor):
30
+ """
31
+ blockt EOS, bis mindestens `min_audio_blocks` gesendet wurden
32
+ """
33
+ def __init__(self, audio_ids: torch.Tensor, min_audio_blocks: int = 1):
34
  super().__init__()
35
+ self.audio_ids = audio_ids
36
+ self.ctrl_ids = torch.tensor([NEW_BLOCK_TOKEN], device=audio_ids.device)
37
+ self.min_blocks = min_audio_blocks
38
+ self.blocks_done = 0
39
 
40
  def __call__(self, input_ids, scores):
41
+ allowed = torch.cat([self.audio_ids, self.ctrl_ids])
42
+ if self.blocks_done >= self.min_blocks: # jetzt darf EOS dazu
43
+ allowed = torch.cat([allowed, torch.tensor([EOS_TOKEN], device=scores.device)])
44
  mask = torch.full_like(scores, float("-inf"))
45
+ mask[:, allowed] = 0
46
  return scores + mask
47
 
48
+ # ── 3. FastAPI GrundgerΓΌst ───────────────────────────────────────────
 
 
 
 
 
 
49
  app = FastAPI()
50
 
51
  @app.get("/")
52
  async def ping():
53
+ return {"msg": "Orpheus‑TTS up & running"}
54
 
55
  @app.on_event("startup")
56
  async def load_models():
57
+ global tok, model, snac, masker
58
+ print("⏳ Lade Modelle …")
59
+
60
+ tok = AutoTokenizer.from_pretrained(REPO)
61
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
62
+
63
  model = AutoModelForCausalLM.from_pretrained(
64
+ REPO,
65
  low_cpu_mem_usage=True,
66
  device_map={"": 0} if device == "cuda" else None,
67
  torch_dtype=torch.bfloat16 if device == "cuda" else None,
68
  )
69
  model.config.pad_token_id = model.config.eos_token_id
70
+ model.config.use_cache = True
71
+
72
+ masker = DynamicAudioMask(VALID_AUDIO_IDS.to(device))
73
+ print("βœ…Β Modelle geladen")
74
 
75
+ # ── 4. Hilfs‑Funktionen ──────────────────────────────────────────────
76
  def build_inputs(text: str, voice: str):
77
+ prompt = f"{voice}: {text}"
78
  ids = tok(prompt, return_tensors="pt").input_ids.to(device)
79
+ ids = torch.cat(
80
+ [ torch.tensor([[START_TOKEN]], device=device),
81
+ ids,
82
+ torch.tensor([[128009, 128260]], device=device) ],
83
+ dim=1,
84
+ )
85
+ attn = torch.ones_like(ids)
86
+ return ids, attn
87
 
88
  def decode_block(block7: list[int]) -> bytes:
89
  l1, l2, l3 = [], [], []
90
  b = block7
91
  l1.append(b[0])
92
+ l2.append(b[1] - 4096)
93
+ l3.extend([b[2] - 8192, b[3] - 12288])
94
  l2.append(b[4] - 16384)
95
+ l3.extend([b[5] - 20480, b[6] - 24576])
96
 
97
+ codes = [
98
+ torch.tensor(l1, device=device).unsqueeze(0),
99
+ torch.tensor(l2, device=device).unsqueeze(0),
100
+ torch.tensor(l3, device=device).unsqueeze(0),
101
+ ]
102
  audio = snac.decode(codes).squeeze().cpu().numpy()
103
  return (audio * 32767).astype("int16").tobytes()
104
 
105
+ # ── 5. WebSocket‑TTS‑Endpoint ───────────────────────────────────────
106
  @app.websocket("/ws/tts")
107
  async def tts(ws: WebSocket):
108
  await ws.accept()
109
  try:
110
  req = json.loads(await ws.receive_text())
111
+ text = req.get("text", "")
112
+ voice = req.get("voice", "Jakob")
113
+
114
+ ids, attn = build_inputs(text, voice)
115
+ past = None
116
+ buf = []
117
 
118
  while True:
119
  out = model.generate(
120
+ input_ids = ids if past is None else None,
121
+ attention_mask = attn if past is None else None,
122
+ past_key_values = past,
123
+ max_new_tokens = CHUNK_TOKENS,
124
+ logits_processor= [masker], # β–Ί dynamischer Masker
125
+ do_sample=True, temperature=0.7, top_p=0.95,
126
  return_dict_in_generate=True,
127
  use_cache=True,
 
128
  )
129
 
130
+ # Cache & neue Tokens extrahieren --------------------------------
131
+ pkv = out.past_key_values
132
+ if isinstance(pkv, Cache): # HFΒ >=Β 4.47
133
+ pkv = pkv.to_legacy()
134
+ past = pkv
135
 
136
+ new = out.sequences[0, -out.num_generated_tokens :].tolist()
137
+ print("new tokens:", new[:20]) # Debug‑Ausgabe
138
 
139
+ # ----------------------------------------------------------------
140
  for t in new:
141
  if t == EOS_TOKEN:
142
+ raise StopIteration
143
+
144
  if t == NEW_BLOCK_TOKEN:
145
+ buf.clear()
 
146
  continue
147
+
148
  buf.append(t - AUDIO_BASE)
149
  if len(buf) == 7:
150
  await ws.send_bytes(decode_block(buf))
151
  buf.clear()
152
+ masker.blocks_done += 1 # β–ΊΒ jetzt darf ggf. EOS
153
 
154
+ # nΓ€chsten generate‑Step nur noch mit Cache, keine neuen ids
155
+ ids, attn = None, None
156
 
157
+ except (StopIteration, WebSocketDisconnect):
158
  pass
159
  except Exception as e:
160
+ print("❌ WS‑Error:", e)
161
+ if ws.client_state.name != "DISCONNECTED":
162
  await ws.close(code=1011)
163
+ finally:
164
+ if ws.client_state.name != "DISCONNECTED":
165
+ try:
166
+ await ws.close()
167
+ except RuntimeError:
168
+ pass # Close‑Frame war schon raus
169
+
170
+ # ── 6. Lokaler Start (uvicorn) ───────────────────────────────────────
171
  if __name__ == "__main__":
172
+ import uvicorn
173
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)