File size: 3,947 Bytes
a09ea48
 
 
0316ec3
4189fe1
0316ec3
a09ea48
 
 
0316ec3
a09ea48
 
 
 
 
2008a3f
a09ea48
0316ec3
 
 
a09ea48
0316ec3
 
a09ea48
 
 
 
 
0316ec3
 
 
a09ea48
 
 
0316ec3
a09ea48
 
 
 
 
ad94d02
a09ea48
0316ec3
 
a09ea48
 
 
 
0316ec3
a09ea48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0316ec3
a09ea48
 
 
0316ec3
a09ea48
 
97006e1
a09ea48
4189fe1
 
a8606ac
a09ea48
 
4189fe1
 
a09ea48
4189fe1
 
a09ea48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4189fe1
a09ea48
 
 
 
4189fe1
a09ea48
4189fe1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login

# — Environment & HF‑Auth —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

# — Device & Modelle laden —
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading SNAC model...")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

model_name = "canopylabs/3b-de-ft-research_release"
print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16
).to(device)
model.config.pad_token_id = model.config.eos_token_id

tokenizer = AutoTokenizer.from_pretrained(model_name)

# — Hilfsfunktionen — 
def process_prompt(text: str, voice: str):
    prompt = f"{voice}: {text}"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    start = torch.tensor([[128259]], dtype=torch.int64)
    end = torch.tensor([[128009, 128260]], dtype=torch.int64)
    ids = torch.cat([start, input_ids, end], dim=1).to(device)
    mask = torch.ones_like(ids).to(device)
    return ids, mask

def parse_output(generated_ids: torch.LongTensor):
    token_to_find = 128257
    token_to_remove = 128258
    idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
    if idxs.numel() > 0:
        last = idxs[-1].item()
        cropped = generated_ids[:, last+1:]
    else:
        cropped = generated_ids
    # remove padding token markers
    rows = []
    for row in cropped:
        row = row[row != token_to_remove]
        rows.append(row)
    flat = rows[0].tolist()
    # adjust and regroup
    layer1, layer2, layer3 = [], [], []
    for i in range(len(flat)//7):
        base = flat[7*i:7*i+7]
        layer1.append(base[0])
        layer2.append(base[1]-4096)
        layer3.extend([base[2]-(2*4096), base[3]-(3*4096)])
        layer2.append(base[4]-4*4096)
        layer3.extend([base[5]-(5*4096), base[6]-(6*4096)])
    codes = [
        torch.tensor(layer1, device=device).unsqueeze(0),
        torch.tensor(layer2, device=device).unsqueeze(0),
        torch.tensor(layer3, device=device).unsqueeze(0),
    ]
    audio = snac.decode(codes).detach().squeeze().cpu().numpy()
    return audio  # float32 numpy at 24000 Hz

# — FastAPI + WebSocket-Endpoint — 
app = FastAPI()

@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        while True:
            msg = await ws.receive_text()
            data = json.loads(msg)
            text = data.get("text", "")
            voice = data.get("voice", "jana")
            # Generate tokens
            ids, mask = process_prompt(text, voice)
            with torch.no_grad():
                gen_ids = model.generate(
                    input_ids=ids,
                    attention_mask=mask,
                    max_new_tokens=1200,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.95,
                    repetition_penalty=1.1,
                    eos_token_id=128258,
                )
            # Convert to waveform
            audio = parse_output(gen_ids)
            # PCM16 conversion & chunking
            pcm16 = (audio * 32767).astype('int16').tobytes()
            # 0.1 s @24 kHz = 2400 samples = 4800 bytes
            chunk_size = 2400 * 2
            for i in range(0, len(pcm16), chunk_size):
                await ws.send_bytes(pcm16[i:i+chunk_size])
                await asyncio.sleep(0.1)  # pacing
    except WebSocketDisconnect:
        print("Client disconnected")
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)