Tomtom84 commited on
Commit
dbb4a9f
·
verified ·
1 Parent(s): 325e9ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +329 -185
app.py CHANGED
@@ -1,224 +1,368 @@
1
  # app.py ──────────────────────────────────────────────────────────────
2
- import os, json, torch, asyncio
 
 
 
 
 
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, DynamicCache # Added StaticCache
6
- from snac import SNAC
 
 
 
 
 
 
 
 
 
 
7
 
8
  # 0) Login + Device ---------------------------------------------------
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
  if HF_TOKEN:
 
11
  login(HF_TOKEN)
12
 
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- #torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2‑Bug
15
 
16
  # 1) Konstanten -------------------------------------------------------
17
- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
18
- CHUNK_TOKENS = 50
19
- START_TOKEN = 128259
20
- NEW_BLOCK = 128257
21
- EOS_TOKEN = 128258
22
- AUDIO_BASE = 128266
23
- AUDIO_SPAN = 4096 * 7 # 28 672 Codes
24
- AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # Renamed VALID_AUDIO to AUDIO_IDS
25
-
26
- # 2) Logit‑Mask (NEW_BLOCK + Audio; EOS erst nach 1. Block) ----------
 
27
  class AudioMask(LogitsProcessor):
28
- def __init__(self, audio_ids: torch.Tensor):
29
  super().__init__()
 
30
  self.allow = torch.cat([
31
- torch.tensor([NEW_BLOCK], device=audio_ids.device),
32
  audio_ids
33
- ])
34
- self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
35
- self.sent_blocks = 0
36
- self.buffer_pos = 0 # Added buffer position
37
 
38
- def __call__(self, input_ids, scores):
39
- allow = torch.cat([self.allow, self.eos]) # Reverted masking logic
 
 
 
40
  mask = torch.full_like(scores, float("-inf"))
41
- mask[:, allow] = 0
 
 
42
  return scores + mask
43
 
44
- # 3) FastAPI Grundgerüst ---------------------------------------------
45
- app = FastAPI()
 
46
 
47
- @app.get("/")
48
- def hello():
49
- return {"status": "ok"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  @app.on_event("startup")
52
- def load_models():
53
- global tok, model, snac, masker
 
 
54
  print("⏳ Lade Modelle …", flush=True)
55
 
56
- tok = AutoTokenizer.from_pretrained(REPO)
57
- snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
 
 
 
 
 
 
58
  model = AutoModelForCausalLM.from_pretrained(
59
  REPO,
60
- device_map={"": 0} if device == "cuda" else None,
61
- torch_dtype=torch.bfloat16 if device == "cuda" else None,
62
- low_cpu_mem_usage=True,
63
  )
64
- model.config.pad_token_id = model.config.eos_token_id
65
- masker = AudioMask(AUDIO_IDS.to(device))
66
-
67
- print("✅ Modelle geladen", flush=True)
68
-
69
- # 4) Helper -----------------------------------------------------------
70
- def build_prompt(text: str, voice: str):
71
- prompt_ids = tok(f"{voice}: {text}", return_tensors="pt").input_ids.to(device)
72
- ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
73
- prompt_ids,
74
- torch.tensor([[128009, 128260]], device=device)], 1)
75
- attn = torch.ones_like(ids)
76
- return ids, attn # Ensure attention mask is created
77
-
78
- def decode_block(block7: list[int]) -> bytes:
79
- l1,l2,l3=[],[],[]
80
- l1.append(block7[0] - 0 * 4096) # Subtract position 0 offset
81
- l2.append(block7[1] - 1 * 4096) # Subtract position 1 offset
82
- l3 += [block7[2] - 2 * 4096, block7[3] - 3 * 4096] # Subtract position offsets
83
- l2.append(block7[4] - 4 * 4096) # Subtract position 4 offset
84
- l3 += [block7[5] - 5 * 4096, block7[6] - 6 * 4096] # Subtract position offsets
85
-
86
- with torch.no_grad():
87
- codes = [torch.tensor(x, device=device).unsqueeze(0)
88
- for x in (l1,l2,l3)]
89
- audio = snac.decode(codes).squeeze().detach().cpu().numpy()
90
-
91
- return (audio*32767).astype("int16").tobytes()
92
-
93
- # 5) WebSocket‑Endpoint ----------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
94
  @app.websocket("/ws/tts")
95
  async def tts(ws: WebSocket):
96
  await ws.accept()
 
 
 
 
97
  try:
98
- req = json.loads(await ws.receive_text())
99
- text = req.get("text", "")
100
- voice = req.get("voice", "Jakob")
 
 
 
101
 
 
 
 
 
 
 
 
102
  ids, attn = build_prompt(text, voice)
103
- past = None
104
- ids, attn = build_prompt(text, voice)
105
- past = None # Holds the DynamicCache object from past_key_values
106
- buf = []
107
- last_tok = None # Initialize last_tok
108
-
109
- while True:
110
- # Determine inputs for this iteration
111
- if past is None:
112
- # First iteration: Use the full prompt
113
- current_input_ids = ids
114
- current_attn_mask = attn
115
- # DO NOT pass cache_position on the first run
116
- current_cache_position = None
117
- else:
118
- # Subsequent iterations: Use only the last token
119
- if last_tok is None:
120
- print("Error: last_tok is None before subsequent generate call.")
121
- break # Should not happen if generation proceeded
122
- current_input_ids = torch.tensor([[last_tok]], device=device)
123
- current_attn_mask = None # Not needed when past_key_values is provided
124
- # DO NOT pass cache_position; let DynamicCache handle it
125
- current_cache_position = None
126
-
127
- # --- Call model.generate ---
128
- try:
129
- gen = model.generate(
130
- input_ids=current_input_ids,
131
- attention_mask=current_attn_mask,
132
- past_key_values=past,
133
- cache_position=current_cache_position, # Will be None after first iteration
134
- max_new_tokens=CHUNK_TOKENS,
135
- logits_processor=[masker],
136
- do_sample=True, temperature=0.7, top_p=0.95,
137
- use_cache=True,
138
- return_dict_in_generate=True,
139
- return_legacy_cache=False # Ensures DynamicCache
140
- )
141
- except Exception as e:
142
- print(f"❌ Error during model.generate: {e}")
143
- import traceback
144
- traceback.print_exc()
145
- break # Exit loop on generation error
146
-
147
- # --- Process Output ---
148
- # Get the full sequence generated *up to this point*
149
- full_sequence_now = gen.sequences # Get the sequence tensor
150
-
151
- # Determine the sequence length *before* this generation call using the cache
152
- # If past is None, the previous length was the initial prompt length
153
- prev_seq_len = past.get_seq_length() if past is not None else ids.shape
154
-
155
- # The new tokens are those generated *in this call*
156
- # These appear *after* the previously cached sequence length
157
- # Ensure slicing is correct even if no new tokens are generated
158
- if full_sequence_now.shape > prev_seq_len:
159
- new_token_ids = full_sequence_now[prev_seq_len:]
160
- new = new_token_ids.tolist() # Convert tensor to list
161
- else:
162
- new = [] # No new tokens generated
163
-
164
- if not new: # If no new tokens were generated, stop
165
- print("No new tokens generated, stopping.")
166
- break
167
-
168
- # Update past_key_values for the *next* iteration
169
- past = gen.past_key_values # Update the cache state
170
-
171
- # Get the very last token generated in *this* call for the *next* input
172
- last_tok = new[-1]
173
-
174
- # ----- Token‑Handling (process the 'new' list) -----
175
- eos_found = False
176
- for t in new:
177
- if t == EOS_TOKEN:
178
- print("EOS token encountered.")
179
- eos_found = True
180
- break # Stop processing tokens in this chunk
181
- if t == NEW_BLOCK:
182
- buf.clear()
183
- continue
184
- # Check if token is within the expected audio range
185
- if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
186
- buf.append(t - AUDIO_BASE)
187
- else:
188
- # Log unexpected tokens if necessary
189
- # print(f"Warning: Generated token {t} outside expected audio range.")
190
- pass # Ignore unexpected tokens for now
191
-
192
- if len(buf) == 7:
193
- await ws.send_bytes(decode_block(buf))
194
- buf.clear()
195
- # Allow EOS only after the first full block is sent
196
- if not masker.sent_blocks:
197
- masker.sent_blocks = 1
198
-
199
- if eos_found:
200
- # Handle any remaining buffer content if needed (e.g., log incomplete block)
201
- if len(buf) > 0:
202
- print(f"Warning: Incomplete audio block at EOS: {len(buf)} tokens. Discarding.")
203
- buf.clear()
204
- break # Exit the while loop
205
-
206
- except (StopIteration, WebSocketDisconnect):
207
- pass
208
  except Exception as e:
209
- print("❌ WS‑Error:", e, flush=True)
210
- import traceback
211
- traceback.print_exc()
212
- if ws.client_state.name != "DISCONNECTED":
213
- await ws.close(code=1011)
 
 
 
 
 
 
 
214
  finally:
 
 
 
 
 
 
 
 
 
215
  if ws.client_state.name != "DISCONNECTED":
216
  try:
217
- await ws.close()
218
- except RuntimeError:
219
- pass
 
 
 
 
220
 
221
- # 6) Dev‑Start --------------------------------------------------------
222
  if __name__ == "__main__":
223
- import uvicorn, sys
224
- uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")
 
 
 
1
  # app.py ──────────────────────────────────────────────────────────────
2
+ import os
3
+ import json
4
+ import torch
5
+ import asyncio
6
+ import traceback # Import traceback for better error logging
7
+
8
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from huggingface_hub import login
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
11
+ # Import BaseStreamer for the interface
12
+ from transformers.generation.streamers import BaseStreamer
13
+ from snac import SNAC # Ensure you have 'pip install snac'
14
+
15
+ # --- Globals (populated in load_models) ---
16
+ tok = None
17
+ model = None
18
+ snac = None
19
+ masker = None
20
+ stopping_criteria = None
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  # 0) Login + Device ---------------------------------------------------
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  if HF_TOKEN:
26
+ print("🔑 Logging in to Hugging Face Hub...")
27
  login(HF_TOKEN)
28
 
29
+ # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug
 
30
 
31
  # 1) Konstanten -------------------------------------------------------
32
+ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
33
+ # CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
34
+ START_TOKEN = 128259
35
+ NEW_BLOCK = 128257
36
+ EOS_TOKEN = 128258
37
+ AUDIO_BASE = 128266
38
+ AUDIO_SPAN = 4096 * 7 # 28672 Codes
39
+ # Create AUDIO_IDS on the correct device later in load_models
40
+ AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
41
+
42
+ # 2) Logit‑Mask -------------------------------------------------------
43
  class AudioMask(LogitsProcessor):
44
+ def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
45
  super().__init__()
46
+ # Allow NEW_BLOCK and all valid audio tokens initially
47
  self.allow = torch.cat([
48
+ torch.tensor([new_block_token_id], device=audio_ids.device), # Add NEW_BLOCK token ID
49
  audio_ids
50
+ ], dim=0)
51
+ self.eos = torch.tensor([eos_token_id], device=audio_ids.device) # Store EOS token ID as tensor
52
+ self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor
53
+ self.sent_blocks = 0 # State: Number of audio blocks sent
54
 
55
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
56
+ # Determine which tokens are allowed based on whether blocks have been sent
57
+ current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
58
+
59
+ # Create a mask initialized to negative infinity
60
  mask = torch.full_like(scores, float("-inf"))
61
+ # Set allowed token scores to 0 (effectively allowing them)
62
+ mask[:, current_allow] = 0
63
+ # Apply the mask to the scores
64
  return scores + mask
65
 
66
+ def reset(self):
67
+ """Resets the state for a new generation request."""
68
+ self.sent_blocks = 0
69
 
70
+ # 3) StoppingCriteria für EOS ---------------------------------------
71
+ # generate() needs explicit stopping criteria when using a streamer
72
+ class EosStoppingCriteria(StoppingCriteria):
73
+ def __init__(self, eos_token_id: int):
74
+ self.eos_token_id = eos_token_id
75
+
76
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
77
+ # Check if the *last* generated token is the EOS token
78
+ if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
79
+ # print("StoppingCriteria: EOS detected.")
80
+ return True
81
+ return False
82
+
83
+ # 4) Benutzerdefinierter AudioStreamer -------------------------------
84
+ class AudioStreamer(BaseStreamer):
85
+ def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop):
86
+ self.ws = ws
87
+ self.snac = snac_decoder
88
+ self.masker = audio_mask # Reference to the mask to update sent_blocks
89
+ self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
90
+ self.device = snac_decoder.device # Get device from the decoder
91
+ self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
92
+ self.tasks = set() # Keep track of pending send tasks
93
+
94
+ def _decode_block(self, block7: list[int]) -> bytes:
95
+ """
96
+ Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
97
+ NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
98
+ is based on a common interleaving hypothesis. Verify if model docs specify otherwise.
99
+ """
100
+ if len(block7) != 7:
101
+ print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
102
+ return b"" # Return empty bytes if block is incomplete
103
+
104
+ # --- Hypothesis: Interleaving mapping ---
105
+ # Assumes 7 tokens map to 3 codebooks like this:
106
+ # Codebook 1 (l1) uses tokens at indices 0, 3, 6
107
+ # Codebook 2 (l2) uses tokens at indices 1, 4
108
+ # Codebook 3 (l3) uses tokens at indices 2, 5
109
+ try:
110
+ l1 = [block7[0], block7[3], block7[6]]
111
+ l2 = [block7[1], block7[4]]
112
+ l3 = [block7[2], block7[5]]
113
+ except IndexError:
114
+ print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
115
+ return b""
116
+
117
+ # Convert lists to tensors on the correct device
118
+ codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
119
+ codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
120
+ codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
121
+ codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC
122
+
123
+ # Decode using SNAC
124
+ with torch.no_grad():
125
+ audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
126
+
127
+ # Squeeze, move to CPU, convert to numpy
128
+ audio_np = audio.squeeze().detach().cpu().numpy()
129
+
130
+ # Convert to 16-bit PCM bytes
131
+ audio_bytes = (audio_np * 32767).astype("int16").tobytes()
132
+ return audio_bytes
133
+
134
+ async def _send_audio_bytes(self, data: bytes):
135
+ """Coroutine to send bytes over WebSocket."""
136
+ if not data: # Don't send empty bytes
137
+ return
138
+ try:
139
+ await self.ws.send_bytes(data)
140
+ # print(f"Streamer: Sent {len(data)} audio bytes.")
141
+ except WebSocketDisconnect:
142
+ print("Streamer: WebSocket disconnected during send.")
143
+ except Exception as e:
144
+ print(f"Streamer: Error sending bytes: {e}")
145
+
146
+ def put(self, value: torch.LongTensor):
147
+ """
148
+ Receives new token IDs (Tensor) from generate() (runs in worker thread).
149
+ Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe.
150
+ """
151
+ # Ensure value is on CPU and flatten to a list of ints
152
+ if value.numel() == 0:
153
+ return
154
+ new_token_ids = value.squeeze().tolist()
155
+ if isinstance(new_token_ids, int): # Handle single token case
156
+ new_token_ids = [new_token_ids]
157
+
158
+ for t in new_token_ids:
159
+ if t == EOS_TOKEN:
160
+ # print("Streamer: EOS token encountered.")
161
+ # EOS is handled by StoppingCriteria, no action needed here except maybe logging.
162
+ break # Stop processing this batch if EOS is found
163
+
164
+ if t == NEW_BLOCK:
165
+ # print("Streamer: NEW_BLOCK token encountered.")
166
+ # NEW_BLOCK indicates the start of audio, might reset buffer if needed
167
+ self.buf.clear()
168
+ continue # Move to the next token
169
+
170
+ # Check if token is within the expected audio range
171
+ if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
172
+ self.buf.append(t - AUDIO_BASE) # Store value relative to base
173
+ else:
174
+ # Log unexpected tokens (like START_TOKEN or others if generation goes wrong)
175
+ # print(f"Streamer Warning: Ignoring unexpected token {t}")
176
+ pass # Ignore tokens outside the audio range
177
+
178
+ # If buffer has 7 tokens, decode and send
179
+ if len(self.buf) == 7:
180
+ audio_bytes = self._decode_block(self.buf)
181
+ self.buf.clear() # Clear buffer after processing
182
+
183
+ if audio_bytes: # Only send if decoding was successful
184
+ # Schedule the async send function to run on the main event loop
185
+ future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
186
+ self.tasks.add(future)
187
+ # Optional: Remove completed tasks to prevent memory leak if generation is very long
188
+ future.add_done_callback(self.tasks.discard)
189
+
190
+
191
+ # Allow EOS only after the first full block has been processed and scheduled for sending
192
+ if self.masker.sent_blocks == 0:
193
+ # print("Streamer: First audio block processed, allowing EOS.")
194
+ self.masker.sent_blocks = 1 # Update state in the mask
195
+
196
+ # Note: No need to explicitly wait for tasks here. put() should return quickly.
197
+
198
+ def end(self):
199
+ """Called by generate() when generation finishes."""
200
+ # Handle any remaining tokens in the buffer (optional, here we discard them)
201
+ if len(self.buf) > 0:
202
+ print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
203
+ self.buf.clear()
204
+
205
+ # Optional: Wait briefly for any outstanding send tasks to complete?
206
+ # This is tricky because end() is sync. A robust solution might involve
207
+ # signaling the WebSocket handler to wait before closing.
208
+ # For simplicity, we rely on FastAPI/Uvicorn's graceful shutdown handling.
209
+ # print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}")
210
+ pass
211
+
212
+ # 5) FastAPI App ------------------------------------------------------
213
+ app = FastAPI()
214
 
215
  @app.on_event("startup")
216
+ async def load_models_startup(): # Make startup async if needed for future async loads
217
+ global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
218
+
219
+ print(f"🚀 Starting up on device: {device}")
220
  print("⏳ Lade Modelle …", flush=True)
221
 
222
+ tok = AutoTokenizer.from_pretrained(REPO)
223
+ print("Tokenizer loaded.")
224
+
225
+ # Load SNAC first (usually smaller)
226
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
227
+ print(f"SNAC loaded to {snac.device}.")
228
+
229
+ # Load the main model
230
  model = AutoModelForCausalLM.from_pretrained(
231
  REPO,
232
+ device_map={"": 0} if device == "cuda" else None, # Assign to GPU 0 if cuda
233
+ torch_dtype=torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float32, # Use bfloat16 if supported
234
+ low_cpu_mem_usage=True, # Good practice for large models
235
  )
236
+ model.config.pad_token_id = model.config.eos_token_id # Set pad token
237
+ print(f"Model loaded to {model.device}.")
238
+
239
+ # Ensure model is in evaluation mode
240
+ model.eval()
241
+
242
+ # Initialize AudioMask (needs AUDIO_IDS on the correct device)
243
+ audio_ids_device = AUDIO_IDS_CPU.to(device)
244
+ masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
245
+ print("AudioMask initialized.")
246
+
247
+ # Initialize StoppingCriteria
248
+ # IMPORTANT: Create the list and add the criteria instance
249
+ stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
250
+ print("StoppingCriteria initialized.")
251
+
252
+ print("✅ Modelle geladen und bereit!", flush=True)
253
+
254
+ @app.get("/")
255
+ def hello():
256
+ return {"status": "ok", "message": "TTS Service is running"}
257
+
258
+ # 6) Helper zum Prompt Bauen -------------------------------------------
259
+ def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
260
+ """Builds the input_ids and attention_mask for the model."""
261
+ # Format: <START> <VOICE>: <TEXT> <NEW_BLOCK>
262
+ prompt_text = f"{voice}: {text}"
263
+ prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)
264
+
265
+ # Construct input_ids tensor
266
+ input_ids = torch.cat([
267
+ torch.tensor([[START_TOKEN]], device=device), # Start token
268
+ prompt_ids, # Encoded prompt
269
+ torch.tensor([[NEW_BLOCK]], device=device) # New block token to trigger audio
270
+ ], dim=1)
271
+
272
+ # Create attention mask (all ones)
273
+ attention_mask = torch.ones_like(input_ids)
274
+ return input_ids, attention_mask
275
+
276
+ # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
277
  @app.websocket("/ws/tts")
278
  async def tts(ws: WebSocket):
279
  await ws.accept()
280
+ print(" клиент подключился") # Client connected
281
+ streamer = None # Initialize for finally block
282
+ main_loop = asyncio.get_running_loop() # Get the current event loop
283
+
284
  try:
285
+ # Receive configuration
286
+ req_text = await ws.receive_text()
287
+ print(f"Received request: {req_text}")
288
+ req = json.loads(req_text)
289
+ text = req.get("text", "Hallo Welt, wie geht es dir heute?") # Default text
290
+ voice = req.get("voice", "Jakob") # Default voice
291
 
292
+ if not text:
293
+ await ws.close(code=1003, reason="Text cannot be empty")
294
+ return
295
+
296
+ print(f"Generating audio for: '{text}' with voice '{voice}'")
297
+
298
+ # Prepare prompt
299
  ids, attn = build_prompt(text, voice)
300
+
301
+ # --- Reset stateful components ---
302
+ masker.reset() # CRITICAL: Reset the mask state for the new request
303
+
304
+ # --- Create Streamer Instance ---
305
+ streamer = AudioStreamer(ws, snac, masker, main_loop)
306
+
307
+ # --- Run model.generate in a separate thread ---
308
+ # This prevents blocking the main FastAPI event loop
309
+ print("Starting generation...")
310
+ await asyncio.to_thread(
311
+ model.generate,
312
+ input_ids=ids,
313
+ attention_mask=attn,
314
+ max_new_tokens=1500, # Limit generation length (adjust as needed)
315
+ logits_processor=[masker],
316
+ stopping_criteria=stopping_criteria,
317
+ do_sample=False, # Use greedy decoding for potentially more stable audio
318
+ # do_sample=True, temperature=0.7, top_p=0.95, # Or use sampling
319
+ use_cache=True,
320
+ streamer=streamer # Pass the custom streamer
321
+ # No need to manage past_key_values manually
322
+ )
323
+ print("Generation finished.")
324
+
325
+ except WebSocketDisconnect:
326
+ print("Client disconnected.")
327
+ except json.JSONDecodeError:
328
+ print("❌ Invalid JSON received.")
329
+ await ws.close(code=1003, reason="Invalid JSON format") # 1003 = Cannot accept data type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  except Exception as e:
331
+ error_details = traceback.format_exc()
332
+ print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
333
+ # Try to send an error message before closing, if possible
334
+ error_payload = json.dumps({"error": str(e)})
335
+ try:
336
+ if ws.client_state.name == "CONNECTED":
337
+ await ws.send_text(error_payload) # Send error as text/json
338
+ except Exception:
339
+ pass # Ignore error during error reporting
340
+ # Close with internal server error code
341
+ if ws.client_state.name == "CONNECTED":
342
+ await ws.close(code=1011) # 1011 = Internal Server Error
343
  finally:
344
+ # Ensure streamer's end method is called if it exists
345
+ if streamer:
346
+ try:
347
+ streamer.end()
348
+ except Exception as e_end:
349
+ print(f"Error during streamer.end(): {e_end}")
350
+
351
+ # Ensure WebSocket is closed
352
+ print("Closing connection.")
353
  if ws.client_state.name != "DISCONNECTED":
354
  try:
355
+ await ws.close(code=1000) # 1000 = Normal Closure
356
+ except RuntimeError as e_close:
357
+ # Can happen if connection is already closing/closed
358
+ print(f"Runtime error closing websocket: {e_close}")
359
+ except Exception as e_close_final:
360
+ print(f"Error closing websocket: {e_close_final}")
361
+ print("Connection closed.")
362
 
363
+ # 8) Dev‑Start --------------------------------------------------------
364
  if __name__ == "__main__":
365
+ import uvicorn
366
+ print("Starting Uvicorn server...")
367
+ # Use reload=True only for development, remove for production
368
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") #, reload=True)