Spaces:
Paused
Paused
import os, json, asyncio | |
import torch | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from dotenv import load_dotenv | |
from snac import SNAC | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login, snapshot_download | |
load_dotenv() | |
if (tok := os.getenv("HF_TOKEN")): | |
login(token=tok) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Loading SNAC…") | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
snapshot_download( | |
repo_id=model_name, | |
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"], | |
ignore_patterns=[ "optimizer.pt", "pytorch_model.bin", "training_args.bin", | |
"scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt" ] | |
) | |
print("Loading Orpheus…") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16 | |
) | |
model = model.to(device) | |
model.config.pad_token_id = model.config.eos_token_id | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# — Helper Functions (wie gehabt) — | |
def process_prompt(text: str, voice: str): | |
prompt = f"{voice}: {text}" | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
start = torch.tensor([[128259]], device=device) | |
end = torch.tensor([[128009, 128260]], device=device) | |
return torch.cat([start, inputs.input_ids, end], dim=1) | |
def parse_output(ids: torch.LongTensor): | |
st, rm = 128257, 128258 | |
idxs = (ids==st).nonzero(as_tuple=True)[1] | |
cropped = ids[:, idxs[-1].item()+1:] if idxs.numel()>0 else ids | |
row = cropped[0][cropped[0]!=rm] | |
return row.tolist() | |
def redistribute_codes(codes: list[int], snac_model: SNAC): | |
# … genau wie vorher … | |
# return numpy array | |
app = FastAPI() | |
async def root(): | |
return {"status":"ok","msg":"Hello, Orpheus TTS up!"} | |
async def ws_tts(ws: WebSocket): | |
await ws.accept() | |
try: | |
msg = json.loads(await ws.receive_text()) | |
text, voice = msg.get("text",""), msg.get("voice","Jakob") | |
ids = process_prompt(text, voice) | |
gen = model.generate( | |
input_ids=ids, | |
max_new_tokens=2000, | |
do_sample=True, temperature=0.7, top_p=0.95, | |
repetition_penalty=1.1, | |
eos_token_id=model.config.eos_token_id, | |
) | |
codes = parse_output(gen) | |
audio_np = redistribute_codes(codes, snac) | |
pcm16 = (audio_np*32767).astype("int16").tobytes() | |
chunk = 2400*2 | |
for i in range(0,len(pcm16),chunk): | |
await ws.send_bytes(pcm16[i:i+chunk]) | |
await asyncio.sleep(0.1) | |
await ws.close() | |
except WebSocketDisconnect: | |
print("Client left") | |
except Exception as e: | |
print("Error in /ws/tts:",e) | |
await ws.close(code=1011) | |
if __name__=="__main__": | |
import uvicorn | |
uvicorn.run("app:app",host="0.0.0.0",port=7860) | |