File size: 6,238 Bytes
479f253
 
4189fe1
9bf14d0
4520cbe
 
d9ea17d
0316ec3
f92444a
479f253
 
 
2008a3f
1ab029d
479f253
 
f92444a
 
479f253
 
 
 
 
 
 
 
 
 
 
 
f92444a
479f253
 
f92444a
 
479f253
bca75ea
479f253
bca75ea
 
479f253
 
 
 
f92444a
bca75ea
f92444a
9bf14d0
0dfc310
9bf14d0
479f253
 
9bf14d0
 
d9ea17d
bca75ea
479f253
9bf14d0
479f253
bca75ea
 
 
 
f63f843
479f253
bca75ea
479f253
f92444a
479f253
f92444a
479f253
 
 
 
 
 
 
 
 
f92444a
479f253
 
 
 
 
 
f92444a
 
 
 
 
 
a8606ac
bca75ea
a09ea48
4189fe1
bca75ea
9ef5e61
4c833ce
 
f63f843
 
479f253
bca75ea
 
 
 
 
479f253
bca75ea
479f253
 
f63f843
9ef5e61
479f253
 
 
f92444a
479f253
 
9ef5e61
479f253
9ef5e61
479f253
4c833ce
9ef5e61
4c833ce
479f253
 
9ef5e61
 
 
bca75ea
 
479f253
 
bca75ea
4c833ce
479f253
a09ea48
bca75ea
f92444a
479f253
f92444a
479f253
a4cfefc
f92444a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# app.py ──────────────────────────────────────────────────────────────
import os, json, asyncio, torch, logging
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache          
from snac import SNAC

# ── 0. Auth & Device ────────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False)          # Flash‑Bug umgehen
logging.getLogger("transformers.generation.utils").setLevel("ERROR")

# ── 1. Konstanten ───────────────────────────────────────────────────
MODEL_REPO        = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS      = 50

START_TOKEN       = 128259          # <𝑠>
NEW_BLOCK_TOKEN   = 128257          # πŸ”Šβ€‘Start
EOS_TOKEN         = 128258          # <eos>
PROMPT_END        = [128009, 128260]
AUDIO_BASE        = 128266

VALID_AUDIO_IDS   = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)

# ── 2. Logit‑Masker ─────────────────────────────────────────────────
class AudioMask(LogitsProcessor):
    def __init__(self, allowed: torch.Tensor):
        super().__init__()
        self.allowed = allowed

    def __call__(self, input_ids, scores):
        mask = torch.full_like(scores, float("-inf"))
        mask[:, self.allowed] = 0
        return scores + mask

ALLOWED_IDS = torch.cat([
    VALID_AUDIO_IDS,
    torch.tensor([START_TOKEN, NEW_BLOCK_TOKEN, EOS_TOKEN])
]).to(device)
MASKER = AudioMask(ALLOWED_IDS)

# ── 3. FastAPI GrundgerΓΌst ──────────────────────────────────────────
app = FastAPI()

@app.get("/")
async def ping():
    return {"message": "Orpheus‑TTSΒ ready"}

@app.on_event("startup")
async def load_models():
    global tok, model, snac
    tok   = AutoTokenizer.from_pretrained(MODEL_REPO)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_REPO,
        low_cpu_mem_usage=True,
        device_map={"": 0} if device == "cuda" else None,
        torch_dtype=torch.bfloat16 if device == "cuda" else None,
    )
    model.config.pad_token_id = model.config.eos_token_id
    snac  = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

# ── 4. Hilfsfunktionen ──────────────────────────────────────────────
def build_inputs(text: str, voice: str):
    prompt = f"{voice}: {text}" if voice and voice != "in_prompt" else text
    ids = tok(prompt, return_tensors="pt").input_ids.to(device)
    ids = torch.cat([
        torch.tensor([[START_TOKEN]], device=device),
        ids,
        torch.tensor([PROMPT_END], device=device)
    ], 1)
    mask = torch.ones_like(ids)
    return ids, mask                        # shape (1,Β L)

def decode_block(block7: list[int]) -> bytes:
    l1, l2, l3 = [], [], []
    b = block7
    l1.append(b[0])
    l2.append(b[1] - 4096)
    l3 += [b[2]-8192, b[3]-12288]
    l2.append(b[4] - 16384)
    l3 += [b[5]-20480, b[6]-24576]

    codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)]
    audio = snac.decode(codes).squeeze().cpu().numpy()
    return (audio * 32767).astype("int16").tobytes()

# ── 5. WebSocket‑Endpoint ───────────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
    await ws.accept()
    try:
        req = json.loads(await ws.receive_text())
        ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
        prompt_len = ids.size(1)
        past, buf = None, []

        while True:
            out = model.generate(
                input_ids=ids if past is None else None,
                attention_mask=attn if past is None else None,
                past_key_values=past,
                max_new_tokens=CHUNK_TOKENS,
                logits_processor=[MASKER],
                do_sample=True, top_p=0.95, temperature=0.7,
                return_dict_in_generate=True,
                use_cache=True,
                return_legacy_cache=True,      # β‡  Warnung verschwindet
            )

            past  = out.past_key_values       # unverΓ€ndert weiterreichen
            seq   = out.sequences[0].tolist()
            new   = seq[prompt_len:]; prompt_len = len(seq)

            if not new:                       # selten, aber mΓΆglich
                continue

            for t in new:
                if t == EOS_TOKEN:
                    await ws.close()
                    return
                if t == NEW_BLOCK_TOKEN:
                    buf.clear(); continue
                if t < AUDIO_BASE:            # sollte durch Maske nie passieren
                    continue
                buf.append(t - AUDIO_BASE)
                if len(buf) == 7:
                    await ws.send_bytes(decode_block(buf))
                    buf.clear()

            # Ab jetzt nur noch Cache – IDs & Mask nicht mehr nΓΆtig
            ids = attn = None

    except WebSocketDisconnect:
        pass
    except Exception as e:
        print("WS‑Error:", e)
        if ws.client_state.name == "CONNECTED":
            await ws.close(code=1011)

# ── 6. Lokaler Start ────────────────────────────────────────────────
if __name__ == "__main__":
    import uvicorn, sys
    port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860
    uvicorn.run("app:app", host="0.0.0.0", port=port)