Tomtom84 commited on
Commit
87012a8
Β·
verified Β·
1 Parent(s): 5d73119

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -42
app.py CHANGED
@@ -1,46 +1,47 @@
1
  # app.py ─────────────────────────────────────────────────────────────
2
- import os, json, asyncio, torch
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
- from transformers import (AutoTokenizer, AutoModelForCausalLM, LogitsProcessor)
6
  from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
- # ── 0. HF‑Login & Device ─────────────────────────────────────────────
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
- #Β Flash‑Attention‑Bug in PyTorchΒ 2.2.x umgehen
17
  torch.backends.cuda.enable_flash_sdp(False)
18
 
19
  # ── 1. Konstanten ────────────────────────────────────────────────────
20
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
21
- CHUNK_TOKENS = 50 # pro mini‑generate
22
  START_TOKEN = 128259
23
  NEW_BLOCK_TOKEN = 128257
24
  EOS_TOKEN = 128258
25
- AUDIO_BASE = 128266 # erster Audio‑Code
26
  VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
27
 
28
  # ── 2. Dynamischer Logit‑Masker ──────────────────────────────────────
29
  class DynamicAudioMask(LogitsProcessor):
30
- """
31
- blockt EOS, bis mindestens `min_audio_blocks` gesendet wurden
32
- """
33
  def __init__(self, audio_ids: torch.Tensor, min_audio_blocks: int = 1):
34
  super().__init__()
35
- self.audio_ids = audio_ids
36
- self.ctrl_ids = torch.tensor([NEW_BLOCK_TOKEN], device=audio_ids.device)
37
- self.min_blocks = min_audio_blocks
38
- self.blocks_done = 0
39
 
40
  def __call__(self, input_ids, scores):
41
  allowed = torch.cat([self.audio_ids, self.ctrl_ids])
42
- if self.blocks_done >= self.min_blocks: # jetzt darf EOS dazu
43
- allowed = torch.cat([allowed, torch.tensor([EOS_TOKEN], device=scores.device)])
 
 
44
  mask = torch.full_like(scores, float("-inf"))
45
  mask[:, allowed] = 0
46
  return scores + mask
@@ -50,7 +51,7 @@ app = FastAPI()
50
 
51
  @app.get("/")
52
  async def ping():
53
- return {"msg": "Orpheus‑TTS up & running"}
54
 
55
  @app.on_event("startup")
56
  async def load_models():
@@ -79,30 +80,23 @@ def build_inputs(text: str, voice: str):
79
  ids = torch.cat(
80
  [ torch.tensor([[START_TOKEN]], device=device),
81
  ids,
82
- torch.tensor([[128009, 128260]], device=device) ],
83
- dim=1,
84
- )
85
- attn = torch.ones_like(ids)
86
- return ids, attn
87
 
88
  def decode_block(block7: list[int]) -> bytes:
89
  l1, l2, l3 = [], [], []
90
  b = block7
91
  l1.append(b[0])
92
  l2.append(b[1] - 4096)
93
- l3.extend([b[2] - 8192, b[3] - 12288])
94
  l2.append(b[4] - 16384)
95
  l3.extend([b[5] - 20480, b[6] - 24576])
96
 
97
- codes = [
98
- torch.tensor(l1, device=device).unsqueeze(0),
99
- torch.tensor(l2, device=device).unsqueeze(0),
100
- torch.tensor(l3, device=device).unsqueeze(0),
101
- ]
102
  audio = snac.decode(codes).squeeze().cpu().numpy()
103
  return (audio * 32767).astype("int16").tobytes()
104
 
105
- # ── 5. WebSocket‑TTS‑Endpoint ───────────────────────────────────────
106
  @app.websocket("/ws/tts")
107
  async def tts(ws: WebSocket):
108
  await ws.accept()
@@ -111,9 +105,9 @@ async def tts(ws: WebSocket):
111
  text = req.get("text", "")
112
  voice = req.get("voice", "Jakob")
113
 
114
- ids, attn = build_inputs(text, voice) # vollstΓ€ndiger Prompt
115
  past = None
116
- last_tok = None # <- NEU
117
  buf = []
118
 
119
  while True:
@@ -126,24 +120,22 @@ async def tts(ws: WebSocket):
126
  do_sample=True, temperature=0.7, top_p=0.95,
127
  return_dict_in_generate=True,
128
  use_cache=True,
129
- return_legacy_cache=True, # <- Warnung unterdrΓΌcken
130
  )
131
 
132
- # ----- Cache & neue Token --------------------------------------
133
  pkv = out.past_key_values
134
- if isinstance(pkv, Cache): # HFΒ >=Β 4.47
135
  pkv = pkv.to_legacy()
136
  past = pkv
137
 
138
- new = out.sequences[0, -out.num_generated_tokens :].tolist()
139
- print("new tokens:", new[:20]) # Debug‑Print
140
 
141
- if not new: # Safety – nichts erzeugt
142
  raise StopIteration
143
 
144
- # ----- Token‑Handling ------------------------------------------
145
- for t in new:
146
- last_tok = t # speichern fΓΌr nΓ€chste Runde
147
 
148
  if t == EOS_TOKEN:
149
  raise StopIteration
@@ -156,9 +148,9 @@ async def tts(ws: WebSocket):
156
  if len(buf) == 7:
157
  await ws.send_bytes(decode_block(buf))
158
  buf.clear()
159
- masker.blocks_done += 1 # nach 1.Β Block darf EOS
160
 
161
- # ab nΓ€chster Runde nur 1Β Token + Cache
162
  ids, attn = None, None
163
 
164
  except (StopIteration, WebSocketDisconnect):
@@ -174,7 +166,7 @@ async def tts(ws: WebSocket):
174
  except RuntimeError:
175
  pass
176
 
177
- # ── 6. Lokaler Start (uvicorn) ───────────────────────────────────────
178
  if __name__ == "__main__":
179
  import uvicorn
180
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
1
  # app.py ─────────────────────────────────────────────────────────────
2
+ import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
  from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
+ # ── 0. HF‑Auth & Device ──────────────────────────────────────────────
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ # Flash‑SDP‑Bug (PyTorch 2.2) deaktivieren
17
  torch.backends.cuda.enable_flash_sdp(False)
18
 
19
  # ── 1. Konstanten ────────────────────────────────────────────────────
20
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
21
+ CHUNK_TOKENS = 50 # β€žMini‑Generateβ€œβ€‘LΓ€nge
22
  START_TOKEN = 128259
23
  NEW_BLOCK_TOKEN = 128257
24
  EOS_TOKEN = 128258
25
+ AUDIO_BASE = 128266
26
  VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
27
 
28
  # ── 2. Dynamischer Logit‑Masker ──────────────────────────────────────
29
  class DynamicAudioMask(LogitsProcessor):
30
+ """LΓ€sst zu Beginn nur Audio‑ und NEW_BLOCK‑Tokens zu;
31
+ EOS erst, wenn min_audio_blocks fertig sind."""
 
32
  def __init__(self, audio_ids: torch.Tensor, min_audio_blocks: int = 1):
33
  super().__init__()
34
+ self.audio_ids = audio_ids
35
+ self.ctrl_ids = torch.tensor([NEW_BLOCK_TOKEN], device=audio_ids.device)
36
+ self.min_blocks = min_audio_blocks
37
+ self.blocks_done = 0
38
 
39
  def __call__(self, input_ids, scores):
40
  allowed = torch.cat([self.audio_ids, self.ctrl_ids])
41
+ if self.blocks_done >= self.min_blocks:
42
+ allowed = torch.cat([allowed,
43
+ torch.tensor([EOS_TOKEN],
44
+ device=scores.device)])
45
  mask = torch.full_like(scores, float("-inf"))
46
  mask[:, allowed] = 0
47
  return scores + mask
 
51
 
52
  @app.get("/")
53
  async def ping():
54
+ return {"msg": "Orpheus‑TTS OK"}
55
 
56
  @app.on_event("startup")
57
  async def load_models():
 
80
  ids = torch.cat(
81
  [ torch.tensor([[START_TOKEN]], device=device),
82
  ids,
83
+ torch.tensor([[128009, 128260]], device=device) ], dim=1)
84
+ return ids, torch.ones_like(ids)
 
 
 
85
 
86
  def decode_block(block7: list[int]) -> bytes:
87
  l1, l2, l3 = [], [], []
88
  b = block7
89
  l1.append(b[0])
90
  l2.append(b[1] - 4096)
91
+ l3.extend([b[2] - 8192, b[3] - 12288])
92
  l2.append(b[4] - 16384)
93
  l3.extend([b[5] - 20480, b[6] - 24576])
94
 
95
+ codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1,l2,l3)]
 
 
 
 
96
  audio = snac.decode(codes).squeeze().cpu().numpy()
97
  return (audio * 32767).astype("int16").tobytes()
98
 
99
+ # ── 5. WebSocket‑Endpoint ────────────────────────────────────────────
100
  @app.websocket("/ws/tts")
101
  async def tts(ws: WebSocket):
102
  await ws.accept()
 
105
  text = req.get("text", "")
106
  voice = req.get("voice", "Jakob")
107
 
108
+ ids, attn = build_inputs(text, voice)
109
  past = None
110
+ last_tok = None
111
  buf = []
112
 
113
  while True:
 
120
  do_sample=True, temperature=0.7, top_p=0.95,
121
  return_dict_in_generate=True,
122
  use_cache=True,
123
+ return_legacy_cache=True # verhindert Cache‑Warnung
124
  )
125
 
 
126
  pkv = out.past_key_values
127
+ if isinstance(pkv, Cache): # HFΒ β‰₯Β 4.47
128
  pkv = pkv.to_legacy()
129
  past = pkv
130
 
131
+ new_toks = out.sequences[0, -out.num_generated_tokens:].tolist()
132
+ print("new tokens:", new_toks[:32]) # Debug‑Ausgabe
133
 
134
+ if not new_toks:
135
  raise StopIteration
136
 
137
+ for t in new_toks:
138
+ last_tok = t
 
139
 
140
  if t == EOS_TOKEN:
141
  raise StopIteration
 
148
  if len(buf) == 7:
149
  await ws.send_bytes(decode_block(buf))
150
  buf.clear()
151
+ masker.blocks_done += 1
152
 
153
+ # ab jetzt nur noch 1Β Token + Cache
154
  ids, attn = None, None
155
 
156
  except (StopIteration, WebSocketDisconnect):
 
166
  except RuntimeError:
167
  pass
168
 
169
+ # ── 6. Lokaler Test‑Start ───────────────────────────────────────────
170
  if __name__ == "__main__":
171
  import uvicorn
172
  uvicorn.run("app:app", host="0.0.0.0", port=7860)