Tomtom84 commited on
Commit
3281189
·
1 Parent(s): 67c3132
Files changed (1) hide show
  1. app.py +40 -37
app.py CHANGED
@@ -40,9 +40,11 @@ model = AutoModelForCausalLM.from_pretrained(
40
  torch_dtype=torch.bfloat16
41
  ).to(device)
42
  model.config.pad_token_id = model.config.eos_token_id
43
-
44
  tokenizer = AutoTokenizer.from_pretrained(model_name)
45
 
 
 
 
46
  # — Hilfsfunktionen —
47
 
48
  def process_prompt(text: str, voice: str):
@@ -67,39 +69,41 @@ def parse_output(generated_ids: torch.LongTensor):
67
  else:
68
  cropped = generated_ids
69
 
70
- # Entferne EOS‑Token
71
  row = cropped[0]
 
72
  return row[row != token_to_remove].tolist()
73
 
74
- def redistribute_codes(code_list: list[int], snac_model: SNAC):
75
  """
76
- Verteilt die Token nur in kompletten 7er‑Blöcken auf die drei SNAC‑Layer
77
- und dekodiert in Audio. Unvollständige Reste (<7 Tokens) werden verworfen.
78
  """
79
- n_blocks = len(code_list) // 7
80
- layer1, layer2, layer3 = [], [], []
81
 
82
- for i in range(n_blocks):
83
- base = code_list[7*i : 7*i + 7]
84
- layer1.append(base[0])
85
- layer2.append(base[1] - 4096)
86
- layer3.append(base[2] - 2*4096)
87
- layer3.append(base[3] - 3*4096)
88
- layer2.append(base[4] - 4*4096)
89
- layer3.append(base[5] - 5*4096)
90
- layer3.append(base[6] - 6*4096)
91
-
92
- if not layer1:
93
- # kein kompletter Block → leeres Audio
94
  return np.zeros(0, dtype=np.float32)
95
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  dev = next(snac_model.parameters()).device
97
- codes = [
98
- torch.tensor(layer1, device=dev).unsqueeze(0),
99
- torch.tensor(layer2, device=dev).unsqueeze(0),
100
- torch.tensor(layer3, device=dev).unsqueeze(0),
101
- ]
102
- audio = snac_model.decode(codes)
103
  return audio.detach().squeeze().cpu().numpy()
104
 
105
  # — FastAPI Setup —
@@ -107,27 +111,27 @@ def redistribute_codes(code_list: list[int], snac_model: SNAC):
107
  app = FastAPI()
108
 
109
  @app.get("/")
110
- def greet_json():
111
- return {"Hello": "World!"}
112
 
113
  @app.websocket("/ws/tts")
114
  async def tts_ws(ws: WebSocket):
115
  await ws.accept()
116
  try:
117
  while True:
118
- # Erwartet JSON: {"text": "...", "voice": "Jakob"}
119
  data = json.loads(await ws.receive_text())
120
  text = data.get("text", "")
121
  voice = data.get("voice", "Jakob")
122
 
123
- # 1) Tokens vorbereiten
124
  ids, mask = process_prompt(text, voice)
125
 
126
  # 2) Generierung
127
  gen_ids = model.generate(
128
  input_ids=ids,
129
  attention_mask=mask,
130
- max_new_tokens=2000, # hier nach Bedarf anpassen
131
  do_sample=True,
132
  temperature=0.7,
133
  top_p=0.95,
@@ -135,18 +139,17 @@ async def tts_ws(ws: WebSocket):
135
  eos_token_id=model.config.eos_token_id,
136
  )
137
 
138
- # 3) Tokens → Code-ListeAudio
139
- codes = parse_output(gen_ids)
140
- audio_np = redistribute_codes(codes, snac)
 
141
 
142
- # 4) in 0.1sStücken PCM16 streamen
143
- pcm16 = (audio_np * 32767).astype("int16").tobytes()
144
- chunk = 2400 * 2 # 2400 samples @24kHz = 0.1s * 2 bytes
145
  for i in range(0, len(pcm16), chunk):
146
  await ws.send_bytes(pcm16[i : i+chunk])
147
  await asyncio.sleep(0.1)
148
 
149
- # Ende der while‐Schleife
150
  except WebSocketDisconnect:
151
  print("Client disconnected")
152
  except Exception as e:
 
40
  torch_dtype=torch.bfloat16
41
  ).to(device)
42
  model.config.pad_token_id = model.config.eos_token_id
 
43
  tokenizer = AutoTokenizer.from_pretrained(model_name)
44
 
45
+ # — Konstanten —
46
+ AUDIO_TOKEN_OFFSET = 128266 # globaler Offset der Audio‑Tokens
47
+
48
  # — Hilfsfunktionen —
49
 
50
  def process_prompt(text: str, voice: str):
 
69
  else:
70
  cropped = generated_ids
71
 
 
72
  row = cropped[0]
73
+ # entferne das EOS‑Marker‑Token
74
  return row[row != token_to_remove].tolist()
75
 
76
+ def redistribute_codes(raw_codes: list[int], snac_model: SNAC):
77
  """
78
+ Subtrahiere erst den globalen Offset, dann packe in 7er-Blöcke und dekodiere.
79
+ Unvollständige Reste (<7 Tokens) werden verworfen.
80
  """
81
+ # 1) Offset abziehen
82
+ codes = [c - AUDIO_TOKEN_OFFSET for c in raw_codes]
83
 
84
+ # 2) Nur ganze 7er‑Blöcke
85
+ n_blocks = len(codes) // 7
86
+ if n_blocks == 0:
 
 
 
 
 
 
 
 
 
87
  return np.zeros(0, dtype=np.float32)
88
 
89
+ layer1, layer2, layer3 = [], [], []
90
+ for i in range(n_blocks):
91
+ b = codes[7*i : 7*i+7]
92
+ layer1.append(b[0])
93
+ layer2.append(b[1] - 4096)
94
+ layer3.append(b[2] - 2*4096)
95
+ layer3.append(b[3] - 3*4096)
96
+ layer2.append(b[4] - 4*4096)
97
+ layer3.append(b[5] - 5*4096)
98
+ layer3.append(b[6] - 6*4096)
99
+
100
+ # 3) SNAC‑Layer‑Tensors bauen und dekodieren
101
  dev = next(snac_model.parameters()).device
102
+ t1 = torch.tensor(layer1, device=dev).unsqueeze(0)
103
+ t2 = torch.tensor(layer2, device=dev).unsqueeze(0)
104
+ t3 = torch.tensor(layer3, device=dev).unsqueeze(0)
105
+ audio = snac_model.decode([t1, t2, t3])
106
+
 
107
  return audio.detach().squeeze().cpu().numpy()
108
 
109
  # — FastAPI Setup —
 
111
  app = FastAPI()
112
 
113
  @app.get("/")
114
+ async def hello():
115
+ return {"message": "Hello World"}
116
 
117
  @app.websocket("/ws/tts")
118
  async def tts_ws(ws: WebSocket):
119
  await ws.accept()
120
  try:
121
  while True:
122
+ # Empfang: {"text":"...", "voice":"Jakob"}
123
  data = json.loads(await ws.receive_text())
124
  text = data.get("text", "")
125
  voice = data.get("voice", "Jakob")
126
 
127
+ # 1) Eingabe → Tokens
128
  ids, mask = process_prompt(text, voice)
129
 
130
  # 2) Generierung
131
  gen_ids = model.generate(
132
  input_ids=ids,
133
  attention_mask=mask,
134
+ max_new_tokens=2000, # nach Bedarf hochsetzen
135
  do_sample=True,
136
  temperature=0.7,
137
  top_p=0.95,
 
139
  eos_token_id=model.config.eos_token_id,
140
  )
141
 
142
+ # 3) Tokens → Audio‑CodesPCM
143
+ raw_codes = parse_output(gen_ids)
144
+ audio_np = redistribute_codes(raw_codes, snac)
145
+ pcm16 = (audio_np * 32767).astype("int16").tobytes()
146
 
147
+ # 4) Stream in 0.1 sChunks
148
+ chunk = 2400 * 2 # 2400 Samples @24 kHz = 0.1 s * 2 Bytes
 
149
  for i in range(0, len(pcm16), chunk):
150
  await ws.send_bytes(pcm16[i : i+chunk])
151
  await asyncio.sleep(0.1)
152
 
 
153
  except WebSocketDisconnect:
154
  print("Client disconnected")
155
  except Exception as e: