File size: 7,042 Bytes
5031731
 
4189fe1
9bf14d0
5031731
 
d9ea17d
0316ec3
5031731
479f253
 
 
2008a3f
1ab029d
f92444a
5031731
 
479f253
5031731
 
 
 
 
 
 
479f253
 
5031731
 
 
 
 
 
479f253
5031731
 
 
 
f92444a
479f253
5031731
 
 
bca75ea
5031731
bca75ea
 
5031731
9bf14d0
0dfc310
9bf14d0
479f253
5031731
9bf14d0
 
d9ea17d
5031731
 
 
 
 
 
9bf14d0
5031731
bca75ea
 
 
 
f63f843
5031731
 
 
 
bca75ea
5031731
f92444a
5031731
f92444a
5031731
 
 
 
 
 
 
 
479f253
 
f92444a
479f253
 
5031731
 
479f253
5031731
f92444a
5031731
 
 
 
 
f92444a
 
 
5031731
a8606ac
bca75ea
a09ea48
4189fe1
bca75ea
5031731
 
 
 
 
 
f63f843
 
479f253
5031731
 
 
 
 
 
bca75ea
479f253
f63f843
9ef5e61
5031731
 
 
 
 
f92444a
5031731
 
9ef5e61
5031731
479f253
9ef5e61
5031731
 
9ef5e61
5031731
479f253
5031731
9ef5e61
 
 
bca75ea
5031731
bca75ea
5031731
 
bca75ea
5031731
479f253
a09ea48
5031731
 
479f253
5031731
 
 
 
 
 
 
 
a4cfefc
5031731
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# app.py  ─────────────────────────────────────────────────────────────
import os, json, asyncio, torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import (AutoTokenizer, AutoModelForCausalLM, LogitsProcessor)
from transformers.generation.utils import Cache
from snac import SNAC

# ── 0. HF‑Login & Device ─────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

device = "cuda" if torch.cuda.is_available() else "cpu"

#Β Flash‑Attention‑Bug in PyTorchΒ 2.2.x umgehen
torch.backends.cuda.enable_flash_sdp(False)

# ── 1. Konstanten ────────────────────────────────────────────────────
REPO              = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS      = 50                      # pro mini‑generate
START_TOKEN       = 128259
NEW_BLOCK_TOKEN   = 128257
EOS_TOKEN         = 128258
AUDIO_BASE        = 128266                 # erster Audio‑Code
VALID_AUDIO_IDS   = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)

# ── 2. Dynamischer Logit‑Masker ──────────────────────────────────────
class DynamicAudioMask(LogitsProcessor):
    """
    blockt EOS, bis mindestens `min_audio_blocks` gesendet wurden
    """
    def __init__(self, audio_ids: torch.Tensor, min_audio_blocks: int = 1):
        super().__init__()
        self.audio_ids    = audio_ids
        self.ctrl_ids     = torch.tensor([NEW_BLOCK_TOKEN], device=audio_ids.device)
        self.min_blocks   = min_audio_blocks
        self.blocks_done  = 0

    def __call__(self, input_ids, scores):
        allowed = torch.cat([self.audio_ids, self.ctrl_ids])
        if self.blocks_done >= self.min_blocks:              # jetzt darf EOS dazu
            allowed = torch.cat([allowed, torch.tensor([EOS_TOKEN], device=scores.device)])
        mask = torch.full_like(scores, float("-inf"))
        mask[:, allowed] = 0
        return scores + mask

# ── 3. FastAPI GrundgerΓΌst ───────────────────────────────────────────
app = FastAPI()

@app.get("/")
async def ping():
    return {"msg": "Orpheus‑TTS up & running"}

@app.on_event("startup")
async def load_models():
    global tok, model, snac, masker
    print("⏳ Lade Modelle …")

    tok  = AutoTokenizer.from_pretrained(REPO)
    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        low_cpu_mem_usage=True,
        device_map={"": 0} if device == "cuda" else None,
        torch_dtype=torch.bfloat16 if device == "cuda" else None,
    )
    model.config.pad_token_id = model.config.eos_token_id
    model.config.use_cache    = True

    masker = DynamicAudioMask(VALID_AUDIO_IDS.to(device))
    print("βœ…Β Modelle geladen")

# ── 4. Hilfs‑Funktionen ──────────────────────────────────────────────
def build_inputs(text: str, voice: str):
    prompt = f"{voice}: {text}"
    ids = tok(prompt, return_tensors="pt").input_ids.to(device)
    ids = torch.cat(
        [ torch.tensor([[START_TOKEN]], device=device),
          ids,
          torch.tensor([[128009, 128260]], device=device) ],
        dim=1,
    )
    attn = torch.ones_like(ids)
    return ids, attn

def decode_block(block7: list[int]) -> bytes:
    l1, l2, l3 = [], [], []
    b = block7
    l1.append(b[0])
    l2.append(b[1] -   4096)
    l3.extend([b[2] -  8192,  b[3] - 12288])
    l2.append(b[4] - 16384)
    l3.extend([b[5] - 20480, b[6] - 24576])

    codes = [
        torch.tensor(l1, device=device).unsqueeze(0),
        torch.tensor(l2, device=device).unsqueeze(0),
        torch.tensor(l3, device=device).unsqueeze(0),
    ]
    audio = snac.decode(codes).squeeze().cpu().numpy()
    return (audio * 32767).astype("int16").tobytes()

# ── 5. WebSocket‑TTS‑Endpoint ───────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
    await ws.accept()
    try:
        req = json.loads(await ws.receive_text())
        text  = req.get("text", "")
        voice = req.get("voice", "Jakob")

        ids, attn = build_inputs(text, voice)
        past      = None
        buf       = []

        while True:
            out = model.generate(
                input_ids       = ids if past is None else None,
                attention_mask  = attn if past is None else None,
                past_key_values = past,
                max_new_tokens  = CHUNK_TOKENS,
                logits_processor= [masker],        # β–Ί dynamischer Masker
                do_sample=True, temperature=0.7, top_p=0.95,
                return_dict_in_generate=True,
                use_cache=True,
            )

            # Cache & neue Tokens extrahieren --------------------------------
            pkv = out.past_key_values
            if isinstance(pkv, Cache):             # HFΒ >=Β 4.47
                pkv = pkv.to_legacy()
            past = pkv

            new = out.sequences[0, -out.num_generated_tokens :].tolist()
            print("new tokens:", new[:20])          # Debug‑Ausgabe

            # ----------------------------------------------------------------
            for t in new:
                if t == EOS_TOKEN:
                    raise StopIteration

                if t == NEW_BLOCK_TOKEN:
                    buf.clear()
                    continue

                buf.append(t - AUDIO_BASE)
                if len(buf) == 7:
                    await ws.send_bytes(decode_block(buf))
                    buf.clear()
                    masker.blocks_done += 1         # β–ΊΒ jetzt darf ggf. EOS

            # nΓ€chsten generate‑Step nur noch mit Cache, keine neuen ids
            ids, attn = None, None

    except (StopIteration, WebSocketDisconnect):
        pass
    except Exception as e:
        print("❌ WS‑Error:", e)
        if ws.client_state.name != "DISCONNECTED":
            await ws.close(code=1011)
    finally:
        if ws.client_state.name != "DISCONNECTED":
            try:
                await ws.close()
            except RuntimeError:
                pass   # Close‑Frame war schon raus

# ── 6. Lokaler Start (uvicorn) ───────────────────────────────────────
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)