Tomtom84 commited on
Commit
7b0d42c
·
verified ·
1 Parent(s): d9ea17d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -28
app.py CHANGED
@@ -7,12 +7,12 @@ from huggingface_hub import login
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from snac import SNAC
9
 
10
- # — HF‑Token & Login (wenn gesetzt) —
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  if HF_TOKEN:
13
  login(HF_TOKEN)
14
 
15
- # — Device wählen
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  app = FastAPI()
@@ -26,39 +26,38 @@ model = None
26
  tokenizer = None
27
  snac_model = None
28
 
29
- # — Startup: SNAC & Orpheus laden —
30
  @app.on_event("startup")
31
  async def load_models():
32
  global model, tokenizer, snac_model
33
- # 1) SNAC
 
34
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
35
- # 2) Orpheus‑TTS
36
- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
 
37
  tokenizer = AutoTokenizer.from_pretrained(REPO)
38
  model = AutoModelForCausalLM.from_pretrained(
39
  REPO,
40
- device_map="auto" if device=="cuda" else None,
41
- torch_dtype=torch.bfloat16 if device=="cuda" else None,
42
  low_cpu_mem_usage=True
43
  ).to(device)
44
  model.config.pad_token_id = model.config.eos_token_id
45
 
46
- # — Marker und Offsets aus der Vorlage
47
  START_TOKEN = 128259
48
  END_TOKENS = [128009, 128260]
49
  AUDIO_OFFSET = 128266
50
 
51
  def process_single_prompt(prompt: str, voice: str) -> list[int]:
52
- # Prompt zusammenbauen
53
- if voice and voice != "in_prompt":
54
- text = f"{voice}: {prompt}"
55
- else:
56
- text = prompt
57
  # Tokenize + Marker
58
- ids = tokenizer(text, return_tensors="pt").input_ids
59
- start = torch.tensor([[START_TOKEN]], dtype=torch.int64)
60
- end = torch.tensor([END_TOKENS], dtype=torch.int64)
61
- input_ids = torch.cat([start, ids, end], dim=1).to(device)
62
  attention_mask = torch.ones_like(input_ids)
63
 
64
  # Generieren
@@ -74,8 +73,8 @@ def process_single_prompt(prompt: str, voice: str) -> list[int]:
74
  use_cache=True,
75
  )
76
 
77
- # letzten START_TOKEN finden & croppen
78
- token_to_find = 128257
79
  token_to_remove = 128258
80
  idxs = (gen == token_to_find).nonzero(as_tuple=True)[1]
81
  if idxs.numel() > 0:
@@ -83,16 +82,16 @@ def process_single_prompt(prompt: str, voice: str) -> list[int]:
83
  else:
84
  cropped = gen
85
 
86
- # Padding entfernen
87
  row = cropped[0][cropped[0] != token_to_remove]
88
- # Aus Länge ein Vielfaches von 7 machen
89
  new_len = (row.size(0) // 7) * 7
90
  trimmed = row[:new_len].tolist()
 
91
  # Offset abziehen
92
  return [t - AUDIO_OFFSET for t in trimmed]
93
 
94
  def redistribute_codes(code_list: list[int]) -> np.ndarray:
95
- # Die 7er‑Blöcke auf 3 Layer verteilen und dekodieren
96
  layer1, layer2, layer3 = [], [], []
97
  for i in range(len(code_list) // 7):
98
  b = code_list[7*i : 7*i+7]
@@ -112,27 +111,25 @@ def redistribute_codes(code_list: list[int]) -> np.ndarray:
112
  audio = snac_model.decode(codes).squeeze().cpu().numpy()
113
  return audio # float32 @24 kHz
114
 
115
- # — WebSocket‑Endpoint für TTS —
116
  @app.websocket("/ws/tts")
117
  async def tts_ws(ws: WebSocket):
118
  await ws.accept()
119
  try:
120
- # 1) Text + Voice empfangen
121
  msg = await ws.receive_text()
122
  req = json.loads(msg)
123
  text = req.get("text", "")
124
  voice = req.get("voice", "")
125
 
126
- # 2) Prompt → Code‑Liste
127
  with torch.no_grad():
128
  codes = process_single_prompt(text, voice)
129
  audio_np = redistribute_codes(codes)
130
 
131
- # 3) In PCM16 konvertieren und senden
132
  pcm16 = (audio_np * 32767).astype(np.int16).tobytes()
133
  await ws.send_bytes(pcm16)
134
 
135
- # 4) sauber schließen
136
  await ws.close()
137
 
138
  except WebSocketDisconnect:
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from snac import SNAC
9
 
10
+ # — HF‑Token & Login (falls gesetzt) —
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  if HF_TOKEN:
13
  login(HF_TOKEN)
14
 
15
+ # — Device auswählen
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  app = FastAPI()
 
26
  tokenizer = None
27
  snac_model = None
28
 
 
29
  @app.on_event("startup")
30
  async def load_models():
31
  global model, tokenizer, snac_model
32
+
33
+ # 1) SNAC laden
34
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
35
+
36
+ # 2) Orpheus‑TTS (public “natural”-Variante)
37
+ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
38
  tokenizer = AutoTokenizer.from_pretrained(REPO)
39
  model = AutoModelForCausalLM.from_pretrained(
40
  REPO,
41
+ device_map="auto" if device == "cuda" else None,
42
+ torch_dtype=torch.bfloat16 if device == "cuda" else None,
43
  low_cpu_mem_usage=True
44
  ).to(device)
45
  model.config.pad_token_id = model.config.eos_token_id
46
 
47
+ # — Marker und Offsets —
48
  START_TOKEN = 128259
49
  END_TOKENS = [128009, 128260]
50
  AUDIO_OFFSET = 128266
51
 
52
  def process_single_prompt(prompt: str, voice: str) -> list[int]:
53
+ # Prompt zusammenstellen
54
+ text = f"{voice}: {prompt}" if voice and voice != "in_prompt" else prompt
55
+
 
 
56
  # Tokenize + Marker
57
+ ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
58
+ start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
59
+ end = torch.tensor([END_TOKENS], dtype=torch.int64, device=device)
60
+ input_ids = torch.cat([start, ids, end], dim=1)
61
  attention_mask = torch.ones_like(input_ids)
62
 
63
  # Generieren
 
73
  use_cache=True,
74
  )
75
 
76
+ # Nach letztem START_TOKEN croppen
77
+ token_to_find = 128257
78
  token_to_remove = 128258
79
  idxs = (gen == token_to_find).nonzero(as_tuple=True)[1]
80
  if idxs.numel() > 0:
 
82
  else:
83
  cropped = gen
84
 
85
+ # Padding entfernen & Länge auf Vielfaches von 7 bringen
86
  row = cropped[0][cropped[0] != token_to_remove]
 
87
  new_len = (row.size(0) // 7) * 7
88
  trimmed = row[:new_len].tolist()
89
+
90
  # Offset abziehen
91
  return [t - AUDIO_OFFSET for t in trimmed]
92
 
93
  def redistribute_codes(code_list: list[int]) -> np.ndarray:
94
+ # 7er‑Blöcke auf 3 Layer verteilen
95
  layer1, layer2, layer3 = [], [], []
96
  for i in range(len(code_list) // 7):
97
  b = code_list[7*i : 7*i+7]
 
111
  audio = snac_model.decode(codes).squeeze().cpu().numpy()
112
  return audio # float32 @24 kHz
113
 
 
114
  @app.websocket("/ws/tts")
115
  async def tts_ws(ws: WebSocket):
116
  await ws.accept()
117
  try:
 
118
  msg = await ws.receive_text()
119
  req = json.loads(msg)
120
  text = req.get("text", "")
121
  voice = req.get("voice", "")
122
 
123
+ # 1) Prompt → Codes → Audio
124
  with torch.no_grad():
125
  codes = process_single_prompt(text, voice)
126
  audio_np = redistribute_codes(codes)
127
 
128
+ # 2) In PCM16 wandeln & senden
129
  pcm16 = (audio_np * 32767).astype(np.int16).tobytes()
130
  await ws.send_bytes(pcm16)
131
 
132
+ # 3) sauber schließen
133
  await ws.close()
134
 
135
  except WebSocketDisconnect: