Tomtom84 commited on
Commit
9cd424e
·
1 Parent(s): d408dd5
Files changed (1) hide show
  1. app.py +85 -91
app.py CHANGED
@@ -14,14 +14,17 @@ HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
- # — Gerät wählen
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # — Modelle laden —
21
  print("Loading SNAC model...")
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
 
24
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
 
 
25
  snapshot_download(
26
  repo_id=model_name,
27
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
@@ -34,118 +37,109 @@ snapshot_download(
34
 
35
  print("Loading Orpheus model...")
36
  model = AutoModelForCausalLM.from_pretrained(
37
- model_name, torch_dtype=torch.bfloat16
 
 
38
  ).to(device)
39
  model.config.pad_token_id = model.config.eos_token_id
 
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
41
 
42
- # — Konstanten für Token‑Mapping —
43
- AUDIO_TOKEN_OFFSET = 128266
44
- START_TOKEN = 128259
45
- SOS_TOKEN = 128257
46
- EOS_TOKEN = 128258
47
 
48
  # — Hilfsfunktionen —
 
49
  def process_prompt(text: str, voice: str):
50
  prompt = f"{voice}: {text}"
51
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
52
- start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
53
- end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
54
- ids = torch.cat([start, input_ids, end], dim=1)
55
- mask = torch.ones_like(ids, dtype=torch.int64, device=device)
56
- return ids, mask
57
-
58
- def redistribute_codes(block: list[int], snac_model: SNAC):
59
- # exakt wie vorher: 7 Codes → 3 Layer → SNAC.decode → NumPy float32 @24 kHz
60
- l1, l2, l3 = [], [], []
61
- for i in range(len(block)//7):
62
- b = block[7*i:7*i+7]
63
- l1.append(b[0])
64
- l2.append(b[1] - 4096)
65
- l3.append(b[2] - 2*4096)
66
- l3.append(b[3] - 3*4096)
67
- l2.append(b[4] - 4*4096)
68
- l3.append(b[5] - 5*4096)
69
- l3.append(b[6] - 6*4096)
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  dev = next(snac_model.parameters()).device
71
  codes = [
72
- torch.tensor(l1, device=dev).unsqueeze(0),
73
- torch.tensor(l2, device=dev).unsqueeze(0),
74
- torch.tensor(l3, device=dev).unsqueeze(0),
75
  ]
76
- audio = snac_model.decode(codes) # → Tensor[1, T]
77
- return audio.squeeze().cpu().numpy()
 
78
 
79
- # — FastAPI Setup
80
  app = FastAPI()
81
 
82
- # 1) Hello‑World Endpoint
83
  @app.get("/")
84
- async def root():
85
- return {"message": "Hallo Welt"}
86
 
87
- # 2) WebSocket Token‑für‑Token TTS
88
  @app.websocket("/ws/tts")
89
  async def tts_ws(ws: WebSocket):
90
  await ws.accept()
91
  try:
92
- while True:
93
- # JSON mit Text & Voice empfangen
94
- raw = await ws.receive_text()
95
- req = json.loads(raw)
96
- text, voice = req.get("text", ""), req.get("voice", "Jakob")
97
- ids, mask = process_prompt(text, voice)
98
-
99
- past_kv = None
100
- collected = []
101
-
102
- # im Sampling‑Loop Token für Token generieren
103
- with torch.no_grad():
104
- for _ in range(2000): # max 200 Tokens
105
- out = model(
106
- input_ids=ids if past_kv is None else None,
107
- attention_mask=mask if past_kv is None else None,
108
- past_key_values=past_kv,
109
- use_cache=True,
110
- )
111
- logits = out.logits[:, -1, :]
112
- next_id = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
113
- past_kv = out.past_key_values
114
-
115
- token = next_id.item()
116
- # Ende
117
- if token == EOS_TOKEN:
118
- break
119
- # Reset bei SOS
120
- if token == SOS_TOKEN:
121
- collected = []
122
- continue
123
-
124
- # in Audio‑Code konvertieren
125
- collected.append(token - AUDIO_TOKEN_OFFSET)
126
-
127
- # sobald 7 Codes → direkt dekodieren & streamen
128
- if len(collected) >= 7:
129
- block = collected[:7]
130
- collected = collected[7:]
131
- audio_np = redistribute_codes(block, snac)
132
- pcm16 = (audio_np * 32767).astype("int16").tobytes()
133
- await ws.send_bytes(pcm16)
134
-
135
- # ab jetzt nur noch past_kv verwenden
136
- ids = None
137
- mask = None
138
-
139
- # zum Schluss End‑Of‑Stream signalisieren
140
- await ws.send_text(json.dumps({"event": "eos"}))
141
 
142
  except WebSocketDisconnect:
143
  print("Client disconnected")
144
  except Exception as e:
 
145
  print("Error in /ws/tts:", e)
146
  await ws.close(code=1011)
147
-
148
- # zum lokalen Test
149
- if __name__ == "__main__":
150
- import uvicorn
151
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
+ # — Device
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
+ # — SNAC laden —
21
  print("Loading SNAC model...")
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
24
+ # — Orpheus‑Modell vorbereiten —
25
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
26
+
27
+ # Nur Konfig+Weights (ermöglicht schlankeren Container)
28
  snapshot_download(
29
  repo_id=model_name,
30
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
 
37
 
38
  print("Loading Orpheus model...")
39
  model = AutoModelForCausalLM.from_pretrained(
40
+ model_name,
41
+ torch_dtype=torch.bfloat16,
42
+ device_map="auto",
43
  ).to(device)
44
  model.config.pad_token_id = model.config.eos_token_id
45
+
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
47
 
 
 
 
 
 
48
 
49
  # — Hilfsfunktionen —
50
+
51
  def process_prompt(text: str, voice: str):
52
  prompt = f"{voice}: {text}"
53
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
54
+ # füge Start-/End-Tokens hinzu
55
+ start = torch.tensor([[128259]], device=device)
56
+ end = torch.tensor([[128009, 128260]], device=device)
57
+ input_ids = torch.cat([start, inputs.input_ids, end], dim=1)
58
+ return input_ids
59
+
60
+ def parse_output(generated_ids: torch.LongTensor):
61
+ token_to_find = 128257
62
+ token_to_remove = 128258
63
+
64
+ idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
65
+ if idxs.numel() > 0:
66
+ cropped = generated_ids[:, idxs[-1].item() + 1 :]
67
+ else:
68
+ cropped = generated_ids
69
+
70
+ row = cropped[0][cropped[0] != token_to_remove]
71
+ return row.tolist()
72
+
73
+ def redistribute_codes(code_list: list[int], snac_model: SNAC):
74
+ layer1, layer2, layer3 = [], [], []
75
+ for i in range((len(code_list) + 1) // 7):
76
+ base = code_list[7*i : 7*i+7]
77
+ layer1.append(base[0])
78
+ layer2.append(base[1] - 4096)
79
+ layer3.append(base[2] - 2*4096)
80
+ layer3.append(base[3] - 3*4096)
81
+ layer2.append(base[4] - 4*4096)
82
+ layer3.append(base[5] - 5*4096)
83
+ layer3.append(base[6] - 6*4096)
84
+
85
  dev = next(snac_model.parameters()).device
86
  codes = [
87
+ torch.tensor(layer1, device=dev).unsqueeze(0),
88
+ torch.tensor(layer2, device=dev).unsqueeze(0),
89
+ torch.tensor(layer3, device=dev).unsqueeze(0),
90
  ]
91
+ audio = snac_model.decode(codes)
92
+ return audio.detach().squeeze().cpu().numpy()
93
+
94
 
95
+ # — FastAPI App
96
  app = FastAPI()
97
 
 
98
  @app.get("/")
99
+ async def hello():
100
+ return {"message": "Hello, Orpheus TTS is up and running!"}
101
 
 
102
  @app.websocket("/ws/tts")
103
  async def tts_ws(ws: WebSocket):
104
  await ws.accept()
105
  try:
106
+ # **Nur EIN Request pro Connection**
107
+ raw = await ws.receive_text()
108
+ data = json.loads(raw)
109
+ text = data.get("text", "")
110
+ voice = data.get("voice", "Jakob")
111
+
112
+ # 1) Text → input_ids
113
+ input_ids = process_prompt(text, voice)
114
+
115
+ # 2) Generation
116
+ gen_ids = model.generate(
117
+ input_ids=input_ids,
118
+ max_new_tokens=2000, # hier kannst du hochsetzen
119
+ do_sample=True,
120
+ temperature=0.7,
121
+ top_p=0.95,
122
+ repetition_penalty=1.1,
123
+ eos_token_id=model.config.eos_token_id,
124
+ )
125
+
126
+ # 3) Token Audio
127
+ codes = parse_output(gen_ids)
128
+ audio_np = redistribute_codes(codes, snac)
129
+
130
+ # 4) PCM16-Bytes in ~0.1s‑Chunks streamen
131
+ pcm16 = (audio_np * 32767).astype("int16").tobytes()
132
+ chunk_size = 2400 * 2 # 2400 Samples @24kHz = 0.1s * 2 Byte
133
+ for i in range(0, len(pcm16), chunk_size):
134
+ await ws.send_bytes(pcm16[i : i+chunk_size])
135
+ await asyncio.sleep(0.1)
136
+
137
+ # Sauber schließen, Client erhält ConnectionClosedOK
138
+ await ws.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  except WebSocketDisconnect:
141
  print("Client disconnected")
142
  except Exception as e:
143
+ # Log und saubere Fehler‑Closure
144
  print("Error in /ws/tts:", e)
145
  await ws.close(code=1011)