Spaces:
Paused
Paused
File size: 3,014 Bytes
f3890ef 0316ec3 4189fe1 0316ec3 a09ea48 0dfc310 0316ec3 a09ea48 f3890ef 2008a3f 1ab029d 0316ec3 f3890ef a09ea48 0316ec3 674acbf 0dfc310 f3890ef 0dfc310 f001a32 f3890ef d408dd5 9cd424e f3890ef a09ea48 9cd424e b3e4aa7 0dfc310 f3890ef 9cd424e a09ea48 9cd424e f3890ef 9cd424e f3890ef 9cd424e f3890ef 97006e1 4189fe1 d408dd5 f3890ef d408dd5 a8606ac f3890ef a09ea48 4189fe1 f3890ef 9cd424e f3890ef 9cd424e f3890ef 9cd424e 4189fe1 f3890ef a09ea48 f3890ef a09ea48 f3890ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os, json, asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download
load_dotenv()
if (tok := os.getenv("HF_TOKEN")):
login(token=tok)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading SNAC…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
snapshot_download(
repo_id=model_name,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=[ "optimizer.pt", "pytorch_model.bin", "training_args.bin",
"scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt" ]
)
print("Loading Orpheus…")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16
)
model = model.to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# — Helper Functions (wie gehabt) —
def process_prompt(text: str, voice: str):
prompt = f"{voice}: {text}"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
start = torch.tensor([[128259]], device=device)
end = torch.tensor([[128009, 128260]], device=device)
return torch.cat([start, inputs.input_ids, end], dim=1)
def parse_output(ids: torch.LongTensor):
st, rm = 128257, 128258
idxs = (ids==st).nonzero(as_tuple=True)[1]
cropped = ids[:, idxs[-1].item()+1:] if idxs.numel()>0 else ids
row = cropped[0][cropped[0]!=rm]
return row.tolist()
def redistribute_codes(codes: list[int], snac_model: SNAC):
# … genau wie vorher …
# return numpy array
app = FastAPI()
@app.get("/")
async def root():
return {"status":"ok","msg":"Hello, Orpheus TTS up!"}
@app.websocket("/ws/tts")
async def ws_tts(ws: WebSocket):
await ws.accept()
try:
msg = json.loads(await ws.receive_text())
text, voice = msg.get("text",""), msg.get("voice","Jakob")
ids = process_prompt(text, voice)
gen = model.generate(
input_ids=ids,
max_new_tokens=2000,
do_sample=True, temperature=0.7, top_p=0.95,
repetition_penalty=1.1,
eos_token_id=model.config.eos_token_id,
)
codes = parse_output(gen)
audio_np = redistribute_codes(codes, snac)
pcm16 = (audio_np*32767).astype("int16").tobytes()
chunk = 2400*2
for i in range(0,len(pcm16),chunk):
await ws.send_bytes(pcm16[i:i+chunk])
await asyncio.sleep(0.1)
await ws.close()
except WebSocketDisconnect:
print("Client left")
except Exception as e:
print("Error in /ws/tts:",e)
await ws.close(code=1011)
if __name__=="__main__":
import uvicorn
uvicorn.run("app:app",host="0.0.0.0",port=7860)
|