File size: 17,144 Bytes
0b5b901
dbb4a9f
 
 
 
 
 
4189fe1
9bf14d0
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
0316ec3
e3958ab
479f253
 
dbb4a9f
479f253
2008a3f
dbb4a9f
e3958ab
 
dbb4a9f
 
55145d2
1d792aa
55145d2
 
53012c3
dbb4a9f
 
 
e3958ab
dbb4a9f
479f253
1d792aa
 
 
 
 
d11cc63
1d792aa
e3958ab
dbb4a9f
55145d2
a0cc672
dbb4a9f
a0cc672
e3958ab
dbb4a9f
 
0dfc310
dbb4a9f
 
 
 
 
 
 
1d792aa
dbb4a9f
 
 
 
 
1d792aa
dbb4a9f
 
53012c3
 
641d199
53012c3
 
dbb4a9f
 
 
 
1d792aa
 
dbb4a9f
 
1d792aa
 
dbb4a9f
 
53012c3
 
 
 
 
 
 
 
 
 
 
 
 
 
55145d2
 
dbb4a9f
d11cc63
53012c3
 
dbb4a9f
53012c3
 
 
 
 
 
 
 
 
dbb4a9f
53012c3
 
 
96dc59a
53012c3
 
1d792aa
 
 
 
53012c3
dbb4a9f
53012c3
 
 
 
 
 
 
 
dbb4a9f
 
 
53012c3
dbb4a9f
 
 
 
1d792aa
dbb4a9f
1d792aa
 
96dc59a
1d792aa
 
 
dbb4a9f
 
 
53012c3
 
dbb4a9f
 
 
1d792aa
 
53012c3
55145d2
dbb4a9f
 
1d792aa
 
 
 
 
 
 
 
 
55145d2
1d792aa
55145d2
1d792aa
96dc59a
1d792aa
55145d2
53012c3
1d792aa
 
 
55145d2
1d792aa
55145d2
 
1d792aa
55145d2
1d792aa
 
55145d2
 
1d792aa
55145d2
 
1d792aa
55145d2
1d792aa
 
55145d2
 
 
 
 
 
1d792aa
55145d2
 
 
 
 
 
53012c3
1d792aa
55145d2
 
 
 
 
 
 
 
96dc59a
55145d2
53012c3
55145d2
 
 
 
 
53012c3
55145d2
 
 
 
53012c3
55145d2
53012c3
55145d2
1d792aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55145d2
 
 
 
1d792aa
55145d2
 
1d792aa
55145d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53012c3
 
 
55145d2
 
 
 
 
 
 
 
 
 
53012c3
 
55145d2
 
 
 
 
53012c3
 
55145d2
 
 
53012c3
55145d2
 
 
 
53012c3
1d792aa
55145d2
 
1d792aa
55145d2
96dc59a
 
 
1d792aa
96dc59a
 
1d792aa
96dc59a
1d792aa
 
 
 
96dc59a
1d792aa
55145d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53012c3
55145d2
53012c3
55145d2
53012c3
55145d2
 
 
 
 
 
 
 
 
 
53012c3
55145d2
1d792aa
55145d2
 
 
 
 
 
 
 
 
 
1d792aa
 
53012c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
# app.py ──────────────────────────────────────────────────────────────
import os
import json
import torch
import asyncio
import traceback # Import traceback for better error logging

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
# Import BaseStreamer for the interface
from transformers.generation.streamers import BaseStreamer
from snac import SNAC # Ensure you have 'pip install snac'

# --- Globals (populated in load_models) ---
tok = None
model = None
snac = None
masker = None
stopping_criteria = None
device = "cuda" if torch.cuda.is_available() else "cpu"

# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    print("πŸ”‘ Logging in to Hugging Face Hub...")
    login(HF_TOKEN)

# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug

# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258 # Ensure this is correct for the model
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7  # 28672 Codes
CODEBOOK_SIZE = 4096  # Explicitly define the codebook size
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)

# 2) Logit‑Mask -------------------------------------------------------
class AudioMask(LogitsProcessor):
    def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
        super().__init__()
        self.allow = torch.cat([
            torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long),
            audio_ids
        ], dim=0)
        self.eos = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
        self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
        self.sent_blocks = 0 # State: Number of audio blocks sent

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
        mask = torch.full_like(scores, float("-inf"))
        mask[:, current_allow] = 0
        return scores + mask

    def reset(self):
        self.sent_blocks = 0

# 3) StoppingCriteria fΓΌr EOS ---------------------------------------
class EosStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_token_id: int):
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
            # print("StoppingCriteria: EOS detected.") # Optional: Uncomment for debugging
            return True
        return False

# 4) Benutzerdefinierter AudioStreamer -------------------------------
class AudioStreamer(BaseStreamer):
    def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
        self.ws = ws
        self.snac = snac_decoder
        self.masker = audio_mask
        self.loop = loop
        self.device = target_device
        self.buf: list[int] = []
        self.tasks = set()

    def _decode_block(self, block7: list[int]) -> bytes:
        """
        Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
        Uses modulo to extract base code value (0-4095).
        Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
        """
        if len(block7) != 7:
            print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
            return b""

        try:
            # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
            code_val_0 = block7[0] % CODEBOOK_SIZE
            code_val_1 = block7[1] % CODEBOOK_SIZE
            code_val_2 = block7[2] % CODEBOOK_SIZE
            code_val_3 = block7[3] % CODEBOOK_SIZE
            code_val_4 = block7[4] % CODEBOOK_SIZE
            code_val_5 = block7[5] % CODEBOOK_SIZE
            code_val_6 = block7[6] % CODEBOOK_SIZE

            # --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
            l1 = [code_val_0]
            l2 = [code_val_1, code_val_4]
            l3 = [code_val_2, code_val_3, code_val_5, code_val_6]

        except IndexError:
            print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
            return b""
        except Exception as e_map:
            print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
            return b""

        # --- Convert lists to tensors on the correct device ---
        try:
            codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
            codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
            codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
            codes = [codes_l1, codes_l2, codes_l3]
        except Exception as e_tensor:
            print(f"Streamer Error: Exception during tensor conversion: {e_tensor}. l1={l1}, l2={l2}, l3={l3}")
            return b""

        # --- Decode using SNAC ---
        try:
            with torch.no_grad():
                audio = self.snac.decode(codes)[0]
        except Exception as e_decode:
            print(f"Streamer Error: Exception during snac.decode: {e_decode}")
            print(f"Input codes shapes: {[c.shape for c in codes]}")
            print(f"Input codes dtypes: {[c.dtype for c in codes]}")
            print(f"Input codes devices: {[c.device for c in codes]}")
            print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
            return b""

        # --- Post-processing ---
        try:
            audio_np = audio.squeeze().detach().cpu().numpy()
            audio_bytes = (audio_np * 32767).astype("int16").tobytes()
            return audio_bytes
        except Exception as e_post:
            print(f"Streamer Error: Exception during post-processing: {e_post}. Audio tensor shape: {audio.shape}")
            return b""

    async def _send_audio_bytes(self, data: bytes):
        """Coroutine to send bytes over WebSocket."""
        if not data:
            return
        try:
            await self.ws.send_bytes(data)
        except WebSocketDisconnect:
            print("Streamer: WebSocket disconnected during send.")
        except Exception as e:
            # Log errors other than expected disconnects more visibly maybe
            if "Cannot call \"send\" once a close message has been sent" not in str(e):
                 print(f"Streamer: Error sending bytes: {e}")
            # else: # Optionally print disconnect errors quietly
            #    print("Streamer: Attempted send after close.")
            pass # Avoid flooding logs if client disconnects early

    def put(self, value: torch.LongTensor):
        """
        Receives new token IDs (Tensor) from generate().
        Processes tokens, decodes full blocks, and schedules sending.
        """
        if value.numel() == 0:
            return
        # Ensure value is on CPU and flatten to a list of ints
        new_token_ids = value.squeeze().cpu().tolist() # Move to CPU before list conversion
        if isinstance(new_token_ids, int):
            new_token_ids = [new_token_ids]

        for t in new_token_ids:
            # --- DEBUGGING PRINT ---
            # Log every token ID received from the model
            print(f"Streamer received token ID: {t}")
            # --- END DEBUGGING ---

            if t == EOS_TOKEN:
                # print("Streamer: EOS token encountered.") # Optional debugging
                break # Stop processing this batch if EOS is found

            if t == NEW_BLOCK:
                # print("Streamer: NEW_BLOCK token encountered.") # Optional debugging
                self.buf.clear()
                continue # Move to the next token

            # Check if token is within the expected audio range
            if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
                self.buf.append(t - AUDIO_BASE) # Store value relative to base
            # else: # Log unexpected tokens if needed
                # print(f"Streamer Warning: Ignoring unexpected token {t} (outside audio range [{AUDIO_BASE}, {AUDIO_BASE + AUDIO_SPAN}))")
                pass

            # If buffer has 7 tokens, decode and send
            if len(self.buf) == 7:
                audio_bytes = self._decode_block(self.buf)
                self.buf.clear() # Clear buffer after processing

                if audio_bytes: # Only send if decoding was successful
                    # Schedule the async send function to run on the main event loop
                    future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
                    self.tasks.add(future)
                    # Optional: Remove completed tasks to prevent memory leak if generation is very long
                    future.add_done_callback(self.tasks.discard)

                    # Allow EOS only after the first full block has been processed and scheduled for sending
                    if self.masker.sent_blocks == 0:
                        # print("Streamer: First audio block processed, allowing EOS.")
                        self.masker.sent_blocks = 1 # Update state in the mask

    def end(self):
        """Called by generate() when generation finishes."""
        if len(self.buf) > 0:
            print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
            self.buf.clear()
        # print(f"Streamer: Generation finished.") # Optional debugging
        pass

# 5) FastAPI App ------------------------------------------------------
app = FastAPI()

@app.on_event("startup")
async def load_models_startup():
    global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU, EOS_TOKEN

    print(f"πŸš€ Starting up on device: {device}")
    print("⏳ Lade Modelle …", flush=True)

    tok = AutoTokenizer.from_pretrained(REPO)
    print("Tokenizer loaded.")

    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    print(f"SNAC loaded to {device}.")

    model_dtype = torch.float32
    if device == "cuda":
        if torch.cuda.is_bf16_supported():
            model_dtype = torch.bfloat16
            print("Using bfloat16 for model.")
        else:
            model_dtype = torch.float16
            print("Using float16 for model.")

    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        device_map={"": 0} if device == "cuda" else None,
        torch_dtype=model_dtype,
        low_cpu_mem_usage=True,
    )

    # --- Verify EOS Token ---
    # Use the actual EOS token ID from the loaded model/tokenizer config
    config_eos_id = model.config.eos_token_id
    tokenizer_eos_id = tok.eos_token_id

    if config_eos_id is None:
        print("🚨 WARNING: model.config.eos_token_id is None!")
        # Fallback or default? Let's use the constant for now, but this needs checking.
        final_eos_token_id = EOS_TOKEN
    elif tokenizer_eos_id is not None and config_eos_id != tokenizer_eos_id:
         print(f"⚠️ WARNING: Mismatch! model.config.eos_token_id ({config_eos_id}) != tok.eos_token_id ({tokenizer_eos_id}). Using model config ID.")
         final_eos_token_id = config_eos_id
    else:
         final_eos_token_id = config_eos_id

    # Update the global constant if it differs or wasn't set properly by config
    if final_eos_token_id != EOS_TOKEN:
         print(f"πŸ”„ Updating EOS_TOKEN constant from {EOS_TOKEN} to {final_eos_token_id}")
         EOS_TOKEN = final_eos_token_id # Update the global constant

    # Set pad_token_id to the determined EOS token ID
    model.config.pad_token_id = EOS_TOKEN
    print(f"Using EOS Token ID: {EOS_TOKEN}")
    # --- End Verify EOS Token ---


    print(f"Model loaded to {model.device} with dtype {model.dtype}.")
    model.eval()

    audio_ids_device = AUDIO_IDS_CPU.to(device)
    masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN) # Use updated EOS_TOKEN
    print("AudioMask initialized.")

    stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)]) # Use updated EOS_TOKEN
    print("StoppingCriteria initialized.")

    print("βœ… Modelle geladen und bereit!", flush=True)

@app.get("/")
def hello():
    return {"status": "ok", "message": "TTS Service is running"}

# 6) Helper zum Prompt Bauen -------------------------------------------
def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
    """Builds the input_ids and attention_mask for the model."""
    prompt_text = f"{voice}: {text}"
    prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)

    input_ids = torch.cat([
        torch.tensor([[START_TOKEN]], device=device, dtype=torch.long),
        prompt_ids,
        torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long)
    ], dim=1)

    attention_mask = torch.ones_like(input_ids)
    return input_ids, attention_mask

# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
    await ws.accept()
    print("πŸ”Œ Client connected")
    streamer = None
    main_loop = asyncio.get_running_loop()

    try:
        req_text = await ws.receive_text()
        print(f"Received request: {req_text}")
        req = json.loads(req_text)
        text = req.get("text", "Hallo Welt, wie geht es dir heute?")
        voice = req.get("voice", "Jakob")

        if not text:
            print("⚠️ Request text is empty.")
            await ws.close(code=1003, reason="Text cannot be empty")
            return

        print(f"Generating audio for: '{text}' with voice '{voice}'")
        ids, attn = build_prompt(text, voice)
        masker.reset()
        streamer = AudioStreamer(ws, snac, masker, main_loop, device)

        print("Starting generation in background thread...")
        # --- DEBUGGING: Adjusted Generation Parameters ---
        await asyncio.to_thread(
            model.generate,
            input_ids=ids,
            attention_mask=attn,
            max_new_tokens=1500, # Keep lower for faster debugging cycles initially
            logits_processor=[masker],
            stopping_criteria=stopping_criteria,
            # --- Adjusted Parameters for Debugging Repetition ---
            do_sample=True,
            temperature=0.7,     # Slightly higher temperature
            # top_p=0.9,         # Commented out top_p for simpler testing
            repetition_penalty=1.2, # Slightly stronger penalty
            # --- End Adjusted Parameters ---
            use_cache=True,
            streamer=streamer
        )
        print("Generation thread finished.")

    except WebSocketDisconnect:
        print("πŸ”Œ Client disconnected.")
    except json.JSONDecodeError:
        print("❌ Invalid JSON received.")
        if ws.client_state.name == "CONNECTED":
            await ws.close(code=1003, reason="Invalid JSON format")
    except Exception as e:
        error_details = traceback.format_exc()
        print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
        error_payload = json.dumps({"error": str(e)})
        try:
            if ws.client_state.name == "CONNECTED":
                 await ws.send_text(error_payload)
        except Exception:
            pass
        if ws.client_state.name == "CONNECTED":
            await ws.close(code=1011)
    finally:
        if streamer:
            try:
                streamer.end()
            except Exception as e_end:
                 print(f"Error during streamer.end(): {e_end}")

        print("Closing connection.")
        if ws.client_state.name == "CONNECTED":
            try:
                await ws.close(code=1000)
            except RuntimeError as e_close:
                 print(f"Runtime error closing websocket: {e_close}")
            except Exception as e_close_final:
                 print(f"Error closing websocket: {e_close_final}")
        elif ws.client_state.name != "DISCONNECTED":
             print(f"WebSocket final state: {ws.client_state.name}")
        print("Connection closed.")

# 8) Dev‑Start --------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    print("Starting Uvicorn server...")
    # Note: Consider running with --workers 1 if you face issues with globals/GPU memory
    # uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", workers=1)
    uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")