Tomtom84 commited on
Commit
0dfc310
·
verified ·
1 Parent(s): a09ea48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -49
app.py CHANGED
@@ -6,21 +6,34 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
  from dotenv import load_dotenv
7
  from snac import SNAC
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
- from huggingface_hub import login
10
 
11
- # — Environment & HF‑Auth
12
  load_dotenv()
13
  HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
- # — Device & Modelle laden
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
19
 
 
20
  print("Loading SNAC model...")
21
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
22
 
23
  model_name = "canopylabs/3b-de-ft-research_release"
 
 
 
 
 
 
 
 
 
 
 
24
  print("Loading Orpheus model...")
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_name, torch_dtype=torch.bfloat16
@@ -29,49 +42,61 @@ model.config.pad_token_id = model.config.eos_token_id
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(model_name)
31
 
32
- # — Hilfsfunktionen —
 
33
  def process_prompt(text: str, voice: str):
34
  prompt = f"{voice}: {text}"
35
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
36
  start = torch.tensor([[128259]], dtype=torch.int64)
37
- end = torch.tensor([[128009, 128260]], dtype=torch.int64)
38
- ids = torch.cat([start, input_ids, end], dim=1).to(device)
39
  mask = torch.ones_like(ids).to(device)
40
  return ids, mask
41
 
42
  def parse_output(generated_ids: torch.LongTensor):
43
- token_to_find = 128257
 
44
  token_to_remove = 128258
 
 
45
  idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
46
  if idxs.numel() > 0:
47
- last = idxs[-1].item()
48
- cropped = generated_ids[:, last+1:]
49
  else:
50
  cropped = generated_ids
51
- # remove padding token markers
 
52
  rows = []
53
  for row in cropped:
54
- row = row[row != token_to_remove]
55
- rows.append(row)
56
- flat = rows[0].tolist()
57
- # adjust and regroup
 
 
 
58
  layer1, layer2, layer3 = [], [], []
59
- for i in range(len(flat)//7):
60
- base = flat[7*i:7*i+7]
61
  layer1.append(base[0])
62
- layer2.append(base[1]-4096)
63
- layer3.extend([base[2]-(2*4096), base[3]-(3*4096)])
64
- layer2.append(base[4]-4*4096)
65
- layer3.extend([base[5]-(5*4096), base[6]-(6*4096)])
 
 
 
 
66
  codes = [
67
- torch.tensor(layer1, device=device).unsqueeze(0),
68
- torch.tensor(layer2, device=device).unsqueeze(0),
69
- torch.tensor(layer3, device=device).unsqueeze(0),
70
  ]
71
- audio = snac.decode(codes).detach().squeeze().cpu().numpy()
72
- return audio # float32 numpy at 24000 Hz
73
 
74
- # — FastAPI + WebSocket-Endpoint —
75
  app = FastAPI()
76
 
77
  @app.websocket("/ws/tts")
@@ -80,31 +105,36 @@ async def tts_ws(ws: WebSocket):
80
  try:
81
  while True:
82
  msg = await ws.receive_text()
83
- data = json.loads(msg)
84
- text = data.get("text", "")
85
  voice = data.get("voice", "jana")
86
- # Generate tokens
 
87
  ids, mask = process_prompt(text, voice)
88
- with torch.no_grad():
89
- gen_ids = model.generate(
90
- input_ids=ids,
91
- attention_mask=mask,
92
- max_new_tokens=1200,
93
- do_sample=True,
94
- temperature=0.7,
95
- top_p=0.95,
96
- repetition_penalty=1.1,
97
- eos_token_id=128258,
98
- )
99
- # Convert to waveform
100
- audio = parse_output(gen_ids)
101
- # PCM16 conversion & chunking
102
- pcm16 = (audio * 32767).astype('int16').tobytes()
103
- # 0.1 s @24 kHz = 2400 samples = 4800 bytes
104
- chunk_size = 2400 * 2
105
- for i in range(0, len(pcm16), chunk_size):
106
- await ws.send_bytes(pcm16[i:i+chunk_size])
107
- await asyncio.sleep(0.1) # pacing
 
 
 
 
108
  except WebSocketDisconnect:
109
  print("Client disconnected")
110
  except Exception as e:
 
6
  from dotenv import load_dotenv
7
  from snac import SNAC
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from huggingface_hub import login, snapshot_download
10
 
11
+ # — ENV & HF‑AUTH
12
  load_dotenv()
13
  HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
+ # — Debug: CPU‑Modus zum Entwickeln, später wieder "cuda"
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ # device = "cpu"
20
 
21
+ # — Modelle laden —
22
  print("Loading SNAC model...")
23
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
24
 
25
  model_name = "canopylabs/3b-de-ft-research_release"
26
+ # optional: explizites snapshot_download (entfernt große Dateien)
27
+ snapshot_download(
28
+ repo_id=model_name,
29
+ allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
30
+ ignore_patterns=[
31
+ "optimizer.pt", "pytorch_model.bin", "training_args.bin",
32
+ "scheduler.pt", "tokenizer.json", "tokenizer_config.json",
33
+ "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
34
+ ]
35
+ )
36
+
37
  print("Loading Orpheus model...")
38
  model = AutoModelForCausalLM.from_pretrained(
39
  model_name, torch_dtype=torch.bfloat16
 
42
 
43
  tokenizer = AutoTokenizer.from_pretrained(model_name)
44
 
45
+ # — Hilfsfunktionen —
46
+
47
  def process_prompt(text: str, voice: str):
48
  prompt = f"{voice}: {text}"
49
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
50
  start = torch.tensor([[128259]], dtype=torch.int64)
51
+ end = torch.tensor([[128009, 128260]], dtype=torch.int64)
52
+ ids = torch.cat([start, input_ids, end], dim=1).to(device)
53
  mask = torch.ones_like(ids).to(device)
54
  return ids, mask
55
 
56
  def parse_output(generated_ids: torch.LongTensor):
57
+ """Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
58
+ token_to_find = 128257
59
  token_to_remove = 128258
60
+
61
+ # 1) Finde letztes Start-Token, croppe
62
  idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
63
  if idxs.numel() > 0:
64
+ cut = idxs[-1].item() + 1
65
+ cropped = generated_ids[:, cut:]
66
  else:
67
  cropped = generated_ids
68
+
69
+ # 2) Entferne Padding-Markierungen
70
  rows = []
71
  for row in cropped:
72
+ rows.append(row[row != token_to_remove])
73
+
74
+ # 3) Flache Liste zurückgeben
75
+ return rows[0].tolist()
76
+
77
+ def redistribute_codes(code_list: list[int], snac_model: SNAC):
78
+ """Verteile die Codes auf drei Layer, dekodiere in Audio."""
79
  layer1, layer2, layer3 = [], [], []
80
+ for i in range((len(code_list) + 1) // 7):
81
+ base = code_list[7*i : 7*i+7]
82
  layer1.append(base[0])
83
+ layer2.append(base[1] - 4096)
84
+ layer3.append(base[2] - 2*4096)
85
+ layer3.append(base[3] - 3*4096)
86
+ layer2.append(base[4] - 4*4096)
87
+ layer3.append(base[5] - 5*4096)
88
+ layer3.append(base[6] - 6*4096)
89
+
90
+ dev = next(snac_model.parameters()).device
91
  codes = [
92
+ torch.tensor(layer1, device=dev).unsqueeze(0),
93
+ torch.tensor(layer2, device=dev).unsqueeze(0),
94
+ torch.tensor(layer3, device=dev).unsqueeze(0),
95
  ]
96
+ audio = snac_model.decode(codes)
97
+ return audio.detach().squeeze().cpu().numpy() # float32 @24 kHz
98
 
99
+ # — FastAPI + WebSocket-Endpoint —
100
  app = FastAPI()
101
 
102
  @app.websocket("/ws/tts")
 
105
  try:
106
  while True:
107
  msg = await ws.receive_text()
108
+ data = json.loads(msg)
109
+ text = data.get("text", "")
110
  voice = data.get("voice", "jana")
111
+
112
+ # 1) Prompt → Tokens
113
  ids, mask = process_prompt(text, voice)
114
+
115
+ # 2) Token-Generation (erst klein testen!)
116
+ gen_ids = model.generate(
117
+ input_ids=ids,
118
+ attention_mask=mask,
119
+ max_new_tokens=200, # zum Debuggen klein halten
120
+ do_sample=True,
121
+ temperature=0.7,
122
+ top_p=0.95,
123
+ repetition_penalty=1.1,
124
+ eos_token_id=128258,
125
+ )
126
+
127
+ # 3) Tokens Code-Liste → Audio
128
+ code_list = parse_output(gen_ids)
129
+ audio_np = redistribute_codes(code_list, snac)
130
+
131
+ # 4) PCM16-Bytes und Stream in 0.1s-Chunks
132
+ pcm16 = (audio_np * 32767).astype("int16").tobytes()
133
+ chunk = 2400 * 2 # 2400 samples @24kHz → 0.1s * 2 bytes
134
+ for i in range(0, len(pcm16), chunk):
135
+ await ws.send_bytes(pcm16[i : i+chunk])
136
+ await asyncio.sleep(0.1)
137
+
138
  except WebSocketDisconnect:
139
  print("Client disconnected")
140
  except Exception as e: