Tomtom84 commited on
Commit
e97a876
·
verified ·
1 Parent(s): 674acbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -36
app.py CHANGED
@@ -14,17 +14,15 @@ HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
- # — Debug: CPU‑Modus zum Entwickeln, später wieder "cuda"
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
- #device = "cpu"
20
 
21
  # — Modelle laden —
22
- print("Loading SNAC model...")
23
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
24
 
25
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
26
- # optional: explizites snapshot_download (entfernt große Dateien)
27
-
28
  snapshot_download(
29
  repo_id=model_name,
30
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
@@ -35,19 +33,24 @@ snapshot_download(
35
  ]
36
  )
37
 
38
- print("Loading Orpheus model...")
39
  model = AutoModelForCausalLM.from_pretrained(
40
  model_name, torch_dtype=torch.bfloat16
41
  ).to(device)
42
  model.config.pad_token_id = model.config.eos_token_id
43
-
44
  tokenizer = AutoTokenizer.from_pretrained(model_name)
45
 
 
 
 
 
46
  # — Hilfsfunktionen —
47
 
48
  def process_prompt(text: str, voice: str):
49
  prompt = f"{voice}: {text}"
50
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 
 
51
  start = torch.tensor([[128259]], dtype=torch.int64)
52
  end = torch.tensor([[128009, 128260]], dtype=torch.int64)
53
  ids = torch.cat([start, input_ids, end], dim=1).to(device)
@@ -55,30 +58,32 @@ def process_prompt(text: str, voice: str):
55
  return ids, mask
56
 
57
  def parse_output(generated_ids: torch.LongTensor):
58
- """Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
59
- token_to_find = 128257
60
- token_to_remove = 128258
61
-
62
- # 1) Finde letztes Start-Token, croppe
63
- idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
 
 
 
64
  if idxs.numel() > 0:
65
  cut = idxs[-1].item() + 1
66
  cropped = generated_ids[:, cut:]
67
  else:
68
  cropped = generated_ids
69
 
70
- # 2) Entferne Padding-Markierungen
71
- rows = []
72
- for row in cropped:
73
- rows.append(row[row != token_to_remove])
74
-
75
- # 3) Flache Liste zurückgeben
76
- return rows[0].tolist()
77
 
78
  def redistribute_codes(code_list: list[int], snac_model: SNAC):
79
- """Verteile die Codes auf drei Layer, dekodiere in Audio."""
 
 
80
  layer1, layer2, layer3 = [], [], []
81
- for i in range((len(code_list) + 1) // 7):
82
  base = code_list[7*i : 7*i+7]
83
  layer1.append(base[0])
84
  layer2.append(base[1] - 4096)
@@ -89,13 +94,11 @@ def redistribute_codes(code_list: list[int], snac_model: SNAC):
89
  layer3.append(base[6] - 6*4096)
90
 
91
  dev = next(snac_model.parameters()).device
92
- codes = [
93
- torch.tensor(layer1, device=dev).unsqueeze(0),
94
- torch.tensor(layer2, device=dev).unsqueeze(0),
95
- torch.tensor(layer3, device=dev).unsqueeze(0),
96
- ]
97
- audio = snac_model.decode(codes)
98
- return audio.detach().squeeze().cpu().numpy() # float32 @24 kHz
99
 
100
  # — FastAPI + WebSocket-Endpoint —
101
  app = FastAPI()
@@ -110,28 +113,28 @@ async def tts_ws(ws: WebSocket):
110
  text = data.get("text", "")
111
  voice = data.get("voice", "Jakob")
112
 
113
- # 1) Prompt → Tokens
114
  ids, mask = process_prompt(text, voice)
115
 
116
- # 2) Token-Generation (erst klein testen!)
117
  gen_ids = model.generate(
118
  input_ids=ids,
119
  attention_mask=mask,
120
- max_new_tokens=200, # zum Debuggen klein halten
121
  do_sample=True,
122
  temperature=0.7,
123
  top_p=0.95,
124
  repetition_penalty=1.1,
125
- eos_token_id=128258,
126
  )
127
 
128
- # 3) Tokens → Code-Liste → Audio
129
  code_list = parse_output(gen_ids)
130
  audio_np = redistribute_codes(code_list, snac)
131
 
132
- # 4) PCM16-Bytes und Stream in 0.1s-Chunks
133
  pcm16 = (audio_np * 32767).astype("int16").tobytes()
134
- chunk = 2400 * 2 # 2400 samples @24kHz → 0.1s * 2 bytes
135
  for i in range(0, len(pcm16), chunk):
136
  await ws.send_bytes(pcm16[i : i+chunk])
137
  await asyncio.sleep(0.1)
 
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
+ # — Device
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
19
 
20
  # — Modelle laden —
21
+ print("Loading SNAC model")
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
24
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
25
+ # Nur die Konfig + Safetensors, alles andere wird ignoriert
 
26
  snapshot_download(
27
  repo_id=model_name,
28
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
 
33
  ]
34
  )
35
 
36
+ print("Loading Orpheus model")
37
  model = AutoModelForCausalLM.from_pretrained(
38
  model_name, torch_dtype=torch.bfloat16
39
  ).to(device)
40
  model.config.pad_token_id = model.config.eos_token_id
 
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
 
43
+ # — Konstanten für Audio‑Token →
44
+ # (muss übereinstimmen mit Deinem Training; hier 128266)
45
+ AUDIO_TOKEN_OFFSET = 128266
46
+
47
  # — Hilfsfunktionen —
48
 
49
  def process_prompt(text: str, voice: str):
50
  prompt = f"{voice}: {text}"
51
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
52
+ # Laut Spezifikation:
53
+ # start_token=128259, end_tokens=(128009,128260)
54
  start = torch.tensor([[128259]], dtype=torch.int64)
55
  end = torch.tensor([[128009, 128260]], dtype=torch.int64)
56
  ids = torch.cat([start, input_ids, end], dim=1).to(device)
 
58
  return ids, mask
59
 
60
  def parse_output(generated_ids: torch.LongTensor):
61
+ """
62
+ Croppt nach dem letzten 128257-Start-Token, entfernt Padding (128258)
63
+ und zieht dann den Audio‑Offset ab, um echte Code‑IDs zu bekommen.
64
+ """
65
+ # finde letztes Audio‑StartToken
66
+ token_to_start = 128257
67
+ token_to_remove = model.config.eos_token_id # 128258
68
+
69
+ idxs = (generated_ids == token_to_start).nonzero(as_tuple=True)[1]
70
  if idxs.numel() > 0:
71
  cut = idxs[-1].item() + 1
72
  cropped = generated_ids[:, cut:]
73
  else:
74
  cropped = generated_ids
75
 
76
+ # flatten & remove PAD, dann Offset abziehen
77
+ flat = cropped[0][cropped[0] != token_to_remove]
78
+ codes = [(int(t) - AUDIO_TOKEN_OFFSET) for t in flat]
79
+ return codes
 
 
 
80
 
81
  def redistribute_codes(code_list: list[int], snac_model: SNAC):
82
+ """
83
+ Verteilt die flache Code‑Liste in 3 Layers und dekodiert mit SNAC.
84
+ """
85
  layer1, layer2, layer3 = [], [], []
86
+ for i in range(len(code_list) // 7):
87
  base = code_list[7*i : 7*i+7]
88
  layer1.append(base[0])
89
  layer2.append(base[1] - 4096)
 
94
  layer3.append(base[6] - 6*4096)
95
 
96
  dev = next(snac_model.parameters()).device
97
+ c1 = torch.tensor(layer1, device=dev).unsqueeze(0)
98
+ c2 = torch.tensor(layer2, device=dev).unsqueeze(0)
99
+ c3 = torch.tensor(layer3, device=dev).unsqueeze(0)
100
+ audio = snac_model.decode([c1, c2, c3])
101
+ return audio.detach().squeeze().cpu().numpy()
 
 
102
 
103
  # — FastAPI + WebSocket-Endpoint —
104
  app = FastAPI()
 
113
  text = data.get("text", "")
114
  voice = data.get("voice", "Jakob")
115
 
116
+ # 1) Prompt → Token‑Tensoren
117
  ids, mask = process_prompt(text, voice)
118
 
119
+ # 2) Generation
120
  gen_ids = model.generate(
121
  input_ids=ids,
122
  attention_mask=mask,
123
+ max_new_tokens=200, # zum Debug
124
  do_sample=True,
125
  temperature=0.7,
126
  top_p=0.95,
127
  repetition_penalty=1.1,
128
+ eos_token_id=model.config.eos_token_id,
129
  )
130
 
131
+ # 3) Token → CodeListe → Audio (Float32 @24 kHz)
132
  code_list = parse_output(gen_ids)
133
  audio_np = redistribute_codes(code_list, snac)
134
 
135
+ # 4) In 0.1 s‑Chunks (2400 Samples) als PCM16 streamen
136
  pcm16 = (audio_np * 32767).astype("int16").tobytes()
137
+ chunk = 2400 * 2
138
  for i in range(0, len(pcm16), chunk):
139
  await ws.send_bytes(pcm16[i : i+chunk])
140
  await asyncio.sleep(0.1)