Tomtom84 commited on
Commit
c70d8eb
·
verified ·
1 Parent(s): 9bf14d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -58
app.py CHANGED
@@ -12,18 +12,18 @@ HF_TOKEN = os.getenv("HF_TOKEN")
12
  if HF_TOKEN:
13
  login(HF_TOKEN)
14
 
15
- # — Device wählen
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # — FastAPI instanziieren —
19
  app = FastAPI()
20
 
21
- # — Hello‑Route, damit kein 404 bei GET / mehr kommt
22
  @app.get("/")
23
  async def read_root():
24
  return {"message": "Hello, world!"}
25
 
26
- # — Modelle bei Startup laden —
27
  @app.on_event("startup")
28
  async def load_models():
29
  global tokenizer, model, snac
@@ -34,99 +34,122 @@ async def load_models():
34
  tokenizer = AutoTokenizer.from_pretrained(model_name)
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_name,
37
- device_map={"": 0} if device == "cuda" else None,
38
- torch_dtype=torch.bfloat16 if device == "cuda" else None,
39
  low_cpu_mem_usage=True
40
- )
41
- # Pad‑ID auf EOS einstellen
42
  model.config.pad_token_id = model.config.eos_token_id
43
 
44
- # — Hilfsfunktionen
45
  def prepare_inputs(text: str, voice: str):
46
  prompt = f"{voice}: {text}"
47
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
48
- # Start‑/End‑Marker
49
  start = torch.tensor([[128259]], dtype=torch.int64, device=device)
50
  end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
51
  ids = torch.cat([start, input_ids, end], dim=1)
52
- mask = torch.ones_like(ids)
53
  return ids, mask
54
 
55
- def decode_block(block_tokens: list[int]):
56
- # aus 7 Tokens einen SNAC‑Decode‑Block bauen
57
- layer1, layer2, layer3 = [], [], []
58
- b = block_tokens
59
- layer1.append(b[0])
60
- layer2.append(b[1] - 4096)
61
- layer3.append(b[2] - 2*4096)
62
- layer3.append(b[3] - 3*4096)
63
- layer2.append(b[4] - 4*4096)
64
- layer3.append(b[5] - 5*4096)
65
- layer3.append(b[6] - 6*4096)
66
  codes = [
67
- torch.tensor(layer1, device=device).unsqueeze(0),
68
- torch.tensor(layer2, device=device).unsqueeze(0),
69
- torch.tensor(layer3, device=device).unsqueeze(0),
70
  ]
71
- # ergibt FloatTensor shape (1, N), @24 kHz
72
  audio = snac.decode(codes).squeeze().cpu().numpy()
73
- # in PCM16 umwandeln
74
  return (audio * 32767).astype("int16").tobytes()
75
 
76
- # — WebSocket Endpoint für TTS Streaming
77
  @app.websocket("/ws/tts")
78
  async def tts_ws(ws: WebSocket):
79
  await ws.accept()
80
  try:
81
- # erst die Anfrage als JSON empfangen
82
  msg = await ws.receive_text()
83
  req = json.loads(msg)
84
  text = req.get("text", "")
85
  voice = req.get("voice", "Jakob")
86
 
87
- # Inputs bauen
88
  input_ids, attention_mask = prepare_inputs(text, voice)
89
  past_kvs = None
90
- collected = []
 
 
 
 
 
 
 
91
 
92
- # Token‑für‑Token loop
93
  while True:
94
- out = model(
95
- input_ids=input_ids if past_kvs is None else None,
96
  attention_mask=attention_mask if past_kvs is None else None,
97
- past_key_values=past_kvs,
 
 
 
 
 
98
  use_cache=True,
 
 
 
99
  )
100
- logits = out.logits[:, -1, :]
101
  past_kvs = out.past_key_values
 
 
102
 
103
- # Sampling
104
- probs = torch.softmax(logits, dim=-1)
105
- nxt = torch.multinomial(probs, num_samples=1).item()
106
 
107
- # Ende, wenn EOS
108
- if nxt == model.config.eos_token_id:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  break
110
- # Reset bei neuem Start‑Marker
111
- if nxt == 128257:
112
- collected = []
113
- continue
114
-
115
- # Audio‑Code offsetten und sammeln
116
- collected.append(nxt - 128266)
117
- # sobald 7 Stück, direkt dekodieren und senden
118
- if len(collected) == 7:
119
- pcm = decode_block(collected)
120
- collected = []
121
- await ws.send_bytes(pcm)
122
-
123
- # nach Ende sauber schließen
124
- await ws.close()
125
 
 
 
 
 
 
126
  except WebSocketDisconnect:
127
- # Client hat disconnectet
128
- pass
129
  except Exception as e:
130
- # bei Fehlern 1011 senden
131
  print("Error in /ws/tts:", e)
132
  await ws.close(code=1011)
 
 
 
 
 
 
12
  if HF_TOKEN:
13
  login(HF_TOKEN)
14
 
15
+ # — Device auswählen
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # — FastAPI instanziieren —
19
  app = FastAPI()
20
 
21
+ # — Hello‑Route, damit GET / nicht 404 gibt
22
  @app.get("/")
23
  async def read_root():
24
  return {"message": "Hello, world!"}
25
 
26
+ # — Modelle beim Startup laden —
27
  @app.on_event("startup")
28
  async def load_models():
29
  global tokenizer, model, snac
 
34
  tokenizer = AutoTokenizer.from_pretrained(model_name)
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_name,
37
+ device_map="auto" if device=="cuda" else None,
38
+ torch_dtype=torch.bfloat16 if device=="cuda" else None,
39
  low_cpu_mem_usage=True
40
+ ).to(device)
 
41
  model.config.pad_token_id = model.config.eos_token_id
42
 
43
+ # — Input‑Vorbereitung
44
  def prepare_inputs(text: str, voice: str):
45
  prompt = f"{voice}: {text}"
46
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
47
  start = torch.tensor([[128259]], dtype=torch.int64, device=device)
48
  end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
49
  ids = torch.cat([start, input_ids, end], dim=1)
50
+ mask = torch.ones_like(ids, device=device)
51
  return ids, mask
52
 
53
+ # SNAC‑Dekodierung eines 7‑Token‑Blocks →
54
+ def decode_block(tokens: list[int]) -> bytes:
55
+ l1, l2, l3 = [], [], []
56
+ b = tokens
57
+ l1.append(b[0])
58
+ l2.append(b[1]-4096)
59
+ l3.append(b[2]-2*4096)
60
+ l3.append(b[3]-3*4096)
61
+ l2.append(b[4]-4*4096)
62
+ l3.append(b[5]-5*4096)
63
+ l3.append(b[6]-6*4096)
64
  codes = [
65
+ torch.tensor(l1, device=device).unsqueeze(0),
66
+ torch.tensor(l2, device=device).unsqueeze(0),
67
+ torch.tensor(l3, device=device).unsqueeze(0),
68
  ]
 
69
  audio = snac.decode(codes).squeeze().cpu().numpy()
 
70
  return (audio * 32767).astype("int16").tobytes()
71
 
72
+ # — WebSocketEndpoint mit Chunked‑Generate (max_new_tokens=50)
73
  @app.websocket("/ws/tts")
74
  async def tts_ws(ws: WebSocket):
75
  await ws.accept()
76
  try:
77
+ # 1) Anfrage einlesen
78
  msg = await ws.receive_text()
79
  req = json.loads(msg)
80
  text = req.get("text", "")
81
  voice = req.get("voice", "Jakob")
82
 
83
+ # 2) Inputs bauen
84
  input_ids, attention_mask = prepare_inputs(text, voice)
85
  past_kvs = None
86
+ buffer_codes: list[int] = []
87
+
88
+ # 3) Chunk‑Generate‑Loop
89
+ chunk_size = 50
90
+ eos_id = model.config.eos_token_id
91
+
92
+ # Wir tracken bisher erzeugte Länge, um abzugrenzen, was neu ist
93
+ prev_len = 0
94
 
 
95
  while True:
96
+ out = model.generate(
97
+ input_ids = input_ids if past_kvs is None else None,
98
  attention_mask=attention_mask if past_kvs is None else None,
99
+ max_new_tokens=chunk_size,
100
+ do_sample=True,
101
+ temperature=0.7,
102
+ top_p=0.95,
103
+ repetition_penalty=1.1,
104
+ eos_token_id=eos_id,
105
  use_cache=True,
106
+ return_dict_in_generate=True,
107
+ output_scores=False,
108
+ past_key_values=past_kvs
109
  )
110
+ # Update past_kvs und sequences
111
  past_kvs = out.past_key_values
112
+ seqs = out.sequences # (1, total_length)
113
+ total_len = seqs.shape[1]
114
 
115
+ # 4) Neue Tokens extrahieren
116
+ new_tokens = seqs[0, prev_len:total_len].tolist()
117
+ prev_len = total_len
118
 
119
+ # 5) Jeden neuen Token aufbereiten
120
+ for tok in new_tokens:
121
+ if tok == eos_id:
122
+ # Ende
123
+ new_tokens = [] # clean up
124
+ break
125
+ if tok == 128257:
126
+ buffer_codes.clear()
127
+ continue
128
+ # offset und puffern
129
+ buffer_codes.append(tok - 128266)
130
+ # sobald 7 Codes gesammelt, dekodieren & senden
131
+ if len(buffer_codes) >= 7:
132
+ block = buffer_codes[:7]
133
+ buffer_codes = buffer_codes[7:]
134
+ pcm = decode_block(block)
135
+ await ws.send_bytes(pcm)
136
+
137
+ # 6) Abbruch, wenn EOS im Chunk war
138
+ if eos_id in new_tokens:
139
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ # Inputs für nächsten Durchgang nur beim ersten Mal
142
+ input_ids = attention_mask = None
143
+
144
+ # 7) Zum Schluss sauber schließen
145
+ await ws.close()
146
  except WebSocketDisconnect:
147
+ return
 
148
  except Exception as e:
 
149
  print("Error in /ws/tts:", e)
150
  await ws.close(code=1011)
151
+
152
+ # — Main für lokalen Test —
153
+ if __name__ == "__main__":
154
+ import uvicorn
155
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)