Tomtom84 commited on
Commit
d4630a2
·
verified ·
1 Parent(s): 3281189

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -108
app.py CHANGED
@@ -1,154 +1,162 @@
1
  import os
2
  import json
3
  import asyncio
4
- import numpy as np
5
  import torch
6
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
7
  from dotenv import load_dotenv
8
  from snac import SNAC
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
- from huggingface_hub import login, snapshot_download
11
 
12
  # — ENV & HF‑AUTH —
13
  load_dotenv()
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if HF_TOKEN:
16
- login(token=HF_TOKEN)
 
17
 
18
- # — Device
 
 
 
 
 
 
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # — Modelle laden —
22
- print("Loading SNAC model...")
23
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
24
 
 
25
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
26
- print("Downloading model weights (config + safetensors)...")
27
- snapshot_download(
28
- repo_id=model_name,
29
- allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
30
- ignore_patterns=[
31
- "optimizer.pt", "pytorch_model.bin", "training_args.bin",
32
- "scheduler.pt", "tokenizer.json", "tokenizer_config.json",
33
- "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
34
- ]
35
  )
36
-
37
- print("Loading Orpheus model...")
38
- model = AutoModelForCausalLM.from_pretrained(
39
  model_name,
40
- torch_dtype=torch.bfloat16
41
- ).to(device)
42
- model.config.pad_token_id = model.config.eos_token_id
43
- tokenizer = AutoTokenizer.from_pretrained(model_name)
44
 
45
- # Konstanten —
46
- AUDIO_TOKEN_OFFSET = 128266 # globaler Offset der Audio‑Tokens
 
47
 
48
  # — Hilfsfunktionen —
49
-
50
- def process_prompt(text: str, voice: str):
51
- """Bereitet input_ids und attention_mask für das Modell vor."""
52
- prompt = f"{voice}: {text}"
53
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
54
- start = torch.tensor([[128259]], dtype=torch.int64)
55
- end = torch.tensor([[128009, 128260]], dtype=torch.int64)
56
- ids = torch.cat([start, input_ids, end], dim=1).to(device)
57
- mask = torch.ones_like(ids).to(device)
58
- return ids, mask
59
-
60
- def parse_output(generated_ids: torch.LongTensor):
61
- """Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
62
- token_to_find = 128257
63
- token_to_remove = 128258
64
-
65
- idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
 
 
 
66
  if idxs.numel() > 0:
67
  cut = idxs[-1].item() + 1
68
- cropped = generated_ids[:, cut:]
69
  else:
70
- cropped = generated_ids
71
-
72
- row = cropped[0]
73
- # entferne das EOS‑Marker‑Token
74
- return row[row != token_to_remove].tolist()
75
-
76
- def redistribute_codes(raw_codes: list[int], snac_model: SNAC):
77
- """
78
- Subtrahiere erst den globalen Offset, dann packe in 7er-Blöcke und dekodiere.
79
- Unvollständige Reste (<7 Tokens) werden verworfen.
80
- """
81
- # 1) Offset abziehen
82
- codes = [c - AUDIO_TOKEN_OFFSET for c in raw_codes]
83
-
84
- # 2) Nur ganze 7er‑Blöcke
85
- n_blocks = len(codes) // 7
86
- if n_blocks == 0:
87
- return np.zeros(0, dtype=np.float32)
88
-
89
- layer1, layer2, layer3 = [], [], []
90
- for i in range(n_blocks):
91
- b = codes[7*i : 7*i+7]
92
- layer1.append(b[0])
93
- layer2.append(b[1] - 4096)
94
- layer3.append(b[2] - 2*4096)
95
- layer3.append(b[3] - 3*4096)
96
- layer2.append(b[4] - 4*4096)
97
- layer3.append(b[5] - 5*4096)
98
- layer3.append(b[6] - 6*4096)
99
-
100
- # 3) SNAC‑Layer‑Tensors bauen und dekodieren
101
- dev = next(snac_model.parameters()).device
102
- t1 = torch.tensor(layer1, device=dev).unsqueeze(0)
103
- t2 = torch.tensor(layer2, device=dev).unsqueeze(0)
104
- t3 = torch.tensor(layer3, device=dev).unsqueeze(0)
105
- audio = snac_model.decode([t1, t2, t3])
106
-
107
- return audio.detach().squeeze().cpu().numpy()
108
-
109
- # — FastAPI Setup —
110
-
111
- app = FastAPI()
112
-
113
- @app.get("/")
114
- async def hello():
115
- return {"message": "Hello World"}
116
-
 
 
117
  @app.websocket("/ws/tts")
118
  async def tts_ws(ws: WebSocket):
119
  await ws.accept()
120
  try:
121
  while True:
122
- # Empfang: {"text":"...", "voice":"Jakob"}
123
- data = json.loads(await ws.receive_text())
124
- text = data.get("text", "")
125
- voice = data.get("voice", "Jakob")
126
 
127
- # 1) Eingabe → Tokens
128
- ids, mask = process_prompt(text, voice)
129
 
130
- # 2) Generierung
131
- gen_ids = model.generate(
132
  input_ids=ids,
133
  attention_mask=mask,
134
- max_new_tokens=2000, # nach Bedarf hochsetzen
135
  do_sample=True,
136
  temperature=0.7,
137
  top_p=0.95,
138
  repetition_penalty=1.1,
139
- eos_token_id=model.config.eos_token_id,
 
 
140
  )
141
 
142
- # 3) Tokens → Audio‑Codes → PCM
143
- raw_codes = parse_output(gen_ids)
144
- audio_np = redistribute_codes(raw_codes, snac)
145
- pcm16 = (audio_np * 32767).astype("int16").tobytes()
146
 
147
- # 4) Stream in 0.1 s‑Chunks
148
- chunk = 2400 * 2 # 2400 Samples @24 kHz = 0.1 s * 2 Bytes
149
- for i in range(0, len(pcm16), chunk):
150
- await ws.send_bytes(pcm16[i : i+chunk])
151
- await asyncio.sleep(0.1)
152
 
153
  except WebSocketDisconnect:
154
  print("Client disconnected")
@@ -156,6 +164,7 @@ async def tts_ws(ws: WebSocket):
156
  print("Error in /ws/tts:", e)
157
  await ws.close(code=1011)
158
 
 
159
  if __name__ == "__main__":
160
  import uvicorn
161
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
1
  import os
2
  import json
3
  import asyncio
 
4
  import torch
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
+ from fastapi.responses import PlainTextResponse
7
  from dotenv import load_dotenv
8
  from snac import SNAC
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from peft import PeftModel
11
 
12
  # — ENV & HF‑AUTH —
13
  load_dotenv()
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if HF_TOKEN:
16
+ # automatisch über huggingface-cli eingeloggt
17
+ os.environ["HUGGINGFACE_HUB_TOKEN"] = HF_TOKEN
18
 
19
+ # — FastAPI
20
+ app = FastAPI()
21
+
22
+ @app.get("/")
23
+ async def hello():
24
+ return PlainTextResponse("Hallo Welt!")
25
+
26
+ # — Device konfigurieren —
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
+ # — SNAC laden —
30
+ print("Loading SNAC model")
31
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
32
 
33
+ # — Orpheus/Kartoffel‑3B über PEFT laden —
34
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
35
+ print(f"Loading base LM + PEFT from {model_name}…")
36
+ base = AutoModelForCausalLM.from_pretrained(
37
+ model_name,
38
+ device_map="auto",
39
+ torch_dtype=torch.bfloat16,
 
 
 
 
40
  )
41
+ model = PeftModel.from_pretrained(
42
+ base,
 
43
  model_name,
44
+ device_map="auto",
45
+ )
46
+ model.eval()
 
47
 
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+ # sicherstellen, dass pad_token_id gesetzt ist
50
+ model.config.pad_token_id = model.config.eos_token_id
51
 
52
  # — Hilfsfunktionen —
53
+ def prepare_prompt(text: str, voice: str):
54
+ """Setzt Start‑ und End‑Marker um den eigentlichen Prompt."""
55
+ if voice:
56
+ full = f"{voice}: {text}"
57
+ else:
58
+ full = text
59
+ start = torch.tensor([[128259]], dtype=torch.int64) # BOS für Audio
60
+ end = torch.tensor([[128009, 128260]], dtype=torch.int64) # ggf. Speaker‑ID + Marker
61
+ enc = tokenizer(full, return_tensors="pt").input_ids
62
+ seq = torch.cat([start, enc, end], dim=1).to(device)
63
+ mask = torch.ones_like(seq).to(device)
64
+ return seq, mask
65
+
66
+ def extract_audio_tokens(generated: torch.LongTensor):
67
+ """Croppe alles bis zum echten Audio-Start, entferne EOS und mache 7er-Batches."""
68
+ bos_tok = 128257
69
+ eos_tok = 128258
70
+
71
+ # letzten Start‑Token finden und ab da weiter
72
+ idxs = (generated == bos_tok).nonzero(as_tuple=True)[1]
73
  if idxs.numel() > 0:
74
  cut = idxs[-1].item() + 1
75
+ cropped = generated[:, cut:]
76
  else:
77
+ cropped = generated
78
+
79
+ # EOS‑Marker entfernen
80
+ flat = cropped[0][cropped[0] != eos_tok]
81
+
82
+ # nur ein Vielfaches von 7 behalten
83
+ length = (flat.size(0) // 7) * 7
84
+ flat = flat[:length]
85
+
86
+ # Die Audio‑Token beginnen ab Offset 128266
87
+ return [(t.item() - 128266) for t in flat]
88
+
89
+ def decode_and_stream(tokens: list[int], ws: WebSocket):
90
+ """Wandelt 7er‑Gruppen in Wave‑Samples um und streamt in 0.1 s Chunks."""
91
+ # gruppiere nach 7 und dekodiere jeweils
92
+ pcm16 = bytearray()
93
+ offset = 0
94
+ while offset + 7 <= len(tokens):
95
+ block = tokens[offset:offset+7]
96
+ offset += 7
97
+
98
+ # SNAC‑Input vorbereiten
99
+ # Layer‑1: direkt, Layer‑2/3 mit Offsets
100
+ l1, l2, l3 = [], [], []
101
+ l1.append(block[0])
102
+ l2.append(block[1] - 4096)
103
+ l3.append(block[2] - 2*4096)
104
+ l3.append(block[3] - 3*4096)
105
+ l2.append(block[4] - 4*4096)
106
+ l3.append(block[5] - 5*4096)
107
+ l3.append(block[6] - 6*4096)
108
+
109
+ t1 = torch.tensor(l1, device=device).unsqueeze(0)
110
+ t2 = torch.tensor(l2, device=device).unsqueeze(0)
111
+ t3 = torch.tensor(l3, device=device).unsqueeze(0)
112
+ audio = snac.decode([t1, t2, t3]).squeeze().cpu().numpy()
113
+
114
+ # in PCM16 @24 kHz
115
+ pcm = (audio * 32767).astype("int16").tobytes()
116
+ pcm16.extend(pcm)
117
+
118
+ # in 0.1 s‑Chunks (2400 Samples ×2 Bytes)
119
+ chunk_size = 2400 * 2
120
+ for i in range(0, len(pcm16), chunk_size):
121
+ ws.send_bytes(pcm16[i : i+chunk_size])
122
+ # ohne Pause kann das WebSocket überlastet werden
123
+ asyncio.sleep(0.1)
124
+
125
+ # — WebSocket TTS Endpoint —
126
  @app.websocket("/ws/tts")
127
  async def tts_ws(ws: WebSocket):
128
  await ws.accept()
129
  try:
130
  while True:
131
+ raw = await ws.receive_text()
132
+ req = json.loads(raw)
133
+ text = req.get("text", "")
134
+ voice = req.get("voice", "")
135
 
136
+ # Prompt vorbereiten
137
+ ids, mask = prepare_prompt(text, voice)
138
 
139
+ # Audio‑Token generieren
140
+ gen = model.generate(
141
  input_ids=ids,
142
  attention_mask=mask,
143
+ max_new_tokens=4000,
144
  do_sample=True,
145
  temperature=0.7,
146
  top_p=0.95,
147
  repetition_penalty=1.1,
148
+ eos_token_id=128258,
149
+ forced_bos_token_id=128259,
150
+ use_cache=True,
151
  )
152
 
153
+ codes = extract_audio_tokens(gen)
154
+ # stream synchron
155
+ await decode_and_stream(codes, ws)
 
156
 
157
+ # sauber schließen
158
+ await ws.close(code=1000)
159
+ break
 
 
160
 
161
  except WebSocketDisconnect:
162
  print("Client disconnected")
 
164
  print("Error in /ws/tts:", e)
165
  await ws.close(code=1011)
166
 
167
+ # — Lokal starten —
168
  if __name__ == "__main__":
169
  import uvicorn
170
  uvicorn.run("app:app", host="0.0.0.0", port=7860)