Tomtom84's picture
up3
f3890ef
raw
history blame
3.01 kB
import os, json, asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download
load_dotenv()
if (tok := os.getenv("HF_TOKEN")):
login(token=tok)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading SNAC…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
snapshot_download(
repo_id=model_name,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=[ "optimizer.pt", "pytorch_model.bin", "training_args.bin",
"scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt" ]
)
print("Loading Orpheus…")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16
)
model = model.to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# — Helper Functions (wie gehabt) —
def process_prompt(text: str, voice: str):
prompt = f"{voice}: {text}"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
start = torch.tensor([[128259]], device=device)
end = torch.tensor([[128009, 128260]], device=device)
return torch.cat([start, inputs.input_ids, end], dim=1)
def parse_output(ids: torch.LongTensor):
st, rm = 128257, 128258
idxs = (ids==st).nonzero(as_tuple=True)[1]
cropped = ids[:, idxs[-1].item()+1:] if idxs.numel()>0 else ids
row = cropped[0][cropped[0]!=rm]
return row.tolist()
def redistribute_codes(codes: list[int], snac_model: SNAC):
# … genau wie vorher …
# return numpy array
app = FastAPI()
@app.get("/")
async def root():
return {"status":"ok","msg":"Hello, Orpheus TTS up!"}
@app.websocket("/ws/tts")
async def ws_tts(ws: WebSocket):
await ws.accept()
try:
msg = json.loads(await ws.receive_text())
text, voice = msg.get("text",""), msg.get("voice","Jakob")
ids = process_prompt(text, voice)
gen = model.generate(
input_ids=ids,
max_new_tokens=2000,
do_sample=True, temperature=0.7, top_p=0.95,
repetition_penalty=1.1,
eos_token_id=model.config.eos_token_id,
)
codes = parse_output(gen)
audio_np = redistribute_codes(codes, snac)
pcm16 = (audio_np*32767).astype("int16").tobytes()
chunk = 2400*2
for i in range(0,len(pcm16),chunk):
await ws.send_bytes(pcm16[i:i+chunk])
await asyncio.sleep(0.1)
await ws.close()
except WebSocketDisconnect:
print("Client left")
except Exception as e:
print("Error in /ws/tts:",e)
await ws.close(code=1011)
if __name__=="__main__":
import uvicorn
uvicorn.run("app:app",host="0.0.0.0",port=7860)