Tomtom84 commited on
Commit
479f253
Β·
verified Β·
1 Parent(s): 4c833ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -70
app.py CHANGED
@@ -1,97 +1,94 @@
1
- # app.py -------------------------------------------------------------
2
- import os, json, torch
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
- from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
  # ── 0. Auth & Device ────────────────────────────────────────────────
10
- if (tok := os.getenv("HF_TOKEN")):
11
- login(tok)
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2 fix
 
15
 
16
  # ── 1. Konstanten ───────────────────────────────────────────────────
17
- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
18
- CHUNK_TOKENS = 50 # ≀ 50Β β†’Β <Β 1Β s Latenz
19
- START_TOKEN = 128259
20
- NEW_BLOCK_TOKEN = 128257
21
- EOS_TOKEN = 128258
22
- AUDIO_BASE = 128266
23
- VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
24
-
25
- # ── 2. Logit‑Maske (nur Audio‑ und Steuer‑Token) ──────────────────
 
 
 
26
  class AudioMask(LogitsProcessor):
27
- def __init__(self, allowed: torch.Tensor): # allowed @device!
 
28
  self.allowed = allowed
29
 
30
- def __call__(self, _ids, scores):
31
  mask = torch.full_like(scores, float("-inf"))
32
- mask[:, self.allowed] = 0.0
33
  return scores + mask
34
 
35
- ALLOWED_IDS = torch.cat(
36
- [VALID_AUDIO_IDS,
37
- torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])]
38
- ).to(device)
39
  MASKER = AudioMask(ALLOWED_IDS)
40
 
41
  # ── 3. FastAPI GrundgerΓΌst ──────────────────────────────────────────
42
  app = FastAPI()
43
 
44
  @app.get("/")
45
- async def root():
46
- return {"msg": "Orpheus‑TTS ready"}
47
-
48
- # global handles
49
- tok = model = snac = None
50
 
51
  @app.on_event("startup")
52
  async def load_models():
53
  global tok, model, snac
54
- tok = AutoTokenizer.from_pretrained(REPO)
55
- snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
56
  model = AutoModelForCausalLM.from_pretrained(
57
- REPO,
58
  low_cpu_mem_usage=True,
59
  device_map={"": 0} if device == "cuda" else None,
60
  torch_dtype=torch.bfloat16 if device == "cuda" else None,
61
  )
62
  model.config.pad_token_id = model.config.eos_token_id
63
- model.config.use_cache = True
64
 
65
- # ── 4. Helper ───────────────────────────────────────────────────────
66
  def build_inputs(text: str, voice: str):
67
- prompt = f"{voice}: {text}"
68
  ids = tok(prompt, return_tensors="pt").input_ids.to(device)
69
- ids = torch.cat(
70
- [
71
- torch.tensor([[START_TOKEN]], device=device),
72
- ids,
73
- torch.tensor([[128009, 128260]], device=device),
74
- ],
75
- 1,
76
- )
77
- return ids, torch.ones_like(ids)
78
-
79
- def decode_block(b7: list[int]) -> bytes:
80
  l1, l2, l3 = [], [], []
81
- l1.append(b7[0])
82
- l2.append(b7[1] - 4096)
83
- l3.extend([b7[2] - 8192, b7[3] - 12288])
84
- l2.append(b7[4] - 16384)
85
- l3.extend([b7[5] - 20480, b7[6] - 24576])
 
86
 
87
  codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)]
88
  audio = snac.decode(codes).squeeze().cpu().numpy()
89
  return (audio * 32767).astype("int16").tobytes()
90
 
91
- def new_tokens_only(full_seq, prev_len):
92
- """liefert Liste der Tokens, die *neu* hinzukamen"""
93
- return full_seq[prev_len:].tolist()
94
-
95
  # ── 5. WebSocket‑Endpoint ───────────────────────────────────────────
96
  @app.websocket("/ws/tts")
97
  async def tts(ws: WebSocket):
@@ -103,47 +100,49 @@ async def tts(ws: WebSocket):
103
  past, buf = None, []
104
 
105
  while True:
106
- gen = model.generate(
107
  input_ids=ids if past is None else None,
108
  attention_mask=attn if past is None else None,
109
  past_key_values=past,
110
  max_new_tokens=CHUNK_TOKENS,
111
  logits_processor=[MASKER],
112
- do_sample=True, temperature=0.7, top_p=0.95,
113
  return_dict_in_generate=True,
114
- use_cache=True, return_legacy_cache=True,
 
115
  )
116
 
117
- past = gen.past_key_values if not isinstance(gen.past_key_values, Cache) else gen.past_key_values.to_legacy()
118
- seq = gen.sequences[0].tolist()
119
- new_tok = seq[prompt_len:]
120
- prompt_len = len(seq)
121
 
122
- if not new_tok:
123
- continue # selten, aber mΓΆglich
124
 
125
- for t in new_tok:
126
  if t == EOS_TOKEN:
127
- # ein einziges Close‑Frame genΓΌgt
128
- await ws.close() # <── einziges explizites close
129
  return
130
  if t == NEW_BLOCK_TOKEN:
131
  buf.clear(); continue
 
 
132
  buf.append(t - AUDIO_BASE)
133
  if len(buf) == 7:
134
  await ws.send_bytes(decode_block(buf))
135
  buf.clear()
136
 
137
- ids = attn = None # nur noch Cache
 
138
 
139
  except WebSocketDisconnect:
140
- pass # Client ging von selbst
141
  except Exception as e:
142
  print("WS‑Error:", e)
143
  if ws.client_state.name == "CONNECTED":
144
- await ws.close(code=1011) # Fehler melden
145
 
146
- # ── 6. Local run ────────────────────────────────────────────────────
147
  if __name__ == "__main__":
148
  import uvicorn, sys
149
  port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860
 
1
+ # app.py ──────────────────────────────────────────────────────────────
2
+ import os, json, asyncio, torch, logging
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
+ from transformers import (AutoTokenizer, AutoModelForCausalLM,
6
+ LogitsProcessor, generation_utils)
7
  from snac import SNAC
8
 
9
  # ── 0. Auth & Device ────────────────────────────────────────────────
10
+ HF_TOKEN = os.getenv("HF_TOKEN")
11
+ if HF_TOKEN:
12
+ login(HF_TOKEN)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ torch.backends.cuda.enable_flash_sdp(False) # Flash‑Bug umgehen
16
+ logging.getLogger("transformers.generation.utils").setLevel("ERROR")
17
 
18
  # ── 1. Konstanten ───────────────────────────────────────────────────
19
+ MODEL_REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
20
+ CHUNK_TOKENS = 50
21
+
22
+ START_TOKEN = 128259 # <𝑠>
23
+ NEW_BLOCK_TOKEN = 128257 # πŸ”Šβ€‘Start
24
+ EOS_TOKEN = 128258 # <eos>
25
+ PROMPT_END = [128009, 128260]
26
+ AUDIO_BASE = 128266
27
+
28
+ VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
29
+
30
+ # ── 2. Logit‑Masker ─────────────────────────────────────────────────
31
  class AudioMask(LogitsProcessor):
32
+ def __init__(self, allowed: torch.Tensor):
33
+ super().__init__()
34
  self.allowed = allowed
35
 
36
+ def __call__(self, input_ids, scores):
37
  mask = torch.full_like(scores, float("-inf"))
38
+ mask[:, self.allowed] = 0
39
  return scores + mask
40
 
41
+ ALLOWED_IDS = torch.cat([
42
+ VALID_AUDIO_IDS,
43
+ torch.tensor([START_TOKEN, NEW_BLOCK_TOKEN, EOS_TOKEN])
44
+ ]).to(device)
45
  MASKER = AudioMask(ALLOWED_IDS)
46
 
47
  # ── 3. FastAPI GrundgerΓΌst ──────────────────────────────────────────
48
  app = FastAPI()
49
 
50
  @app.get("/")
51
+ async def ping():
52
+ return {"message": "Orpheus‑TTSΒ ready"}
 
 
 
53
 
54
  @app.on_event("startup")
55
  async def load_models():
56
  global tok, model, snac
57
+ tok = AutoTokenizer.from_pretrained(MODEL_REPO)
 
58
  model = AutoModelForCausalLM.from_pretrained(
59
+ MODEL_REPO,
60
  low_cpu_mem_usage=True,
61
  device_map={"": 0} if device == "cuda" else None,
62
  torch_dtype=torch.bfloat16 if device == "cuda" else None,
63
  )
64
  model.config.pad_token_id = model.config.eos_token_id
65
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
66
 
67
+ # ── 4. Hilfsfunktionen ──────────────────────────────────────────────
68
  def build_inputs(text: str, voice: str):
69
+ prompt = f"{voice}: {text}" if voice and voice != "in_prompt" else text
70
  ids = tok(prompt, return_tensors="pt").input_ids.to(device)
71
+ ids = torch.cat([
72
+ torch.tensor([[START_TOKEN]], device=device),
73
+ ids,
74
+ torch.tensor([PROMPT_END], device=device)
75
+ ], 1)
76
+ mask = torch.ones_like(ids)
77
+ return ids, mask # shape (1,Β L)
78
+
79
+ def decode_block(block7: list[int]) -> bytes:
 
 
80
  l1, l2, l3 = [], [], []
81
+ b = block7
82
+ l1.append(b[0])
83
+ l2.append(b[1] - 4096)
84
+ l3 += [b[2]-8192, b[3]-12288]
85
+ l2.append(b[4] - 16384)
86
+ l3 += [b[5]-20480, b[6]-24576]
87
 
88
  codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)]
89
  audio = snac.decode(codes).squeeze().cpu().numpy()
90
  return (audio * 32767).astype("int16").tobytes()
91
 
 
 
 
 
92
  # ── 5. WebSocket‑Endpoint ───────────────────────────────────────────
93
  @app.websocket("/ws/tts")
94
  async def tts(ws: WebSocket):
 
100
  past, buf = None, []
101
 
102
  while True:
103
+ out = model.generate(
104
  input_ids=ids if past is None else None,
105
  attention_mask=attn if past is None else None,
106
  past_key_values=past,
107
  max_new_tokens=CHUNK_TOKENS,
108
  logits_processor=[MASKER],
109
+ do_sample=True, top_p=0.95, temperature=0.7,
110
  return_dict_in_generate=True,
111
+ use_cache=True,
112
+ return_legacy_cache=True, # β‡  Warnung verschwindet
113
  )
114
 
115
+ past = out.past_key_values # unverΓ€ndert weiterreichen
116
+ seq = out.sequences[0].tolist()
117
+ new = seq[prompt_len:]; prompt_len = len(seq)
 
118
 
119
+ if not new: # selten, aber mΓΆglich
120
+ continue
121
 
122
+ for t in new:
123
  if t == EOS_TOKEN:
124
+ await ws.close()
 
125
  return
126
  if t == NEW_BLOCK_TOKEN:
127
  buf.clear(); continue
128
+ if t < AUDIO_BASE: # sollte durch Maske nie passieren
129
+ continue
130
  buf.append(t - AUDIO_BASE)
131
  if len(buf) == 7:
132
  await ws.send_bytes(decode_block(buf))
133
  buf.clear()
134
 
135
+ # Ab jetzt nur noch Cache – IDs & Mask nicht mehr nΓΆtig
136
+ ids = attn = None
137
 
138
  except WebSocketDisconnect:
139
+ pass
140
  except Exception as e:
141
  print("WS‑Error:", e)
142
  if ws.client_state.name == "CONNECTED":
143
+ await ws.close(code=1011)
144
 
145
+ # ── 6. Lokaler Start ────────────────────────────────────────────────
146
  if __name__ == "__main__":
147
  import uvicorn, sys
148
  port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860