File size: 5,137 Bytes
2c15189
 
 
67c3132
0316ec3
4189fe1
0316ec3
a09ea48
 
0dfc310
0316ec3
67c3132
a09ea48
2c15189
 
 
2008a3f
67c3132
1ab029d
0316ec3
67c3132
 
a09ea48
0316ec3
674acbf
67c3132
0dfc310
 
 
2c15189
 
67c3132
 
2c15189
0dfc310
f001a32
67c3132
d408dd5
9cd424e
67c3132
 
a09ea48
9cd424e
b3e4aa7
0dfc310
67c3132
9cd424e
a09ea48
67c3132
a09ea48
67c3132
 
 
 
 
 
9cd424e
2c15189
67c3132
 
 
 
 
2c15189
67c3132
 
2c15189
 
67c3132
 
 
 
9cd424e
2c15189
 
67c3132
 
2c15189
67c3132
2c15189
67c3132
 
 
2c15189
 
 
 
 
 
 
67c3132
 
 
 
 
2c15189
 
 
 
 
 
 
 
 
67c3132
97006e1
4189fe1
 
d408dd5
67c3132
 
d408dd5
a8606ac
2c15189
a09ea48
4189fe1
2c15189
67c3132
2c15189
 
 
 
67c3132
 
2c15189
67c3132
2c15189
 
67c3132
 
2c15189
 
 
 
 
 
 
67c3132
 
2c15189
 
67c3132
2c15189
67c3132
2c15189
 
 
 
67c3132
4189fe1
2c15189
a09ea48
2c15189
a09ea48
f3890ef
2c15189
f3890ef
67c3132
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import json
import asyncio
import numpy as np
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download

# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

# — Device —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — Modelle laden —
print("Loading SNAC model...")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
print("Downloading model weights (config + safetensors)...")
snapshot_download(
    repo_id=model_name,
    allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
    ignore_patterns=[
        "optimizer.pt", "pytorch_model.bin", "training_args.bin",
        "scheduler.pt", "tokenizer.json", "tokenizer_config.json",
        "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
    ]
)

print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16
).to(device)
model.config.pad_token_id = model.config.eos_token_id

tokenizer = AutoTokenizer.from_pretrained(model_name)

# — Hilfsfunktionen —

def process_prompt(text: str, voice: str):
    """Bereitet input_ids und attention_mask für das Modell vor."""
    prompt = f"{voice}: {text}"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    start = torch.tensor([[128259]], dtype=torch.int64)
    end   = torch.tensor([[128009, 128260]], dtype=torch.int64)
    ids  = torch.cat([start, input_ids, end], dim=1).to(device)
    mask = torch.ones_like(ids).to(device)
    return ids, mask

def parse_output(generated_ids: torch.LongTensor):
    """Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
    token_to_find   = 128257
    token_to_remove = 128258

    idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
    if idxs.numel() > 0:
        cut = idxs[-1].item() + 1
        cropped = generated_ids[:, cut:]
    else:
        cropped = generated_ids

    # Entferne EOS‑Token
    row = cropped[0]
    return row[row != token_to_remove].tolist()

def redistribute_codes(code_list: list[int], snac_model: SNAC):
    """
    Verteilt die Token nur in kompletten 7er‑Blöcken auf die drei SNAC‑Layer
    und dekodiert in Audio. Unvollständige Reste (<7 Tokens) werden verworfen.
    """
    n_blocks = len(code_list) // 7
    layer1, layer2, layer3 = [], [], []

    for i in range(n_blocks):
        base = code_list[7*i : 7*i + 7]
        layer1.append(base[0])
        layer2.append(base[1] -   4096)
        layer3.append(base[2] - 2*4096)
        layer3.append(base[3] - 3*4096)
        layer2.append(base[4] - 4*4096)
        layer3.append(base[5] - 5*4096)
        layer3.append(base[6] - 6*4096)

    if not layer1:
        # kein kompletter Block → leeres Audio
        return np.zeros(0, dtype=np.float32)

    dev = next(snac_model.parameters()).device
    codes = [
        torch.tensor(layer1, device=dev).unsqueeze(0),
        torch.tensor(layer2, device=dev).unsqueeze(0),
        torch.tensor(layer3, device=dev).unsqueeze(0),
    ]
    audio = snac_model.decode(codes)
    return audio.detach().squeeze().cpu().numpy()

# — FastAPI Setup —

app = FastAPI()

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        while True:
            # Erwartet JSON: {"text": "...", "voice": "Jakob"}
            data = json.loads(await ws.receive_text())
            text  = data.get("text", "")
            voice = data.get("voice", "Jakob")

            # 1) Tokens vorbereiten
            ids, mask = process_prompt(text, voice)

            # 2) Generierung
            gen_ids = model.generate(
                input_ids=ids,
                attention_mask=mask,
                max_new_tokens=2000,    # hier nach Bedarf anpassen
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                repetition_penalty=1.1,
                eos_token_id=model.config.eos_token_id,
            )

            # 3) Tokens → Code-Liste → Audio
            codes   = parse_output(gen_ids)
            audio_np = redistribute_codes(codes, snac)

            # 4) in 0.1s‑Stücken PCM16 streamen
            pcm16 = (audio_np * 32767).astype("int16").tobytes()
            chunk = 2400 * 2  # 2400 samples @24kHz = 0.1s * 2 bytes
            for i in range(0, len(pcm16), chunk):
                await ws.send_bytes(pcm16[i : i+chunk])
                await asyncio.sleep(0.1)

        # Ende der while‐Schleife
    except WebSocketDisconnect:
        print("Client disconnected")
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)