File size: 4,418 Bytes
2c15189
 
0316ec3
d9ea17d
4189fe1
9bf14d0
a09ea48
d9ea17d
0316ec3
d9ea17d
 
2c15189
9bf14d0
2008a3f
d9ea17d
1ab029d
0316ec3
9bf14d0
0dfc310
9bf14d0
 
d9ea17d
9bf14d0
d9ea17d
 
 
 
 
 
9bf14d0
d9ea17d
 
 
 
 
 
 
9bf14d0
d9ea17d
 
 
9bf14d0
d9ea17d
9bf14d0
3281189
d9ea17d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf14d0
d9ea17d
 
 
9bf14d0
d9ea17d
 
9bf14d0
d9ea17d
a8606ac
2c15189
a09ea48
4189fe1
d9ea17d
9bf14d0
 
 
d9ea17d
 
 
 
 
 
 
 
 
 
 
 
c70d8eb
fd06e70
4189fe1
fd06e70
a09ea48
2c15189
a09ea48
d9ea17d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import json
import torch
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer
from snac import SNAC

# — HF‑Token & Login (wenn gesetzt) —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

# — Device wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"

app = FastAPI()

@app.get("/")
async def read_root():
    return {"message": "Hello, world!"}

# — Globale Modelle —
model = None
tokenizer = None
snac_model = None

# — Startup: SNAC & Orpheus laden —
@app.on_event("startup")
async def load_models():
    global model, tokenizer, snac_model
    # 1) SNAC
    snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    # 2) Orpheus‑TTS
    REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(REPO)
    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        device_map="auto" if device=="cuda" else None,
        torch_dtype=torch.bfloat16 if device=="cuda" else None,
        low_cpu_mem_usage=True
    ).to(device)
    model.config.pad_token_id = model.config.eos_token_id

# — Marker und Offsets aus der Vorlage —
START_TOKEN  = 128259
END_TOKENS   = [128009, 128260]
AUDIO_OFFSET = 128266

def process_single_prompt(prompt: str, voice: str) -> list[int]:
    # Prompt zusammenbauen
    if voice and voice != "in_prompt":
        text = f"{voice}: {prompt}"
    else:
        text = prompt
    # Tokenize + Marker
    ids = tokenizer(text, return_tensors="pt").input_ids
    start = torch.tensor([[START_TOKEN]], dtype=torch.int64)
    end   = torch.tensor([END_TOKENS], dtype=torch.int64)
    input_ids = torch.cat([start, ids, end], dim=1).to(device)
    attention_mask = torch.ones_like(input_ids)

    # Generieren
    gen = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=4000,
        do_sample=True,
        temperature=0.6,
        top_p=0.95,
        repetition_penalty=1.1,
        eos_token_id=128258,
        use_cache=True,
    )

    # letzten START_TOKEN finden & croppen
    token_to_find   = 128257
    token_to_remove = 128258
    idxs = (gen == token_to_find).nonzero(as_tuple=True)[1]
    if idxs.numel() > 0:
        cropped = gen[:, idxs[-1] + 1 :]
    else:
        cropped = gen

    # Padding entfernen
    row = cropped[0][cropped[0] != token_to_remove]
    # Aus Länge ein Vielfaches von 7 machen
    new_len = (row.size(0) // 7) * 7
    trimmed = row[:new_len].tolist()
    # Offset abziehen
    return [t - AUDIO_OFFSET for t in trimmed]

def redistribute_codes(code_list: list[int]) -> np.ndarray:
    # Die 7er‑Blöcke auf 3 Layer verteilen und dekodieren
    layer1, layer2, layer3 = [], [], []
    for i in range(len(code_list) // 7):
        b = code_list[7*i : 7*i+7]
        layer1.append(b[0])
        layer2.append(b[1] -   4096)
        layer3.append(b[2] - 2*4096)
        layer3.append(b[3] - 3*4096)
        layer2.append(b[4] - 4*4096)
        layer3.append(b[5] - 5*4096)
        layer3.append(b[6] - 6*4096)

    codes = [
        torch.tensor(layer1, device=device).unsqueeze(0),
        torch.tensor(layer2, device=device).unsqueeze(0),
        torch.tensor(layer3, device=device).unsqueeze(0),
    ]
    audio = snac_model.decode(codes).squeeze().cpu().numpy()
    return audio  # float32 @24 kHz

# — WebSocket‑Endpoint für TTS —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        # 1) Text + Voice empfangen
        msg = await ws.receive_text()
        req = json.loads(msg)
        text  = req.get("text", "")
        voice = req.get("voice", "")

        # 2) Prompt → Code‑Liste
        with torch.no_grad():
            codes = process_single_prompt(text, voice)
            audio_np = redistribute_codes(codes)

        # 3) In PCM16 konvertieren und senden
        pcm16 = (audio_np * 32767).astype(np.int16).tobytes()
        await ws.send_bytes(pcm16)

        # 4) sauber schließen
        await ws.close()

    except WebSocketDisconnect:
        pass
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)