File size: 3,014 Bytes
f3890ef
0316ec3
4189fe1
0316ec3
a09ea48
 
0dfc310
0316ec3
a09ea48
f3890ef
 
2008a3f
1ab029d
0316ec3
f3890ef
a09ea48
0316ec3
674acbf
0dfc310
 
 
f3890ef
 
0dfc310
f001a32
f3890ef
d408dd5
9cd424e
f3890ef
 
 
a09ea48
9cd424e
b3e4aa7
0dfc310
f3890ef
9cd424e
a09ea48
 
9cd424e
 
 
f3890ef
9cd424e
f3890ef
 
 
 
 
9cd424e
 
f3890ef
 
 
97006e1
4189fe1
 
d408dd5
f3890ef
 
d408dd5
a8606ac
f3890ef
a09ea48
4189fe1
f3890ef
 
 
 
 
 
 
9cd424e
 
 
f3890ef
9cd424e
f3890ef
 
 
 
9cd424e
 
4189fe1
f3890ef
a09ea48
f3890ef
a09ea48
f3890ef
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os, json, asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download

load_dotenv()
if (tok := os.getenv("HF_TOKEN")):
    login(token=tok)

device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading SNAC…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
snapshot_download(
    repo_id=model_name,
    allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
    ignore_patterns=[ "optimizer.pt", "pytorch_model.bin", "training_args.bin",
                      "scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt" ]
)

print("Loading Orpheus…")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16
)
model = model.to(device)
model.config.pad_token_id = model.config.eos_token_id

tokenizer = AutoTokenizer.from_pretrained(model_name)

# — Helper Functions (wie gehabt) —

def process_prompt(text: str, voice: str):
    prompt = f"{voice}: {text}"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    start = torch.tensor([[128259]], device=device)
    end   = torch.tensor([[128009, 128260]], device=device)
    return torch.cat([start, inputs.input_ids, end], dim=1)

def parse_output(ids: torch.LongTensor):
    st, rm = 128257, 128258
    idxs = (ids==st).nonzero(as_tuple=True)[1]
    cropped = ids[:, idxs[-1].item()+1:] if idxs.numel()>0 else ids
    row = cropped[0][cropped[0]!=rm]
    return row.tolist()

def redistribute_codes(codes: list[int], snac_model: SNAC):
    # … genau wie vorher …
    # return numpy array

app = FastAPI()

@app.get("/")
async def root():
    return {"status":"ok","msg":"Hello, Orpheus TTS up!"}

@app.websocket("/ws/tts")
async def ws_tts(ws: WebSocket):
    await ws.accept()
    try:
        msg = json.loads(await ws.receive_text())
        text, voice = msg.get("text",""), msg.get("voice","Jakob")
        ids = process_prompt(text, voice)
        gen = model.generate(
            input_ids=ids,
            max_new_tokens=2000,
            do_sample=True, temperature=0.7, top_p=0.95,
            repetition_penalty=1.1,
            eos_token_id=model.config.eos_token_id,
        )
        codes    = parse_output(gen)
        audio_np = redistribute_codes(codes, snac)
        pcm16    = (audio_np*32767).astype("int16").tobytes()
        chunk    = 2400*2
        for i in range(0,len(pcm16),chunk):
            await ws.send_bytes(pcm16[i:i+chunk])
            await asyncio.sleep(0.1)
        await ws.close()
    except WebSocketDisconnect:
        print("Client left")
    except Exception as e:
        print("Error in /ws/tts:",e)
        await ws.close(code=1011)

if __name__=="__main__":
    import uvicorn
    uvicorn.run("app:app",host="0.0.0.0",port=7860)