File size: 5,667 Bytes
2c15189
 
a3af518
0316ec3
a4cfefc
 
 
4189fe1
9bf14d0
d9ea17d
a3af518
0316ec3
a3af518
d9ea17d
2c15189
a4cfefc
2008a3f
a4cfefc
1ab029d
0316ec3
a4cfefc
9bf14d0
0dfc310
a4cfefc
9bf14d0
a4cfefc
 
9bf14d0
a4cfefc
9bf14d0
d9ea17d
a3af518
a4cfefc
 
a3af518
f63f843
a4cfefc
f63f843
 
9bf14d0
f63f843
a4cfefc
986d4cd
a4cfefc
 
f63f843
 
a3af518
a4cfefc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f63f843
a3af518
f63f843
a4cfefc
 
f63f843
 
a4cfefc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf14d0
a4cfefc
 
 
9bf14d0
a3af518
a4cfefc
 
 
9bf14d0
a4cfefc
a8606ac
2c15189
a09ea48
4189fe1
f63f843
a4cfefc
 
 
f63f843
a4cfefc
f63f843
 
a4cfefc
f63f843
a4cfefc
f63f843
 
 
 
 
 
a4cfefc
f63f843
a4cfefc
 
f63f843
a4cfefc
 
f63f843
a4cfefc
 
 
 
 
 
f63f843
d9ea17d
a4cfefc
 
 
 
 
 
f63f843
 
a4cfefc
 
 
a3af518
a4cfefc
c70d8eb
4189fe1
a4cfefc
a09ea48
a4cfefc
a09ea48
a4cfefc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import json
import asyncio
import torch
# Bugfix für PyTorch 2.2.x Flash‑SDP‑Assertion
torch.backends.cuda.enable_flash_sdp(False)

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer

# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

# — Device wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — FastAPI instanzieren —
app = FastAPI()

# — Hello‑Route, damit GET / kein 404 mehr gibt —
@app.get("/")
async def read_root():
    return {"message": "Orpheus TTS WebSocket Server läuft"}

# — Modelle beim Startup laden —
@app.on_event("startup")
async def load_models():
    global tokenizer, model, snac

    # SNAC für Audio‑Decoding
    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

    # Orpheus‑TTS Base
    REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(REPO)
    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        device_map={"": 0} if device=="cuda" else None,
        torch_dtype=torch.bfloat16 if device=="cuda" else None,
        low_cpu_mem_usage=True,
        return_legacy_cache=True  # für compatibility mit past_key_values als Tuple
    ).to(device)
    model.config.pad_token_id = model.config.eos_token_id

    # --- Logit‑Masking vorbereiten ---
    # reine Audio‑Tokens laufen von 128266 bis 128266+4096-1
    AUDIO_OFFSET = 128266
    AUDIO_COUNT  = 4096
    valid_audio = torch.arange(AUDIO_OFFSET, AUDIO_OFFSET + AUDIO_COUNT, device=device)
    ctrl_tokens = torch.tensor([128257, model.config.eos_token_id], device=device)
    global ALLOWED_IDS
    ALLOWED_IDS = torch.cat([valid_audio, ctrl_tokens])

def sample_from_logits(logits: torch.Tensor) -> int:
    """
    Maskt alle IDs außer ALLOWED_IDS und sampelt dann einen Token.
    """
    # logits: [1, vocab_size]
    mask = torch.full_like(logits, float("-inf"))
    mask[0, ALLOWED_IDS] = 0.0
    probs = torch.softmax(logits + mask, dim=-1)
    return torch.multinomial(probs, num_samples=1).item()

def prepare_inputs(text: str, voice: str):
    prompt = f"{voice}: {text}"
    ids   = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    # Start‐/End‐Marker
    start = torch.tensor([[128259]], dtype=torch.int64, device=device)
    end   = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
    input_ids = torch.cat([start, ids, end], dim=1)
    attention_mask = torch.ones_like(input_ids, device=device)
    return input_ids, attention_mask

def decode_block(block: list[int]) -> bytes:
    """
    Aus 7 gesampelten Audio‑Codes einen PCM‑16‑Byte‐Block dekodieren.
    Hier erwarten wir block[i] = raw_token - 128266.
    """
    layer1, layer2, layer3 = [], [], []
    b = block
    layer1.append(b[0])
    layer2.append(b[1] -   4096)
    layer3.append(b[2] - 2*4096)
    layer3.append(b[3] - 3*4096)
    layer2.append(b[4] - 4*4096)
    layer3.append(b[5] - 5*4096)
    layer3.append(b[6] - 6*4096)

    dev = next(snac.parameters()).device
    codes = [
        torch.tensor(layer1, device=dev).unsqueeze(0),
        torch.tensor(layer2, device=dev).unsqueeze(0),
        torch.tensor(layer3, device=dev).unsqueeze(0),
    ]
    audio = snac.decode(codes).squeeze().cpu().numpy()
    # in PCM16 umwandeln
    pcm16 = (audio * 32767).astype("int16").tobytes()
    return pcm16

# — WebSocket Endpoint für TTS Streaming —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        msg = await ws.receive_text()
        req = json.loads(msg)
        text  = req.get("text", "")
        voice = req.get("voice", "Jakob")

        # Inputs vorbereiten
        input_ids, attention_mask = prepare_inputs(text, voice)
        past_kvs = None
        buffer   = []  # sammelt die 7 Audio‑Codes

        # Token‑für‑Token Loop
        while True:
            out = model(
                input_ids=input_ids if past_kvs is None else None,
                attention_mask=attention_mask if past_kvs is None else None,
                past_key_values=past_kvs,
                use_cache=True,
                return_dict=True
            )
            past_kvs = out.past_key_values
            next_tok = sample_from_logits(out.logits[:, -1, :])

            # Ende?
            if next_tok == model.config.eos_token_id:
                break

            # Reset bei neuem Audio‑Block‑Start
            if next_tok == 128257:
                buffer.clear()
                input_ids      = torch.tensor([[next_tok]], device=device)
                attention_mask = torch.ones_like(input_ids)
                continue

            # Audio‑Code sammeln (Offset abziehen)
            buffer.append(next_tok - 128266)
            # sobald wir 7 Codes haben → dekodieren & senden
            if len(buffer) == 7:
                pcm = decode_block(buffer)
                buffer.clear()
                await ws.send_bytes(pcm)

            # nächster Schritt: genau diesen Token wieder einspeisen
            input_ids      = torch.tensor([[next_tok]], device=device)
            attention_mask = torch.ones_like(input_ids)

        # sauber beenden
        await ws.close()
    except WebSocketDisconnect:
        pass
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

# — CLI zum lokalen Testen —
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)