Tomtom84 commited on
Commit
67c3132
Β·
1 Parent(s): 2c15189
Files changed (1) hide show
  1. app.py +56 -46
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import json
3
  import asyncio
 
4
  import torch
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
  from dotenv import load_dotenv
@@ -8,74 +9,78 @@ from snac import SNAC
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
  from huggingface_hub import login, snapshot_download
10
 
11
- # ─── ENV & HF TOKEN ────────────────────────────────────────────────────────────
12
  load_dotenv()
13
  HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
- # ─── DEVICE ─────────────────────────────────────────────────────────────────────
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # ─── SNAC ───────────────────────────────────────────────────────────────────────
21
- print("Loading SNAC model…")
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
24
- # ─── ORPHEUS ────────────────────────────────────────────────────────────────────
25
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
26
-
27
- # pre‑download only the config + safetensors, damit das Image schlank bleibt
28
  snapshot_download(
29
  repo_id=model_name,
30
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
31
  ignore_patterns=[
32
  "optimizer.pt", "pytorch_model.bin", "training_args.bin",
33
- "scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt"
 
34
  ]
35
  )
36
 
37
- print("Loading Orpheus model…")
38
  model = AutoModelForCausalLM.from_pretrained(
39
  model_name,
40
- torch_dtype=torch.bfloat16 # optional: beschleunigt das FP16‑Àhnliche Rechnen
41
- )
42
- model = model.to(device)
43
  model.config.pad_token_id = model.config.eos_token_id
44
 
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
 
47
- # ─── HILFSFUNKTIONEN ──────────────────────────────────────────────────────────
48
 
49
  def process_prompt(text: str, voice: str):
50
- """
51
- Baut aus Text+Voice ein batch‑Tensor input_ids fΓΌr `model.generate`.
52
- """
53
  prompt = f"{voice}: {text}"
54
- tok = tokenizer(prompt, return_tensors="pt").to(device)
55
- start = torch.tensor([[128259]], device=device)
56
- end = torch.tensor([[128009, 128260]], device=device)
57
- return torch.cat([start, tok.input_ids, end], dim=1)
 
 
58
 
59
  def parse_output(generated_ids: torch.LongTensor):
60
- """
61
- Schneidet bis zum letzten 128257 und entfernt 128258, gibt reine Token‑Liste zurΓΌck.
62
- """
63
- START, PAD = 128257, 128258
64
- idxs = (generated_ids == START).nonzero(as_tuple=True)[1]
65
  if idxs.numel() > 0:
66
- cropped = generated_ids[:, idxs[-1].item()+1:]
 
67
  else:
68
  cropped = generated_ids
69
- row = cropped[0][cropped[0] != PAD]
70
- return row.tolist()
 
 
71
 
72
  def redistribute_codes(code_list: list[int], snac_model: SNAC):
73
  """
74
- Verteilt 7er‑BlΓΆcke auf die drei SNAC‑Layer und dekodiert zu Audio (numpy float32).
 
75
  """
 
76
  layer1, layer2, layer3 = [], [], []
77
- for i in range((len(code_list) + 1) // 7):
78
- base = code_list[7*i : 7*i+7]
 
79
  layer1.append(base[0])
80
  layer2.append(base[1] - 4096)
81
  layer3.append(base[2] - 2*4096)
@@ -83,6 +88,11 @@ def redistribute_codes(code_list: list[int], snac_model: SNAC):
83
  layer2.append(base[4] - 4*4096)
84
  layer3.append(base[5] - 5*4096)
85
  layer3.append(base[6] - 6*4096)
 
 
 
 
 
86
  dev = next(snac_model.parameters()).device
87
  codes = [
88
  torch.tensor(layer1, device=dev).unsqueeze(0),
@@ -92,31 +102,32 @@ def redistribute_codes(code_list: list[int], snac_model: SNAC):
92
  audio = snac_model.decode(codes)
93
  return audio.detach().squeeze().cpu().numpy()
94
 
95
- # ─── FASTAPI ────────────────────────────────────────────────────────────────────
96
 
97
  app = FastAPI()
98
 
99
  @app.get("/")
100
- async def healthcheck():
101
- return {"status": "ok", "msg": "Hello, Orpheus TTS up!"}
102
 
103
  @app.websocket("/ws/tts")
104
  async def tts_ws(ws: WebSocket):
105
  await ws.accept()
106
  try:
107
  while True:
108
- # 1) Eintreffende JSON‐Nachricht parsen
109
  data = json.loads(await ws.receive_text())
110
  text = data.get("text", "")
111
  voice = data.get("voice", "Jakob")
112
 
113
- # 2) Prompt β†’ input_ids
114
- ids = process_prompt(text, voice)
115
 
116
- # 3) Token‐Erzeugung
117
  gen_ids = model.generate(
118
  input_ids=ids,
119
- max_new_tokens=2000, # hier z.B. 20k geht auch, wird aber speicherintensiv
 
120
  do_sample=True,
121
  temperature=0.7,
122
  top_p=0.95,
@@ -124,25 +135,24 @@ async def tts_ws(ws: WebSocket):
124
  eos_token_id=model.config.eos_token_id,
125
  )
126
 
127
- # 4) Tokens β†’ Code‐Liste β†’ Audio
128
- codes = parse_output(gen_ids)
129
  audio_np = redistribute_codes(codes, snac)
130
 
131
- # 5) PCM16‐Stream in 0.1 s‐BlΓΆcken
132
  pcm16 = (audio_np * 32767).astype("int16").tobytes()
133
- chunk = 2400 * 2
134
  for i in range(0, len(pcm16), chunk):
135
  await ws.send_bytes(pcm16[i : i+chunk])
136
  await asyncio.sleep(0.1)
137
 
 
138
  except WebSocketDisconnect:
139
  print("Client disconnected")
140
  except Exception as e:
141
  print("Error in /ws/tts:", e)
142
  await ws.close(code=1011)
143
 
144
- # ─── START ──────────────────────────────────────────────────────────────────────
145
-
146
  if __name__ == "__main__":
147
  import uvicorn
148
- uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")
 
1
  import os
2
  import json
3
  import asyncio
4
+ import numpy as np
5
  import torch
6
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
7
  from dotenv import load_dotenv
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
  from huggingface_hub import login, snapshot_download
11
 
12
+ # β€” ENV & HF‑AUTH β€”
13
  load_dotenv()
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if HF_TOKEN:
16
  login(token=HF_TOKEN)
17
 
18
+ # β€” Device β€”
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # β€” Modelle laden β€”
22
+ print("Loading SNAC model...")
23
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
24
 
 
25
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
26
+ print("Downloading model weights (config + safetensors)...")
 
27
  snapshot_download(
28
  repo_id=model_name,
29
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
30
  ignore_patterns=[
31
  "optimizer.pt", "pytorch_model.bin", "training_args.bin",
32
+ "scheduler.pt", "tokenizer.json", "tokenizer_config.json",
33
+ "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
34
  ]
35
  )
36
 
37
+ print("Loading Orpheus model...")
38
  model = AutoModelForCausalLM.from_pretrained(
39
  model_name,
40
+ torch_dtype=torch.bfloat16
41
+ ).to(device)
 
42
  model.config.pad_token_id = model.config.eos_token_id
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(model_name)
45
 
46
+ # β€” Hilfsfunktionen β€”
47
 
48
  def process_prompt(text: str, voice: str):
49
+ """Bereitet input_ids und attention_mask fΓΌr das Modell vor."""
 
 
50
  prompt = f"{voice}: {text}"
51
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
52
+ start = torch.tensor([[128259]], dtype=torch.int64)
53
+ end = torch.tensor([[128009, 128260]], dtype=torch.int64)
54
+ ids = torch.cat([start, input_ids, end], dim=1).to(device)
55
+ mask = torch.ones_like(ids).to(device)
56
+ return ids, mask
57
 
58
  def parse_output(generated_ids: torch.LongTensor):
59
+ """Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
60
+ token_to_find = 128257
61
+ token_to_remove = 128258
62
+
63
+ idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
64
  if idxs.numel() > 0:
65
+ cut = idxs[-1].item() + 1
66
+ cropped = generated_ids[:, cut:]
67
  else:
68
  cropped = generated_ids
69
+
70
+ # Entferne EOS‑Token
71
+ row = cropped[0]
72
+ return row[row != token_to_remove].tolist()
73
 
74
  def redistribute_codes(code_list: list[int], snac_model: SNAC):
75
  """
76
+ Verteilt die Token nur in kompletten 7er‑BlΓΆcken auf die drei SNAC‑Layer
77
+ und dekodiert in Audio. UnvollstΓ€ndige Reste (<7 Tokens) werden verworfen.
78
  """
79
+ n_blocks = len(code_list) // 7
80
  layer1, layer2, layer3 = [], [], []
81
+
82
+ for i in range(n_blocks):
83
+ base = code_list[7*i : 7*i + 7]
84
  layer1.append(base[0])
85
  layer2.append(base[1] - 4096)
86
  layer3.append(base[2] - 2*4096)
 
88
  layer2.append(base[4] - 4*4096)
89
  layer3.append(base[5] - 5*4096)
90
  layer3.append(base[6] - 6*4096)
91
+
92
+ if not layer1:
93
+ # kein kompletter Block β†’ leeres Audio
94
+ return np.zeros(0, dtype=np.float32)
95
+
96
  dev = next(snac_model.parameters()).device
97
  codes = [
98
  torch.tensor(layer1, device=dev).unsqueeze(0),
 
102
  audio = snac_model.decode(codes)
103
  return audio.detach().squeeze().cpu().numpy()
104
 
105
+ # β€” FastAPI Setup β€”
106
 
107
  app = FastAPI()
108
 
109
  @app.get("/")
110
+ def greet_json():
111
+ return {"Hello": "World!"}
112
 
113
  @app.websocket("/ws/tts")
114
  async def tts_ws(ws: WebSocket):
115
  await ws.accept()
116
  try:
117
  while True:
118
+ # Erwartet JSON: {"text": "...", "voice": "Jakob"}
119
  data = json.loads(await ws.receive_text())
120
  text = data.get("text", "")
121
  voice = data.get("voice", "Jakob")
122
 
123
+ # 1) Tokens vorbereiten
124
+ ids, mask = process_prompt(text, voice)
125
 
126
+ # 2) Generierung
127
  gen_ids = model.generate(
128
  input_ids=ids,
129
+ attention_mask=mask,
130
+ max_new_tokens=2000, # hier nach Bedarf anpassen
131
  do_sample=True,
132
  temperature=0.7,
133
  top_p=0.95,
 
135
  eos_token_id=model.config.eos_token_id,
136
  )
137
 
138
+ # 3) Tokens β†’ Code-Liste β†’ Audio
139
+ codes = parse_output(gen_ids)
140
  audio_np = redistribute_codes(codes, snac)
141
 
142
+ # 4) in 0.1s‑StΓΌcken PCM16 streamen
143
  pcm16 = (audio_np * 32767).astype("int16").tobytes()
144
+ chunk = 2400 * 2 # 2400 samples @24kHz = 0.1s * 2 bytes
145
  for i in range(0, len(pcm16), chunk):
146
  await ws.send_bytes(pcm16[i : i+chunk])
147
  await asyncio.sleep(0.1)
148
 
149
+ # Ende der while‐Schleife
150
  except WebSocketDisconnect:
151
  print("Client disconnected")
152
  except Exception as e:
153
  print("Error in /ws/tts:", e)
154
  await ws.close(code=1011)
155
 
 
 
156
  if __name__ == "__main__":
157
  import uvicorn
158
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)