File size: 4,324 Bytes
2c15189
 
 
0316ec3
4189fe1
9bf14d0
a09ea48
 
0316ec3
9bf14d0
2c15189
 
9bf14d0
2008a3f
9bf14d0
1ab029d
0316ec3
9bf14d0
 
0dfc310
9bf14d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3281189
67c3132
9bf14d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8606ac
2c15189
a09ea48
4189fe1
9bf14d0
 
 
 
 
 
 
 
 
 
 
 
2c15189
9bf14d0
 
 
 
d4630a2
2c15189
9bf14d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c15189
4189fe1
9bf14d0
 
a09ea48
9bf14d0
2c15189
a09ea48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer

# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

# — Device wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — FastAPI instanziieren —
app = FastAPI()

# — Hello‑Route, damit kein 404 bei GET / mehr kommt —
@app.get("/")
async def read_root():
    return {"message": "Hello, world!"}

# — Modelle bei Startup laden —
@app.on_event("startup")
async def load_models():
    global tokenizer, model, snac
    # SNAC laden
    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    # TTS‑Modell laden
    model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map={"": 0} if device == "cuda" else None,
        torch_dtype=torch.bfloat16 if device == "cuda" else None,
        low_cpu_mem_usage=True
    )
    # Pad‑ID auf EOS einstellen
    model.config.pad_token_id = model.config.eos_token_id

# — Hilfsfunktionen —
def prepare_inputs(text: str, voice: str):
    prompt = f"{voice}: {text}"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    # Start‑/End‑Marker
    start = torch.tensor([[128259]], dtype=torch.int64, device=device)
    end   = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
    ids   = torch.cat([start, input_ids, end], dim=1)
    mask  = torch.ones_like(ids)
    return ids, mask

def decode_block(block_tokens: list[int]):
    # aus 7 Tokens einen SNAC‑Decode‑Block bauen
    layer1, layer2, layer3 = [], [], []
    b = block_tokens
    layer1.append(b[0])
    layer2.append(b[1] -   4096)
    layer3.append(b[2] - 2*4096)
    layer3.append(b[3] - 3*4096)
    layer2.append(b[4] - 4*4096)
    layer3.append(b[5] - 5*4096)
    layer3.append(b[6] - 6*4096)
    codes = [
        torch.tensor(layer1, device=device).unsqueeze(0),
        torch.tensor(layer2, device=device).unsqueeze(0),
        torch.tensor(layer3, device=device).unsqueeze(0),
    ]
    # ergibt FloatTensor shape (1, N), @24 kHz
    audio = snac.decode(codes).squeeze().cpu().numpy()
    # in PCM16 umwandeln
    return (audio * 32767).astype("int16").tobytes()

# — WebSocket Endpoint für TTS Streaming —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        # erst die Anfrage als JSON empfangen
        msg = await ws.receive_text()
        req = json.loads(msg)
        text  = req.get("text", "")
        voice = req.get("voice", "Jakob")

        # Inputs bauen
        input_ids, attention_mask = prepare_inputs(text, voice)
        past_kvs = None
        collected = []

        # Token‑für‑Token loop
        while True:
            out = model(
                input_ids=input_ids if past_kvs is None else None,
                attention_mask=attention_mask if past_kvs is None else None,
                past_key_values=past_kvs,
                use_cache=True,
            )
            logits = out.logits[:, -1, :]
            past_kvs = out.past_key_values

            # Sampling
            probs = torch.softmax(logits, dim=-1)
            nxt   = torch.multinomial(probs, num_samples=1).item()

            # Ende, wenn EOS
            if nxt == model.config.eos_token_id:
                break
            # Reset bei neuem Start‑Marker
            if nxt == 128257:
                collected = []
                continue

            # Audio‑Code offsetten und sammeln
            collected.append(nxt - 128266)
            # sobald 7 Stück, direkt dekodieren und senden
            if len(collected) == 7:
                pcm = decode_block(collected)
                collected = []
                await ws.send_bytes(pcm)

        # nach Ende sauber schließen
        await ws.close()

    except WebSocketDisconnect:
        # Client hat disconnectet
        pass
    except Exception as e:
        # bei Fehlern 1011 senden
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)