Tomtom84 commited on
Commit
2c15189
Β·
1 Parent(s): f3890ef
Files changed (1) hide show
  1. app.py +100 -44
app.py CHANGED
@@ -1,4 +1,6 @@
1
- import os, json, asyncio
 
 
2
  import torch
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from dotenv import load_dotenv
@@ -6,87 +8,141 @@ from snac import SNAC
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from huggingface_hub import login, snapshot_download
8
 
 
9
  load_dotenv()
10
- if (tok := os.getenv("HF_TOKEN")):
11
- login(token=tok)
 
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
- print("Loading SNAC…")
 
16
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
17
 
 
18
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
 
 
19
  snapshot_download(
20
  repo_id=model_name,
21
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
22
- ignore_patterns=[ "optimizer.pt", "pytorch_model.bin", "training_args.bin",
23
- "scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt" ]
 
 
24
  )
25
 
26
- print("Loading Orpheus…")
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_name,
29
- torch_dtype=torch.bfloat16
30
  )
31
  model = model.to(device)
32
  model.config.pad_token_id = model.config.eos_token_id
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(model_name)
35
 
36
- # β€” Helper Functions (wie gehabt) β€”
37
 
38
  def process_prompt(text: str, voice: str):
 
 
 
39
  prompt = f"{voice}: {text}"
40
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
41
  start = torch.tensor([[128259]], device=device)
42
  end = torch.tensor([[128009, 128260]], device=device)
43
- return torch.cat([start, inputs.input_ids, end], dim=1)
44
 
45
- def parse_output(ids: torch.LongTensor):
46
- st, rm = 128257, 128258
47
- idxs = (ids==st).nonzero(as_tuple=True)[1]
48
- cropped = ids[:, idxs[-1].item()+1:] if idxs.numel()>0 else ids
49
- row = cropped[0][cropped[0]!=rm]
 
 
 
 
 
 
50
  return row.tolist()
51
 
52
- def redistribute_codes(codes: list[int], snac_model: SNAC):
53
- # … genau wie vorher …
54
- # return numpy array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  app = FastAPI()
57
 
58
  @app.get("/")
59
- async def root():
60
- return {"status":"ok","msg":"Hello, Orpheus TTS up!"}
61
 
62
  @app.websocket("/ws/tts")
63
- async def ws_tts(ws: WebSocket):
64
  await ws.accept()
65
  try:
66
- msg = json.loads(await ws.receive_text())
67
- text, voice = msg.get("text",""), msg.get("voice","Jakob")
68
- ids = process_prompt(text, voice)
69
- gen = model.generate(
70
- input_ids=ids,
71
- max_new_tokens=2000,
72
- do_sample=True, temperature=0.7, top_p=0.95,
73
- repetition_penalty=1.1,
74
- eos_token_id=model.config.eos_token_id,
75
- )
76
- codes = parse_output(gen)
77
- audio_np = redistribute_codes(codes, snac)
78
- pcm16 = (audio_np*32767).astype("int16").tobytes()
79
- chunk = 2400*2
80
- for i in range(0,len(pcm16),chunk):
81
- await ws.send_bytes(pcm16[i:i+chunk])
82
- await asyncio.sleep(0.1)
83
- await ws.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  except WebSocketDisconnect:
85
- print("Client left")
86
  except Exception as e:
87
- print("Error in /ws/tts:",e)
88
  await ws.close(code=1011)
89
 
90
- if __name__=="__main__":
 
 
91
  import uvicorn
92
- uvicorn.run("app:app",host="0.0.0.0",port=7860)
 
1
+ import os
2
+ import json
3
+ import asyncio
4
  import torch
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
  from dotenv import load_dotenv
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
  from huggingface_hub import login, snapshot_download
10
 
11
+ # ─── ENV & HF TOKEN ────────────────────────────────────────────────────────────
12
  load_dotenv()
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+ if HF_TOKEN:
15
+ login(token=HF_TOKEN)
16
 
17
+ # ─── DEVICE ─────────────────────────────────────────────────────────────────────
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
+ # ─── SNAC ───────────────────────────────────────────────────────────────────────
21
+ print("Loading SNAC model…")
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
24
+ # ─── ORPHEUS ────────────────────────────────────────────────────────────────────
25
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
26
+
27
+ # pre‑download only the config + safetensors, damit das Image schlank bleibt
28
  snapshot_download(
29
  repo_id=model_name,
30
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
31
+ ignore_patterns=[
32
+ "optimizer.pt", "pytorch_model.bin", "training_args.bin",
33
+ "scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt"
34
+ ]
35
  )
36
 
37
+ print("Loading Orpheus model…")
38
  model = AutoModelForCausalLM.from_pretrained(
39
  model_name,
40
+ torch_dtype=torch.bfloat16 # optional: beschleunigt das FP16‑Àhnliche Rechnen
41
  )
42
  model = model.to(device)
43
  model.config.pad_token_id = model.config.eos_token_id
44
 
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
 
47
+ # ─── HILFSFUNKTIONEN ──────────────────────────────────────────────────────────
48
 
49
  def process_prompt(text: str, voice: str):
50
+ """
51
+ Baut aus Text+Voice ein batch‑Tensor input_ids fΓΌr `model.generate`.
52
+ """
53
  prompt = f"{voice}: {text}"
54
+ tok = tokenizer(prompt, return_tensors="pt").to(device)
55
  start = torch.tensor([[128259]], device=device)
56
  end = torch.tensor([[128009, 128260]], device=device)
57
+ return torch.cat([start, tok.input_ids, end], dim=1)
58
 
59
+ def parse_output(generated_ids: torch.LongTensor):
60
+ """
61
+ Schneidet bis zum letzten 128257 und entfernt 128258, gibt reine Token‑Liste zurΓΌck.
62
+ """
63
+ START, PAD = 128257, 128258
64
+ idxs = (generated_ids == START).nonzero(as_tuple=True)[1]
65
+ if idxs.numel() > 0:
66
+ cropped = generated_ids[:, idxs[-1].item()+1:]
67
+ else:
68
+ cropped = generated_ids
69
+ row = cropped[0][cropped[0] != PAD]
70
  return row.tolist()
71
 
72
+ def redistribute_codes(code_list: list[int], snac_model: SNAC):
73
+ """
74
+ Verteilt 7er‑BlΓΆcke auf die drei SNAC‑Layer und dekodiert zu Audio (numpy float32).
75
+ """
76
+ layer1, layer2, layer3 = [], [], []
77
+ for i in range((len(code_list) + 1) // 7):
78
+ base = code_list[7*i : 7*i+7]
79
+ layer1.append(base[0])
80
+ layer2.append(base[1] - 4096)
81
+ layer3.append(base[2] - 2*4096)
82
+ layer3.append(base[3] - 3*4096)
83
+ layer2.append(base[4] - 4*4096)
84
+ layer3.append(base[5] - 5*4096)
85
+ layer3.append(base[6] - 6*4096)
86
+ dev = next(snac_model.parameters()).device
87
+ codes = [
88
+ torch.tensor(layer1, device=dev).unsqueeze(0),
89
+ torch.tensor(layer2, device=dev).unsqueeze(0),
90
+ torch.tensor(layer3, device=dev).unsqueeze(0),
91
+ ]
92
+ audio = snac_model.decode(codes)
93
+ return audio.detach().squeeze().cpu().numpy()
94
+
95
+ # ─── FASTAPI ───────���────────────────────────────────────────────────────────────
96
 
97
  app = FastAPI()
98
 
99
  @app.get("/")
100
+ async def healthcheck():
101
+ return {"status": "ok", "msg": "Hello, Orpheus TTS up!"}
102
 
103
  @app.websocket("/ws/tts")
104
+ async def tts_ws(ws: WebSocket):
105
  await ws.accept()
106
  try:
107
+ while True:
108
+ # 1) Eintreffende JSON‐Nachricht parsen
109
+ data = json.loads(await ws.receive_text())
110
+ text = data.get("text", "")
111
+ voice = data.get("voice", "Jakob")
112
+
113
+ # 2) Prompt β†’ input_ids
114
+ ids = process_prompt(text, voice)
115
+
116
+ # 3) Token‐Erzeugung
117
+ gen_ids = model.generate(
118
+ input_ids=ids,
119
+ max_new_tokens=2000, # hier z.B. 20k geht auch, wird aber speicherintensiv
120
+ do_sample=True,
121
+ temperature=0.7,
122
+ top_p=0.95,
123
+ repetition_penalty=1.1,
124
+ eos_token_id=model.config.eos_token_id,
125
+ )
126
+
127
+ # 4) Tokens β†’ Code‐Liste β†’ Audio
128
+ codes = parse_output(gen_ids)
129
+ audio_np = redistribute_codes(codes, snac)
130
+
131
+ # 5) PCM16‐Stream in 0.1 s‐BlΓΆcken
132
+ pcm16 = (audio_np * 32767).astype("int16").tobytes()
133
+ chunk = 2400 * 2
134
+ for i in range(0, len(pcm16), chunk):
135
+ await ws.send_bytes(pcm16[i : i+chunk])
136
+ await asyncio.sleep(0.1)
137
+
138
  except WebSocketDisconnect:
139
+ print("Client disconnected")
140
  except Exception as e:
141
+ print("Error in /ws/tts:", e)
142
  await ws.close(code=1011)
143
 
144
+ # ─── START ──────────────────────────────────────────────────────────────────────
145
+
146
+ if __name__ == "__main__":
147
  import uvicorn
148
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")