Tomtom84 commited on
Commit
e3958ab
·
verified ·
1 Parent(s): 9e2fbd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -94
app.py CHANGED
@@ -3,146 +3,159 @@ import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
- from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
- # 0 · Auth & Device ---------------------------------------------------
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- torch.backends.cuda.enable_flash_sdp(False) # SDPAssert fix
16
-
17
- # 1 · Konstanten ------------------------------------------------------
18
- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
19
- CHUNK_TOKENS = 50
20
- START_TOKEN = 128259
21
- NEW_BLOCK = 128257
22
- EOS_TOKEN = 128258
23
- AUDIO_BASE = 128266
24
- AUDIO_SPAN = 4096 * 7 # 28 672 Codes
25
- VALID_AUDIO = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
26
-
27
- # 2 · Logit‑Masker ----------------------------------------------------
28
- class DynamicMask(LogitsProcessor):
29
- def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1):
30
  super().__init__()
31
- self.audio_ids = audio_ids
32
- self.ctrl_ids = torch.tensor([NEW_BLOCK], device=audio_ids.device)
33
- self.blocks = 0
34
- self.min_blk = min_blocks
35
- def __call__(self, inp_ids, scores):
36
- allow = torch.cat([self.audio_ids, self.ctrl_ids])
37
- if self.blocks >= self.min_blk:
38
- allow = torch.cat([allow,
39
- torch.tensor([EOS_TOKEN], device=scores.device)])
40
- mask = torch.full_like(scores, float("-inf"))
41
- mask[:, allow] = 0
42
- return scores + mask
43
-
44
- # 3 · FastAPI‑App -----------------------------------------------------
 
 
45
  app = FastAPI()
46
 
47
  @app.get("/")
48
- async def root():
49
- return {"msg": "Orpheus‑TTS online"}
50
 
51
  @app.on_event("startup")
52
- async def load():
53
  global tok, model, snac, masker
54
- print("⏳ Lade Modelle …")
 
55
  tok = AutoTokenizer.from_pretrained(REPO)
56
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
57
  model = AutoModelForCausalLM.from_pretrained(
58
  REPO,
 
 
59
  low_cpu_mem_usage=True,
60
- device_map={"":0} if device=="cuda" else None,
61
- torch_dtype=torch.bfloat16 if device=="cuda" else None,
62
  )
63
  model.config.pad_token_id = model.config.eos_token_id
64
- model.config.use_cache = True
65
- masker = DynamicMask(VALID_AUDIO.to(device))
66
- print("✅ Modelle geladen")
67
-
68
- # 4 · Hilfsfunktionen -------------------------------------------------
69
- def build_inputs(text:str, voice:str):
70
- prompt = f"{voice}: {text}"
71
- ids = tok(prompt, return_tensors="pt").input_ids.to(device)
72
- ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
73
- ids,
74
- torch.tensor([[128009,128260]], device=device)],1)
75
- return ids, torch.ones_like(ids)
76
-
77
- def decode_block(b):
78
  l1,l2,l3=[],[],[]
79
- l1.append(b[0])
80
- l2.append(b[1]-4096)
81
- l3 += [b[2]-8192, b[3]-12288]
82
- l2.append(b[4]-16384)
83
- l3 += [b[5]-20480, b[6]-24576]
84
- codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)]
85
  with torch.no_grad():
 
 
86
  audio = snac.decode(codes).squeeze().detach().cpu().numpy()
 
87
  return (audio*32767).astype("int16").tobytes()
88
 
89
- # 5 · WebSocket‑Endpoint ---------------------------------------------
90
  @app.websocket("/ws/tts")
91
  async def tts(ws: WebSocket):
92
  await ws.accept()
93
  try:
94
  req = json.loads(await ws.receive_text())
95
- ids, attn = build_inputs(req.get("text",""), req.get("voice","Jakob"))
96
- past, last_tok, buf = None, None, []
97
- prompt_len = ids.shape[1]
 
 
 
 
 
98
 
99
  while True:
100
- print(f"DEBUG: Before generate - past is None: {past is None}") # Added logging
101
- out = model.generate(
102
- input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
103
- attention_mask = attn if past is None else None,
104
- past_key_values= past,
105
- max_new_tokens = CHUNK_TOKENS,
106
- logits_processor=[masker],
107
  do_sample=True, temperature=0.7, top_p=0.95,
108
- use_cache=True, return_dict_in_generate=True,
109
- return_legacy_cache=True)
110
- print(f"DEBUG: After generate - type of out.past_key_values: {type(out.past_key_values)}") # Added logging
111
- pkv = out.past_key_values
112
- print(f"DEBUG: After getting pkv - type of pkv: {type(pkv)}") # Added logging
113
- if isinstance(pkv, Cache): pkv = pkv.to_legacy()
114
- past = pkv
115
- print(f"DEBUG: After cache handling - past is None: {past is None}") # Added logging
116
-
117
- seq = out.sequences[0].tolist()
118
- new = seq[prompt_len:]; prompt_len = len(seq)
119
- print("new tokens:", new[:25])
120
-
121
- if not new: raise StopIteration
 
 
122
  for t in new:
123
- last_tok = t
124
- if t == EOS_TOKEN: raise StopIteration
125
- if t == NEW_BLOCK: buf.clear(); continue
 
 
126
  buf.append(t - AUDIO_BASE)
127
  if len(buf) == 7:
128
  await ws.send_bytes(decode_block(buf))
129
  buf.clear()
130
- masker.blocks += 1
131
-
132
- ids, attn = None, None # ab jetzt 1‑Token‑Step
133
 
134
  except (StopIteration, WebSocketDisconnect):
135
  pass
136
  except Exception as e:
137
- print("❌ WS‑Error:", e)
138
  if ws.client_state.name != "DISCONNECTED":
139
  await ws.close(code=1011)
140
  finally:
141
  if ws.client_state.name != "DISCONNECTED":
142
- try: await ws.close()
143
- except RuntimeError: pass
 
 
144
 
145
- # 6 · Local run -------------------------------------------------------
146
  if __name__ == "__main__":
147
- import uvicorn
148
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
 
6
  from snac import SNAC
7
 
8
+ # 0) Login + Device ---------------------------------------------------
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
  if HF_TOKEN:
11
  login(HF_TOKEN)
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ torch.backends.cuda.enable_flash_sdp(False) # PyTorch2.2‑Bug
15
+
16
+ # 1) Konstanten -------------------------------------------------------
17
+ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
18
+ CHUNK_TOKENS = 50
19
+ START_TOKEN = 128259
20
+ NEW_BLOCK = 128257
21
+ EOS_TOKEN = 128258
22
+ AUDIO_BASE = 128266
23
+ AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
24
+
25
+ # 2) Logit‑Mask (NEW_BLOCK + Audio; EOS erst nach 1. Block) ----------
26
+ class AudioMask(LogitsProcessor):
27
+ def __init__(self, audio_ids: torch.Tensor):
 
28
  super().__init__()
29
+ self.allow = torch.cat([
30
+ torch.tensor([NEW_BLOCK], device=audio_ids.device),
31
+ audio_ids
32
+ ])
33
+ self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
34
+ self.sent_blocks = 0
35
+
36
+ def __call__(self, input_ids, logits):
37
+ allowed = self.allow
38
+ if self.sent_blocks: # ab 1. Block EOS zulassen
39
+ allowed = torch.cat([allowed, self.eos])
40
+ mask = logits.new_full(logits.shape, float("-inf"))
41
+ mask[:, allowed] = 0
42
+ return logits + mask
43
+
44
+ # 3) FastAPI Grundgerüst ---------------------------------------------
45
  app = FastAPI()
46
 
47
  @app.get("/")
48
+ def hello():
49
+ return {"status": "ok"}
50
 
51
  @app.on_event("startup")
52
+ def load_models():
53
  global tok, model, snac, masker
54
+ print("⏳ Lade Modelle …", flush=True)
55
+
56
  tok = AutoTokenizer.from_pretrained(REPO)
57
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
58
  model = AutoModelForCausalLM.from_pretrained(
59
  REPO,
60
+ device_map={"": 0} if device == "cuda" else None,
61
+ torch_dtype=torch.bfloat16 if device == "cuda" else None,
62
  low_cpu_mem_usage=True,
 
 
63
  )
64
  model.config.pad_token_id = model.config.eos_token_id
65
+ masker = AudioMask(AUDIO_IDS.to(device))
66
+
67
+ print("✅ Modelle geladen", flush=True)
68
+
69
+ # 4) Helper -----------------------------------------------------------
70
+ def build_prompt(text: str, voice: str):
71
+ prompt_ids = tok(f"{voice}: {text}", return_tensors="pt").input_ids.to(device)
72
+ ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
73
+ prompt_ids,
74
+ torch.tensor([[128009, 128260]], device=device)], 1)
75
+ attn = torch.ones_like(ids)
76
+ return ids, attn
77
+
78
+ def decode_block(block7: list[int]) -> bytes:
79
  l1,l2,l3=[],[],[]
80
+ l1.append(block7[0])
81
+ l2.append(block7[1]-4096)
82
+ l3 += [block7[2]-8192, block7[3]-12288]
83
+ l2.append(block7[4]-16384)
84
+ l3 += [block7[5]-20480, block7[6]-24576]
85
+
86
  with torch.no_grad():
87
+ codes = [torch.tensor(x, device=device).unsqueeze(0)
88
+ for x in (l1,l2,l3)]
89
  audio = snac.decode(codes).squeeze().detach().cpu().numpy()
90
+
91
  return (audio*32767).astype("int16").tobytes()
92
 
93
+ # 5) WebSocket‑Endpoint ----------------------------------------------
94
  @app.websocket("/ws/tts")
95
  async def tts(ws: WebSocket):
96
  await ws.accept()
97
  try:
98
  req = json.loads(await ws.receive_text())
99
+ text = req.get("text", "")
100
+ voice = req.get("voice", "Jakob")
101
+
102
+ ids, attn = build_prompt(text, voice)
103
+ past = None
104
+ offset_len = ids.size(1) # wie viele Tokens existieren schon
105
+ last_tok = None
106
+ buf = []
107
 
108
  while True:
109
+ # --- Mini‑Generate -------------------------------------------
110
+ gen = model.generate(
111
+ input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
112
+ attention_mask = attn if past is None else None,
113
+ past_key_values = past,
114
+ max_new_tokens = CHUNK_TOKENS,
115
+ logits_processor= [masker],
116
  do_sample=True, temperature=0.7, top_p=0.95,
117
+ use_cache=True
118
+ )
119
+
120
+ # ----- neue Tokens heraus schneiden --------------------------
121
+ new = gen[0, offset_len:].tolist()
122
+ if not new: # nichts -> fertig
123
+ break
124
+ offset_len += len(new)
125
+
126
+ # ----- weiter mit Cache (letzte PKV steht im Modell) ---------
127
+ past = model._past_key_values
128
+ last_tok = new[-1]
129
+
130
+ print("new tokens:", new[:25], flush=True)
131
+
132
+ # ----- Token‑Handling ----------------------------------------
133
  for t in new:
134
+ if t == EOS_TOKEN:
135
+ raise StopIteration
136
+ if t == NEW_BLOCK:
137
+ buf.clear()
138
+ continue
139
  buf.append(t - AUDIO_BASE)
140
  if len(buf) == 7:
141
  await ws.send_bytes(decode_block(buf))
142
  buf.clear()
143
+ masker.sent_blocks = 1 # ab jetzt EOS zulässig
 
 
144
 
145
  except (StopIteration, WebSocketDisconnect):
146
  pass
147
  except Exception as e:
148
+ print("❌ WS‑Error:", e, flush=True)
149
  if ws.client_state.name != "DISCONNECTED":
150
  await ws.close(code=1011)
151
  finally:
152
  if ws.client_state.name != "DISCONNECTED":
153
+ try:
154
+ await ws.close()
155
+ except RuntimeError:
156
+ pass
157
 
158
+ # 6) Dev‑Start --------------------------------------------------------
159
  if __name__ == "__main__":
160
+ import uvicorn, sys
161
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")