Tomtom84 commited on
Commit
d44e840
·
verified ·
1 Parent(s): 0b5b901

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -61
app.py CHANGED
@@ -6,46 +6,47 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
  from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
- # ── 0 · Login & Device ───────────────────────────────────────────────
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- torch.backends.cuda.enable_flash_sdp(False) # CUDA‑Assert‑Fix
16
-
17
- # ── 1 · Konstanten ───────────────────────────────────────────────────
18
- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
19
- CHUNK_TOKENS = 50
20
- START_TOKEN = 128259
21
- NEW_BLOCK = 128257
22
- EOS_TOKEN = 128258
23
- AUDIO_BASE = 128266
24
- VALID_AUDIO = torch.arange(AUDIO_BASE, AUDIO_BASE+4096)
25
-
26
- # ── 2 · Logit‑Masker ─────────────────────────────────────────────────
27
- class DynamicAudioMask(LogitsProcessor):
 
28
  def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1):
29
  super().__init__()
30
- self.audio_ids = audio_ids
31
- self.ctrl_ids = torch.tensor([NEW_BLOCK], device=audio_ids.device)
32
- self.min_blocks = min_blocks
33
- self.blocks = 0
34
- def __call__(self, inp, scores):
35
  allow = torch.cat([self.audio_ids, self.ctrl_ids])
36
- if self.blocks >= self.min_blocks:
37
  allow = torch.cat([allow,
38
  torch.tensor([EOS_TOKEN], device=scores.device)])
39
  mask = torch.full_like(scores, float("-inf"))
40
  mask[:, allow] = 0
41
  return scores + mask
42
 
43
- # ── 3 · FastAPI‑App ──────────────────────────────────────────────────
44
  app = FastAPI()
45
 
46
  @app.get("/")
47
  async def root():
48
- return {"msg": "Orpheus‑TTS alive"}
49
 
50
  @app.on_event("startup")
51
  async def load():
@@ -57,13 +58,14 @@ async def load():
57
  REPO,
58
  low_cpu_mem_usage=True,
59
  device_map={"":0} if device=="cuda" else None,
60
- torch_dtype=torch.bfloat16 if device=="cuda" else None)
 
61
  model.config.pad_token_id = model.config.eos_token_id
62
  model.config.use_cache = True
63
- masker = DynamicAudioMask(VALID_AUDIO.to(device))
64
  print("✅ Modelle geladen")
65
 
66
- # ── 4 · Hilfsfunktionen ──────────────────────────────────────────────
67
  def build_inputs(text:str, voice:str):
68
  prompt = f"{voice}: {text}"
69
  ids = tok(prompt, return_tensors="pt").input_ids.to(device)
@@ -72,66 +74,61 @@ def build_inputs(text:str, voice:str):
72
  torch.tensor([[128009,128260]], device=device)],1)
73
  return ids, torch.ones_like(ids)
74
 
75
- def decode_block(block):
76
  l1,l2,l3=[],[],[]
77
- l1.append(block[0])
78
- l2.append(block[1]-4096)
79
- l3.extend([block[2]-8192, block[3]-12288])
80
- l2.append(block[4]-16384)
81
- l3.extend([block[5]-20480, block[6]-24576])
82
  codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)]
83
  audio=snac.decode(codes).squeeze().cpu().numpy()
84
  return (audio*32767).astype("int16").tobytes()
85
 
86
- # ── 5 · WebSocket‑TTS ────────────────────────────────────────────────
87
  @app.websocket("/ws/tts")
88
- async def tts(ws:WebSocket):
89
  await ws.accept()
90
  try:
91
- req = json.loads(await ws.receive_text())
92
- text = req.get("text","")
93
- voice = req.get("voice","Jakob")
94
-
95
- ids, attn = build_inputs(text, voice)
96
- total_len = ids.shape[1] # Länge des Prompts
97
- past = None
98
- last_tok = None
99
- buf = []
100
 
101
  while True:
 
102
  out = model.generate(
103
- input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
104
- attention_mask = attn if past is None else None,
105
- past_key_values = past,
106
- max_new_tokens = CHUNK_TOKENS,
107
- logits_processor= [masker],
108
  do_sample=True, temperature=0.7, top_p=0.95,
109
  use_cache=True, return_dict_in_generate=True,
110
  return_legacy_cache=True)
111
-
112
  pkv = out.past_key_values
 
113
  if isinstance(pkv, Cache): pkv = pkv.to_legacy()
114
  past = pkv
 
115
 
116
- seq = out.sequences[0].tolist()
117
- new = seq[total_len:] # alles *nach* Prompt
118
- total_len = len(seq) # fürs nächste Mal
119
- print("new tokens:", new[:32])
120
-
121
- if not new: # nichts generiert
122
- raise StopIteration
123
 
 
124
  for t in new:
125
  last_tok = t
126
  if t == EOS_TOKEN: raise StopIteration
127
- if t == NEW_BLOCK:
128
- buf.clear(); continue
129
- buf.append(t-AUDIO_BASE)
130
- if len(buf)==7:
131
  await ws.send_bytes(decode_block(buf))
132
  buf.clear()
133
  masker.blocks += 1
134
- ids, attn = None, None # ab jetzt 1‑Token‑Step
 
135
 
136
  except (StopIteration, WebSocketDisconnect):
137
  pass
@@ -144,7 +141,7 @@ async def tts(ws:WebSocket):
144
  try: await ws.close()
145
  except RuntimeError: pass
146
 
147
- # ── 6 · local run ────────────────────────────────────────────────────
148
  if __name__ == "__main__":
149
  import uvicorn
150
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
6
  from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
+ # 0 · Auth & Device ---------------------------------------------------
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ torch.backends.cuda.enable_flash_sdp(False) # SDP‑Assert fix
16
+
17
+ # 1 · Konstanten ------------------------------------------------------
18
+ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
19
+ CHUNK_TOKENS = 50
20
+ START_TOKEN = 128259
21
+ NEW_BLOCK = 128257
22
+ EOS_TOKEN = 128258
23
+ AUDIO_BASE = 128266
24
+ AUDIO_SPAN = 4096 * 7 # 28 672 Codes
25
+ VALID_AUDIO = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
26
+
27
+ # 2 · Logit‑Masker ----------------------------------------------------
28
+ class DynamicMask(LogitsProcessor):
29
  def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1):
30
  super().__init__()
31
+ self.audio_ids = audio_ids
32
+ self.ctrl_ids = torch.tensor([NEW_BLOCK], device=audio_ids.device)
33
+ self.blocks = 0
34
+ self.min_blk = min_blocks
35
+ def __call__(self, inp_ids, scores):
36
  allow = torch.cat([self.audio_ids, self.ctrl_ids])
37
+ if self.blocks >= self.min_blk:
38
  allow = torch.cat([allow,
39
  torch.tensor([EOS_TOKEN], device=scores.device)])
40
  mask = torch.full_like(scores, float("-inf"))
41
  mask[:, allow] = 0
42
  return scores + mask
43
 
44
+ # 3 · FastAPI‑App -----------------------------------------------------
45
  app = FastAPI()
46
 
47
  @app.get("/")
48
  async def root():
49
+ return {"msg": "Orpheus‑TTS online"}
50
 
51
  @app.on_event("startup")
52
  async def load():
 
58
  REPO,
59
  low_cpu_mem_usage=True,
60
  device_map={"":0} if device=="cuda" else None,
61
+ torch_dtype=torch.bfloat16 if device=="cuda" else None,
62
+ )
63
  model.config.pad_token_id = model.config.eos_token_id
64
  model.config.use_cache = True
65
+ masker = DynamicMask(VALID_AUDIO.to(device))
66
  print("✅ Modelle geladen")
67
 
68
+ # 4 · Hilfsfunktionen -------------------------------------------------
69
  def build_inputs(text:str, voice:str):
70
  prompt = f"{voice}: {text}"
71
  ids = tok(prompt, return_tensors="pt").input_ids.to(device)
 
74
  torch.tensor([[128009,128260]], device=device)],1)
75
  return ids, torch.ones_like(ids)
76
 
77
+ def decode_block(b):
78
  l1,l2,l3=[],[],[]
79
+ l1.append(b[0])
80
+ l2.append(b[1]-4096)
81
+ l3 += [b[2]-8192, b[3]-12288]
82
+ l2.append(b[4]-16384)
83
+ l3 += [b[5]-20480, b[6]-24576]
84
  codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)]
85
  audio=snac.decode(codes).squeeze().cpu().numpy()
86
  return (audio*32767).astype("int16").tobytes()
87
 
88
+ # 5 · WebSocket‑Endpoint ---------------------------------------------
89
  @app.websocket("/ws/tts")
90
+ async def tts(ws: WebSocket):
91
  await ws.accept()
92
  try:
93
+ req = json.loads(await ws.receive_text())
94
+ ids, attn = build_inputs(req.get("text",""), req.get("voice","Jakob"))
95
+ past, last_tok, buf = None, None, []
96
+ prompt_len = ids.shape[1]
 
 
 
 
 
97
 
98
  while True:
99
+ print(f"DEBUG: Before generate - past is None: {past is None}") # Added logging
100
  out = model.generate(
101
+ input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
102
+ attention_mask = attn if past is None else None,
103
+ past_key_values= past,
104
+ max_new_tokens = CHUNK_TOKENS,
105
+ logits_processor=[masker],
106
  do_sample=True, temperature=0.7, top_p=0.95,
107
  use_cache=True, return_dict_in_generate=True,
108
  return_legacy_cache=True)
109
+ print(f"DEBUG: After generate - type of out.past_key_values: {type(out.past_key_values)}") # Added logging
110
  pkv = out.past_key_values
111
+ print(f"DEBUG: After getting pkv - type of pkv: {type(pkv)}") # Added logging
112
  if isinstance(pkv, Cache): pkv = pkv.to_legacy()
113
  past = pkv
114
+ print(f"DEBUG: After cache handling - past is None: {past is None}") # Added logging
115
 
116
+ seq = out.sequences[0].tolist()
117
+ new = seq[prompt_len:]; prompt_len = len(seq)
118
+ print("new tokens:", new[:25])
 
 
 
 
119
 
120
+ if not new: raise StopIteration
121
  for t in new:
122
  last_tok = t
123
  if t == EOS_TOKEN: raise StopIteration
124
+ if t == NEW_BLOCK: buf.clear(); continue
125
+ buf.append(t - AUDIO_BASE)
126
+ if len(buf) == 7:
 
127
  await ws.send_bytes(decode_block(buf))
128
  buf.clear()
129
  masker.blocks += 1
130
+
131
+ ids, attn = None, None # ab jetzt 1‑Token‑Step
132
 
133
  except (StopIteration, WebSocketDisconnect):
134
  pass
 
141
  try: await ws.close()
142
  except RuntimeError: pass
143
 
144
+ # 6 · Local run -------------------------------------------------------
145
  if __name__ == "__main__":
146
  import uvicorn
147
  uvicorn.run("app:app", host="0.0.0.0", port=7860)