Tomtom84 commited on
Commit
b3e4aa7
·
verified ·
1 Parent(s): e97a876

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -65
app.py CHANGED
@@ -22,7 +22,7 @@ print("Loading SNAC model…")
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
24
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
25
- # Nur die Konfig + Safetensors, alles andere wird ignoriert
26
  snapshot_download(
27
  repo_id=model_name,
28
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
@@ -34,71 +34,59 @@ snapshot_download(
34
  )
35
 
36
  print("Loading Orpheus model…")
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_name, torch_dtype=torch.bfloat16
39
- ).to(device)
40
  model.config.pad_token_id = model.config.eos_token_id
41
- tokenizer = AutoTokenizer.from_pretrained(model_name)
42
-
43
- # — Konstanten für Audio‑Token →
44
- # (muss übereinstimmen mit Deinem Training; hier 128266)
45
- AUDIO_TOKEN_OFFSET = 128266
46
 
47
- # Hilfsfunktionen —
48
 
 
49
  def process_prompt(text: str, voice: str):
 
50
  prompt = f"{voice}: {text}"
51
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
52
- # Laut Spezifikation:
53
- # start_token=128259, end_tokens=(128009,128260)
54
- start = torch.tensor([[128259]], dtype=torch.int64)
55
- end = torch.tensor([[128009, 128260]], dtype=torch.int64)
56
- ids = torch.cat([start, input_ids, end], dim=1).to(device)
57
- mask = torch.ones_like(ids).to(device)
58
  return ids, mask
59
 
60
- def parse_output(generated_ids: torch.LongTensor):
61
- """
62
- Croppt nach dem letzten 128257-Start-Token, entfernt Padding (128258)
63
- und zieht dann den Audio‑Offset ab, um echte Code‑IDs zu bekommen.
64
- """
65
- # finde letztes Audio‑Start‑Token
66
- token_to_start = 128257
67
- token_to_remove = model.config.eos_token_id # 128258
68
 
69
- idxs = (generated_ids == token_to_start).nonzero(as_tuple=True)[1]
70
  if idxs.numel() > 0:
71
- cut = idxs[-1].item() + 1
72
- cropped = generated_ids[:, cut:]
73
  else:
74
  cropped = generated_ids
75
 
76
- # flatten & remove PAD, dann Offset abziehen
77
- flat = cropped[0][cropped[0] != token_to_remove]
78
- codes = [(int(t) - AUDIO_TOKEN_OFFSET) for t in flat]
79
- return codes
80
-
81
- def redistribute_codes(code_list: list[int], snac_model: SNAC):
82
- """
83
- Verteilt die flache Code‑Liste in 3 Layers und dekodiert mit SNAC.
84
- """
85
- layer1, layer2, layer3 = [], [], []
86
- for i in range(len(code_list) // 7):
87
  base = code_list[7*i : 7*i+7]
88
- layer1.append(base[0])
89
- layer2.append(base[1] - 4096)
90
- layer3.append(base[2] - 2*4096)
91
- layer3.append(base[3] - 3*4096)
92
- layer2.append(base[4] - 4*4096)
93
- layer3.append(base[5] - 5*4096)
94
- layer3.append(base[6] - 6*4096)
95
-
96
- dev = next(snac_model.parameters()).device
97
- c1 = torch.tensor(layer1, device=dev).unsqueeze(0)
98
- c2 = torch.tensor(layer2, device=dev).unsqueeze(0)
99
- c3 = torch.tensor(layer3, device=dev).unsqueeze(0)
100
- audio = snac_model.decode([c1, c2, c3])
101
- return audio.detach().squeeze().cpu().numpy()
 
 
 
102
 
103
  # — FastAPI + WebSocket-Endpoint —
104
  app = FastAPI()
@@ -108,41 +96,47 @@ async def tts_ws(ws: WebSocket):
108
  await ws.accept()
109
  try:
110
  while True:
 
111
  msg = await ws.receive_text()
112
  data = json.loads(msg)
113
  text = data.get("text", "")
114
  voice = data.get("voice", "Jakob")
115
 
116
- # 1) Prompt → Token‑Tensoren
117
  ids, mask = process_prompt(text, voice)
118
 
119
- # 2) Generation
120
  gen_ids = model.generate(
121
  input_ids=ids,
122
  attention_mask=mask,
123
- max_new_tokens=200, # zum Debug
124
  do_sample=True,
125
  temperature=0.7,
126
  top_p=0.95,
127
  repetition_penalty=1.1,
128
- eos_token_id=model.config.eos_token_id,
129
  )
130
 
131
- # 3) Token Code‑ListeAudio (Float32 @24 kHz)
132
- code_list = parse_output(gen_ids)
133
- audio_np = redistribute_codes(code_list, snac)
 
134
 
135
- # 4) In 0.1 s‑Chunks (2400 Samples) als PCM16 streamen
136
- pcm16 = (audio_np * 32767).astype("int16").tobytes()
137
- chunk = 2400 * 2
138
- for i in range(0, len(pcm16), chunk):
139
- await ws.send_bytes(pcm16[i : i+chunk])
140
  await asyncio.sleep(0.1)
141
 
 
 
 
 
 
142
  except WebSocketDisconnect:
143
  print("Client disconnected")
144
  except Exception as e:
145
  print("Error in /ws/tts:", e)
 
146
  await ws.close(code=1011)
147
 
148
  if __name__ == "__main__":
 
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
24
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
25
+ print("Downloading Orpheus weights (konfig + safetensors)…")
26
  snapshot_download(
27
  repo_id=model_name,
28
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
 
34
  )
35
 
36
  print("Loading Orpheus model…")
37
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
 
 
38
  model.config.pad_token_id = model.config.eos_token_id
 
 
 
 
 
39
 
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
41
 
42
+ # — Hilfsfunktionen —
43
  def process_prompt(text: str, voice: str):
44
+ """Erzeuge input_ids und attention_mask für einen Prompt."""
45
  prompt = f"{voice}: {text}"
46
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
47
+ start = torch.tensor([[128259]], dtype=torch.int64, device=device)
48
+ end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
49
+ ids = torch.cat([start, input_ids, end], dim=1)
50
+ mask = torch.ones_like(ids)
 
 
51
  return ids, mask
52
 
53
+ def parse_output(generated_ids: torch.LongTensor) -> list[int]:
54
+ """Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
55
+ token_to_find = 128257
56
+ token_to_remove = 128258
 
 
 
 
57
 
58
+ idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
59
  if idxs.numel() > 0:
60
+ cropped = generated_ids[:, idxs[-1].item() + 1 :]
 
61
  else:
62
  cropped = generated_ids
63
 
64
+ row = cropped[0]
65
+ row = row[row != token_to_remove]
66
+ return row.tolist()
67
+
68
+ def redistribute_codes(code_list: list[int]) -> bytes:
69
+ """Verteile die Codes auf die drei SNAC-Layer und dekodiere zu PCM16-Bytes."""
70
+ l1, l2, l3 = [], [], []
71
+ for i in range((len(code_list) + 1) // 7):
 
 
 
72
  base = code_list[7*i : 7*i+7]
73
+ l1.append(base[0])
74
+ l2.append(base[1] - 4096)
75
+ l3.append(base[2] - 2*4096)
76
+ l3.append(base[3] - 3*4096)
77
+ l2.append(base[4] - 4*4096)
78
+ l3.append(base[5] - 5*4096)
79
+ l3.append(base[6] - 6*4096)
80
+
81
+ dev = next(snac.parameters()).device
82
+ codes = [
83
+ torch.tensor(l1, device=dev).unsqueeze(0),
84
+ torch.tensor(l2, device=dev).unsqueeze(0),
85
+ torch.tensor(l3, device=dev).unsqueeze(0),
86
+ ]
87
+ audio = snac.decode(codes).squeeze().cpu().numpy() # float32 @24 kHz
88
+ pcm16 = (audio * 32767).astype("int16").tobytes()
89
+ return pcm16
90
 
91
  # — FastAPI + WebSocket-Endpoint —
92
  app = FastAPI()
 
96
  await ws.accept()
97
  try:
98
  while True:
99
+ # 1) Nachricht empfangen
100
  msg = await ws.receive_text()
101
  data = json.loads(msg)
102
  text = data.get("text", "")
103
  voice = data.get("voice", "Jakob")
104
 
105
+ # 2) Prompt → IDs/Mask
106
  ids, mask = process_prompt(text, voice)
107
 
108
+ # 3) Token-Generation
109
  gen_ids = model.generate(
110
  input_ids=ids,
111
  attention_mask=mask,
112
+ max_new_tokens=2000,
113
  do_sample=True,
114
  temperature=0.7,
115
  top_p=0.95,
116
  repetition_penalty=1.1,
117
+ eos_token_id=128258,
118
  )
119
 
120
+ # 4) Parse + SNACPCM16‑Bytes
121
+ codes = parse_output(gen_ids)
122
+ pcm16 = redistribute_codes(codes)
123
+ chunk_sz = 2400 * 2 # 0.1 s @24 kHz
124
 
125
+ # 5) Stream audio‑Chunks
126
+ for i in range(0, len(pcm16), chunk_sz):
127
+ await ws.send_bytes(pcm16[i : i + chunk_sz])
 
 
128
  await asyncio.sleep(0.1)
129
 
130
+ # 6) Ende‑Signal
131
+ await ws.send_json({"event": "eos"})
132
+
133
+ # (Verbindung bleibt offen für nächste Anfrage)
134
+
135
  except WebSocketDisconnect:
136
  print("Client disconnected")
137
  except Exception as e:
138
  print("Error in /ws/tts:", e)
139
+ # Schließe erst, nachdem Fehler gemeldet
140
  await ws.close(code=1011)
141
 
142
  if __name__ == "__main__":