Tomtom84 commited on
Commit
9bf14d0
·
verified ·
1 Parent(s): e9a25ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -142
app.py CHANGED
@@ -3,168 +3,130 @@ import json
3
  import asyncio
4
  import torch
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
- from fastapi.responses import PlainTextResponse
7
- from dotenv import load_dotenv
8
  from snac import SNAC
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
- from peft import PeftModel
11
 
12
- # — ENV & HF‑AUTH
13
- load_dotenv()
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if HF_TOKEN:
16
- # automatisch über huggingface-cli eingeloggt
17
- os.environ["HUGGINGFACE_HUB_TOKEN"] = HF_TOKEN
18
 
19
- # — FastAPI
20
- app = FastAPI()
21
-
22
- @app.get("/")
23
- async def hello():
24
- return PlainTextResponse("Hallo Welt!")
25
-
26
- # — Device konfigurieren —
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
- # — SNAC laden
30
- print("Loading SNAC model…")
31
- snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
32
-
33
- # — Orpheus/Kartoffel‑3B über PEFT laden —
34
- model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
35
- print(f"Loading base LM + PEFT from {model_name}…")
36
- base = AutoModelForCausalLM.from_pretrained(
37
- model_name,
38
- device_map="auto",
39
- torch_dtype=torch.bfloat16,
40
- )
41
- model = PeftModel.from_pretrained(
42
- base,
43
- model_name,
44
- device_map="auto",
45
- )
46
- model.eval()
47
 
48
- tokenizer = AutoTokenizer.from_pretrained(model_name)
49
- # sicherstellen, dass pad_token_id gesetzt ist
50
- model.config.pad_token_id = model.config.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # — Hilfsfunktionen —
53
- def prepare_prompt(text: str, voice: str):
54
- """Setzt Start‑ und End‑Marker um den eigentlichen Prompt."""
55
- if voice:
56
- full = f"{voice}: {text}"
57
- else:
58
- full = text
59
- start = torch.tensor([[128259]], dtype=torch.int64) # BOS für Audio
60
- end = torch.tensor([[128009, 128260]], dtype=torch.int64) # ggf. Speaker‑ID + Marker
61
- enc = tokenizer(full, return_tensors="pt").input_ids
62
- seq = torch.cat([start, enc, end], dim=1).to(device)
63
- mask = torch.ones_like(seq).to(device)
64
- return seq, mask
65
-
66
- def extract_audio_tokens(generated: torch.LongTensor):
67
- """Croppe alles bis zum echten Audio-Start, entferne EOS und mache 7er-Batches."""
68
- bos_tok = 128257
69
- eos_tok = 128258
70
-
71
- # letzten Start‑Token finden und ab da weiter
72
- idxs = (generated == bos_tok).nonzero(as_tuple=True)[1]
73
- if idxs.numel() > 0:
74
- cut = idxs[-1].item() + 1
75
- cropped = generated[:, cut:]
76
- else:
77
- cropped = generated
78
-
79
- # EOS‑Marker entfernen
80
- flat = cropped[0][cropped[0] != eos_tok]
81
-
82
- # nur ein Vielfaches von 7 behalten
83
- length = (flat.size(0) // 7) * 7
84
- flat = flat[:length]
85
-
86
- # Die Audio‑Token beginnen ab Offset 128266
87
- return [(t.item() - 128266) for t in flat]
88
-
89
- def decode_and_stream(tokens: list[int], ws: WebSocket):
90
- """Wandelt 7er‑Gruppen in Wave‑Samples um und streamt in 0.1 s Chunks."""
91
- # gruppiere nach 7 und dekodiere jeweils
92
- pcm16 = bytearray()
93
- offset = 0
94
- while offset + 7 <= len(tokens):
95
- block = tokens[offset:offset+7]
96
- offset += 7
97
-
98
- # SNAC‑Input vorbereiten
99
- # Layer‑1: direkt, Layer‑2/3 mit Offsets
100
- l1, l2, l3 = [], [], []
101
- l1.append(block[0])
102
- l2.append(block[1] - 4096)
103
- l3.append(block[2] - 2*4096)
104
- l3.append(block[3] - 3*4096)
105
- l2.append(block[4] - 4*4096)
106
- l3.append(block[5] - 5*4096)
107
- l3.append(block[6] - 6*4096)
108
-
109
- t1 = torch.tensor(l1, device=device).unsqueeze(0)
110
- t2 = torch.tensor(l2, device=device).unsqueeze(0)
111
- t3 = torch.tensor(l3, device=device).unsqueeze(0)
112
- audio = snac.decode([t1, t2, t3]).squeeze().cpu().numpy()
113
-
114
- # in PCM16 @24 kHz
115
- pcm = (audio * 32767).astype("int16").tobytes()
116
- pcm16.extend(pcm)
117
-
118
- # in 0.1 s‑Chunks (2400 Samples ×2 Bytes)
119
- chunk_size = 2400 * 2
120
- for i in range(0, len(pcm16), chunk_size):
121
- ws.send_bytes(pcm16[i : i+chunk_size])
122
- # ohne Pause kann das WebSocket überlastet werden
123
- asyncio.sleep(0.1)
124
-
125
- # — WebSocket TTS Endpoint —
126
  @app.websocket("/ws/tts")
127
  async def tts_ws(ws: WebSocket):
128
  await ws.accept()
129
  try:
 
 
 
 
 
 
 
 
 
 
 
 
130
  while True:
131
- raw = await ws.receive_text()
132
- req = json.loads(raw)
133
- text = req.get("text", "")
134
- voice = req.get("voice", "")
135
-
136
- # Prompt vorbereiten
137
- ids, mask = prepare_prompt(text, voice)
138
-
139
- # Audio‑Token generieren
140
- gen = model.generate(
141
- input_ids=ids,
142
- attention_mask=mask,
143
- max_new_tokens=4000,
144
- do_sample=True,
145
- temperature=0.7,
146
- top_p=0.95,
147
- repetition_penalty=1.1,
148
- eos_token_id=128258,
149
- forced_bos_token_id=128259,
150
  use_cache=True,
151
  )
152
-
153
- codes = extract_audio_tokens(gen)
154
- # stream synchron
155
- await decode_and_stream(codes, ws)
156
-
157
- # sauber schließen
158
- await ws.close(code=1000)
159
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  except WebSocketDisconnect:
162
- print("Client disconnected")
 
163
  except Exception as e:
 
164
  print("Error in /ws/tts:", e)
165
  await ws.close(code=1011)
166
-
167
- # — Lokal starten —
168
- if __name__ == "__main__":
169
- import uvicorn
170
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
3
  import asyncio
4
  import torch
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
+ from huggingface_hub import login
 
7
  from snac import SNAC
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
9
 
10
+ # — HF‑Token & Login
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  if HF_TOKEN:
13
+ login(HF_TOKEN)
 
14
 
15
+ # — Device wählen —
 
 
 
 
 
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ # — FastAPI instanziieren
19
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Hello‑Route, damit kein 404 bei GET / mehr kommt —
22
+ @app.get("/")
23
+ async def read_root():
24
+ return {"message": "Hello, world!"}
25
+
26
+ # — Modelle bei Startup laden —
27
+ @app.on_event("startup")
28
+ async def load_models():
29
+ global tokenizer, model, snac
30
+ # SNAC laden
31
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
32
+ # TTS‑Modell laden
33
+ model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ device_map={"": 0} if device == "cuda" else None,
38
+ torch_dtype=torch.bfloat16 if device == "cuda" else None,
39
+ low_cpu_mem_usage=True
40
+ )
41
+ # Pad‑ID auf EOS einstellen
42
+ model.config.pad_token_id = model.config.eos_token_id
43
 
44
  # — Hilfsfunktionen —
45
+ def prepare_inputs(text: str, voice: str):
46
+ prompt = f"{voice}: {text}"
47
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
48
+ # Start‑/End‑Marker
49
+ start = torch.tensor([[128259]], dtype=torch.int64, device=device)
50
+ end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
51
+ ids = torch.cat([start, input_ids, end], dim=1)
52
+ mask = torch.ones_like(ids)
53
+ return ids, mask
54
+
55
+ def decode_block(block_tokens: list[int]):
56
+ # aus 7 Tokens einen SNAC‑Decode‑Block bauen
57
+ layer1, layer2, layer3 = [], [], []
58
+ b = block_tokens
59
+ layer1.append(b[0])
60
+ layer2.append(b[1] - 4096)
61
+ layer3.append(b[2] - 2*4096)
62
+ layer3.append(b[3] - 3*4096)
63
+ layer2.append(b[4] - 4*4096)
64
+ layer3.append(b[5] - 5*4096)
65
+ layer3.append(b[6] - 6*4096)
66
+ codes = [
67
+ torch.tensor(layer1, device=device).unsqueeze(0),
68
+ torch.tensor(layer2, device=device).unsqueeze(0),
69
+ torch.tensor(layer3, device=device).unsqueeze(0),
70
+ ]
71
+ # ergibt FloatTensor shape (1, N), @24 kHz
72
+ audio = snac.decode(codes).squeeze().cpu().numpy()
73
+ # in PCM16 umwandeln
74
+ return (audio * 32767).astype("int16").tobytes()
75
+
76
+ # WebSocket Endpoint für TTS Streaming —
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @app.websocket("/ws/tts")
78
  async def tts_ws(ws: WebSocket):
79
  await ws.accept()
80
  try:
81
+ # erst die Anfrage als JSON empfangen
82
+ msg = await ws.receive_text()
83
+ req = json.loads(msg)
84
+ text = req.get("text", "")
85
+ voice = req.get("voice", "Jakob")
86
+
87
+ # Inputs bauen
88
+ input_ids, attention_mask = prepare_inputs(text, voice)
89
+ past_kvs = None
90
+ collected = []
91
+
92
+ # Token‑für‑Token loop
93
  while True:
94
+ out = model(
95
+ input_ids=input_ids if past_kvs is None else None,
96
+ attention_mask=attention_mask if past_kvs is None else None,
97
+ past_key_values=past_kvs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  use_cache=True,
99
  )
100
+ logits = out.logits[:, -1, :]
101
+ past_kvs = out.past_key_values
102
+
103
+ # Sampling
104
+ probs = torch.softmax(logits, dim=-1)
105
+ nxt = torch.multinomial(probs, num_samples=1).item()
106
+
107
+ # Ende, wenn EOS
108
+ if nxt == model.config.eos_token_id:
109
+ break
110
+ # Reset bei neuem Start‑Marker
111
+ if nxt == 128257:
112
+ collected = []
113
+ continue
114
+
115
+ # Audio‑Code offsetten und sammeln
116
+ collected.append(nxt - 128266)
117
+ # sobald 7 Stück, direkt dekodieren und senden
118
+ if len(collected) == 7:
119
+ pcm = decode_block(collected)
120
+ collected = []
121
+ await ws.send_bytes(pcm)
122
+
123
+ # nach Ende sauber schließen
124
+ await ws.close()
125
 
126
  except WebSocketDisconnect:
127
+ # Client hat disconnectet
128
+ pass
129
  except Exception as e:
130
+ # bei Fehlern 1011 senden
131
  print("Error in /ws/tts:", e)
132
  await ws.close(code=1011)