Tomtom84 commited on
Commit
f3890ef
·
1 Parent(s): 9cd424e
Files changed (1) hide show
  1. app.py +41 -94
app.py CHANGED
@@ -1,6 +1,4 @@
1
- import os
2
- import json
3
- import asyncio
4
  import torch
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
  from dotenv import load_dotenv
@@ -8,138 +6,87 @@ from snac import SNAC
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
  from huggingface_hub import login, snapshot_download
10
 
11
- # — ENV & HF‑AUTH —
12
  load_dotenv()
13
- HF_TOKEN = os.getenv("HF_TOKEN")
14
- if HF_TOKEN:
15
- login(token=HF_TOKEN)
16
 
17
- # — Device —
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # SNAC laden —
21
- print("Loading SNAC model...")
22
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
23
 
24
- # — Orpheus‑Modell vorbereiten —
25
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
26
-
27
- # Nur Konfig+Weights (ermöglicht schlankeren Container)
28
  snapshot_download(
29
  repo_id=model_name,
30
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
31
- ignore_patterns=[
32
- "optimizer.pt", "pytorch_model.bin", "training_args.bin",
33
- "scheduler.pt", "tokenizer.json", "tokenizer_config.json",
34
- "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
35
- ]
36
  )
37
 
38
- print("Loading Orpheus model...")
39
  model = AutoModelForCausalLM.from_pretrained(
40
  model_name,
41
- torch_dtype=torch.bfloat16,
42
- device_map="auto",
43
- ).to(device)
44
  model.config.pad_token_id = model.config.eos_token_id
45
 
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
47
 
48
-
49
- # — Hilfsfunktionen —
50
 
51
  def process_prompt(text: str, voice: str):
52
  prompt = f"{voice}: {text}"
53
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
54
- # füge Start-/End-Tokens hinzu
55
  start = torch.tensor([[128259]], device=device)
56
  end = torch.tensor([[128009, 128260]], device=device)
57
- input_ids = torch.cat([start, inputs.input_ids, end], dim=1)
58
- return input_ids
59
-
60
- def parse_output(generated_ids: torch.LongTensor):
61
- token_to_find = 128257
62
- token_to_remove = 128258
63
 
64
- idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
65
- if idxs.numel() > 0:
66
- cropped = generated_ids[:, idxs[-1].item() + 1 :]
67
- else:
68
- cropped = generated_ids
69
-
70
- row = cropped[0][cropped[0] != token_to_remove]
71
  return row.tolist()
72
 
73
- def redistribute_codes(code_list: list[int], snac_model: SNAC):
74
- layer1, layer2, layer3 = [], [], []
75
- for i in range((len(code_list) + 1) // 7):
76
- base = code_list[7*i : 7*i+7]
77
- layer1.append(base[0])
78
- layer2.append(base[1] - 4096)
79
- layer3.append(base[2] - 2*4096)
80
- layer3.append(base[3] - 3*4096)
81
- layer2.append(base[4] - 4*4096)
82
- layer3.append(base[5] - 5*4096)
83
- layer3.append(base[6] - 6*4096)
84
-
85
- dev = next(snac_model.parameters()).device
86
- codes = [
87
- torch.tensor(layer1, device=dev).unsqueeze(0),
88
- torch.tensor(layer2, device=dev).unsqueeze(0),
89
- torch.tensor(layer3, device=dev).unsqueeze(0),
90
- ]
91
- audio = snac_model.decode(codes)
92
- return audio.detach().squeeze().cpu().numpy()
93
-
94
 
95
- # — FastAPI App —
96
  app = FastAPI()
97
 
98
  @app.get("/")
99
- async def hello():
100
- return {"message": "Hello, Orpheus TTS is up and running!"}
101
 
102
  @app.websocket("/ws/tts")
103
- async def tts_ws(ws: WebSocket):
104
  await ws.accept()
105
  try:
106
- # **Nur EIN Request pro Connection**
107
- raw = await ws.receive_text()
108
- data = json.loads(raw)
109
- text = data.get("text", "")
110
- voice = data.get("voice", "Jakob")
111
-
112
- # 1) Text → input_ids
113
- input_ids = process_prompt(text, voice)
114
-
115
- # 2) Generation
116
- gen_ids = model.generate(
117
- input_ids=input_ids,
118
- max_new_tokens=2000, # hier kannst du hochsetzen
119
- do_sample=True,
120
- temperature=0.7,
121
- top_p=0.95,
122
  repetition_penalty=1.1,
123
  eos_token_id=model.config.eos_token_id,
124
  )
125
-
126
- # 3) Token → Audio
127
- codes = parse_output(gen_ids)
128
  audio_np = redistribute_codes(codes, snac)
129
-
130
- # 4) PCM16-Bytes in ~0.1s‑Chunks streamen
131
- pcm16 = (audio_np * 32767).astype("int16").tobytes()
132
- chunk_size = 2400 * 2 # 2400 Samples @24kHz = 0.1s * 2 Byte
133
- for i in range(0, len(pcm16), chunk_size):
134
- await ws.send_bytes(pcm16[i : i+chunk_size])
135
  await asyncio.sleep(0.1)
136
-
137
- # Sauber schließen, Client erhält ConnectionClosedOK
138
  await ws.close()
139
-
140
  except WebSocketDisconnect:
141
- print("Client disconnected")
142
  except Exception as e:
143
- # Log und saubere Fehler‑Closure
144
- print("Error in /ws/tts:", e)
145
  await ws.close(code=1011)
 
 
 
 
 
1
+ import os, json, asyncio
 
 
2
  import torch
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from dotenv import load_dotenv
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from huggingface_hub import login, snapshot_download
8
 
 
9
  load_dotenv()
10
+ if (tok := os.getenv("HF_TOKEN")):
11
+ login(token=tok)
 
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ print("Loading SNAC…")
 
16
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
17
 
 
18
  model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
 
 
19
  snapshot_download(
20
  repo_id=model_name,
21
  allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
22
+ ignore_patterns=[ "optimizer.pt", "pytorch_model.bin", "training_args.bin",
23
+ "scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt" ]
 
 
 
24
  )
25
 
26
+ print("Loading Orpheus")
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_name,
29
+ torch_dtype=torch.bfloat16
30
+ )
31
+ model = model.to(device)
32
  model.config.pad_token_id = model.config.eos_token_id
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(model_name)
35
 
36
+ # — Helper Functions (wie gehabt) —
 
37
 
38
  def process_prompt(text: str, voice: str):
39
  prompt = f"{voice}: {text}"
40
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
41
  start = torch.tensor([[128259]], device=device)
42
  end = torch.tensor([[128009, 128260]], device=device)
43
+ return torch.cat([start, inputs.input_ids, end], dim=1)
 
 
 
 
 
44
 
45
+ def parse_output(ids: torch.LongTensor):
46
+ st, rm = 128257, 128258
47
+ idxs = (ids==st).nonzero(as_tuple=True)[1]
48
+ cropped = ids[:, idxs[-1].item()+1:] if idxs.numel()>0 else ids
49
+ row = cropped[0][cropped[0]!=rm]
 
 
50
  return row.tolist()
51
 
52
+ def redistribute_codes(codes: list[int], snac_model: SNAC):
53
+ # genau wie vorher
54
+ # return numpy array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
56
  app = FastAPI()
57
 
58
  @app.get("/")
59
+ async def root():
60
+ return {"status":"ok","msg":"Hello, Orpheus TTS up!"}
61
 
62
  @app.websocket("/ws/tts")
63
+ async def ws_tts(ws: WebSocket):
64
  await ws.accept()
65
  try:
66
+ msg = json.loads(await ws.receive_text())
67
+ text, voice = msg.get("text",""), msg.get("voice","Jakob")
68
+ ids = process_prompt(text, voice)
69
+ gen = model.generate(
70
+ input_ids=ids,
71
+ max_new_tokens=2000,
72
+ do_sample=True, temperature=0.7, top_p=0.95,
 
 
 
 
 
 
 
 
 
73
  repetition_penalty=1.1,
74
  eos_token_id=model.config.eos_token_id,
75
  )
76
+ codes = parse_output(gen)
 
 
77
  audio_np = redistribute_codes(codes, snac)
78
+ pcm16 = (audio_np*32767).astype("int16").tobytes()
79
+ chunk = 2400*2
80
+ for i in range(0,len(pcm16),chunk):
81
+ await ws.send_bytes(pcm16[i:i+chunk])
 
 
82
  await asyncio.sleep(0.1)
 
 
83
  await ws.close()
 
84
  except WebSocketDisconnect:
85
+ print("Client left")
86
  except Exception as e:
87
+ print("Error in /ws/tts:",e)
 
88
  await ws.close(code=1011)
89
+
90
+ if __name__=="__main__":
91
+ import uvicorn
92
+ uvicorn.run("app:app",host="0.0.0.0",port=7860)