Tomtom84 commited on
Commit
a09ea48
·
verified ·
1 Parent(s): 97006e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -245
app.py CHANGED
@@ -1,277 +1,116 @@
1
- import spaces
2
- from snac import SNAC
 
3
  import torch
4
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
5
- import gradio as gr
6
- import os
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
- from huggingface_hub import snapshot_download
9
  from dotenv import load_dotenv
10
- load_dotenv()
 
 
11
 
12
- # Check if HF_TOKEN is available
13
- token = os.getenv("HF_TOKEN")
14
- if token:
15
- from huggingface_hub import login
16
- login(token=token)
17
- else:
18
- print("⚠️ No HF_TOKEN found – gated model will fail.")
19
 
20
- # Check if CUDA is available
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  print("Loading SNAC model...")
24
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
25
- snac_model = snac_model.to(device)
26
 
27
  model_name = "canopylabs/3b-de-ft-research_release"
28
- #"canopylabs/orpheus-3b-0.1-ft"
 
 
 
 
29
 
30
- # Download only model config and safetensors
31
- snapshot_download(
32
- repo_id=model_name,
33
- allow_patterns=[
34
- "config.json",
35
- "*.safetensors",
36
- "model.safetensors.index.json",
37
- ],
38
- ignore_patterns=[
39
- "optimizer.pt",
40
- "pytorch_model.bin",
41
- "training_args.bin",
42
- "scheduler.pt",
43
- "tokenizer.json",
44
- "tokenizer_config.json",
45
- "special_tokens_map.json",
46
- "vocab.json",
47
- "merges.txt",
48
- "tokenizer.*"
49
- ]
50
- )
51
-
52
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
53
- model.to(device)
54
  tokenizer = AutoTokenizer.from_pretrained(model_name)
55
- print(f"Orpheus model loaded to {device}")
56
 
57
- # Process text prompt
58
- def process_prompt(prompt, voice, tokenizer, device):
59
- prompt = f"{voice}: {prompt}"
60
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
61
-
62
- start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
63
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
64
-
65
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
66
-
67
- # No padding needed for single input
68
- attention_mask = torch.ones_like(modified_input_ids)
69
-
70
- return modified_input_ids.to(device), attention_mask.to(device)
71
 
72
- # Parse output tokens to audio
73
- def parse_output(generated_ids):
74
  token_to_find = 128257
75
  token_to_remove = 128258
76
-
77
- token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
78
-
79
- if len(token_indices[1]) > 0:
80
- last_occurrence_idx = token_indices[1][-1].item()
81
- cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
82
  else:
83
- cropped_tensor = generated_ids
84
-
85
- processed_rows = []
86
- for row in cropped_tensor:
87
- masked_row = row[row != token_to_remove]
88
- processed_rows.append(masked_row)
89
-
90
- code_lists = []
91
- for row in processed_rows:
92
- row_length = row.size(0)
93
- new_length = (row_length // 7) * 7
94
- trimmed_row = row[:new_length]
95
- trimmed_row = [t - 128266 for t in trimmed_row]
96
- code_lists.append(trimmed_row)
97
-
98
- return code_lists[0] # Return just the first one for single sample
99
-
100
- # Redistribute codes for audio generation
101
- def redistribute_codes(code_list, snac_model):
102
- device = next(snac_model.parameters()).device # Get the device of SNAC model
103
-
104
- layer_1 = []
105
- layer_2 = []
106
- layer_3 = []
107
- for i in range((len(code_list)+1)//7):
108
- layer_1.append(code_list[7*i])
109
- layer_2.append(code_list[7*i+1]-4096)
110
- layer_3.append(code_list[7*i+2]-(2*4096))
111
- layer_3.append(code_list[7*i+3]-(3*4096))
112
- layer_2.append(code_list[7*i+4]-(4*4096))
113
- layer_3.append(code_list[7*i+5]-(5*4096))
114
- layer_3.append(code_list[7*i+6]-(6*4096))
115
-
116
- # Move tensors to the same device as the SNAC model
117
  codes = [
118
- torch.tensor(layer_1, device=device).unsqueeze(0),
119
- torch.tensor(layer_2, device=device).unsqueeze(0),
120
- torch.tensor(layer_3, device=device).unsqueeze(0)
121
  ]
122
-
123
- audio_hat = snac_model.decode(codes)
124
- return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
125
-
126
- # Main generation function
127
- @spaces.GPU()
128
- def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
129
- if not text.strip():
130
- return None
131
-
132
- try:
133
- progress(0.1, "Processing text...")
134
- input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
135
-
136
- progress(0.3, "Generating speech tokens...")
137
- with torch.no_grad():
138
- generated_ids = model.generate(
139
- input_ids=input_ids,
140
- attention_mask=attention_mask,
141
- max_new_tokens=max_new_tokens,
142
- do_sample=True,
143
- temperature=temperature,
144
- top_p=top_p,
145
- repetition_penalty=repetition_penalty,
146
- num_return_sequences=1,
147
- eos_token_id=128258,
148
- )
149
-
150
- progress(0.6, "Processing speech tokens...")
151
- code_list = parse_output(generated_ids)
152
-
153
- progress(0.8, "Converting to audio...")
154
- audio_samples = redistribute_codes(code_list, snac_model)
155
-
156
- return (24000, audio_samples) # Return sample rate and audio
157
- except Exception as e:
158
- print(f"Error generating speech: {e}")
159
- return None
160
-
161
- # Examples for the UI
162
- examples = [
163
- ["Hey there my name is Tara, <chuckle> and I'm a speech generation model that can sound like a person.", "tara", 0.6, 0.95, 1.1, 1200],
164
- ["I've also been taught to understand and produce paralinguistic things <sigh> like sighing, or <laugh> laughing, or <yawn> yawning!", "dan", 0.7, 0.95, 1.1, 1200],
165
- ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... <gasp> well, lets just say a lot of parameters.", "leah", 0.6, 0.9, 1.2, 1200],
166
- ["Sometimes when I talk too much, I need to <cough> excuse myself. <sniffle> The weather has been quite cold lately.", "leo", 0.65, 0.9, 1.1, 1200],
167
- ["Public speaking can be challenging. <groan> But with enough practice, anyone can become better at it.", "jess", 0.7, 0.95, 1.1, 1200],
168
- ["The hike was exhausting but the view from the top was absolutely breathtaking! <sigh> It was totally worth it.", "mia", 0.65, 0.9, 1.15, 1200],
169
- ["Did you hear that joke? <laugh> I couldn't stop laughing when I first heard it. <chuckle> It's still funny.", "zac", 0.7, 0.95, 1.1, 1200],
170
- ["After running the marathon, I was so tired <yawn> and needed a long rest. <sigh> But I felt accomplished.", "zoe", 0.6, 0.95, 1.1, 1200]
171
- ]
172
-
173
- # Available voices
174
- VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
175
-
176
- # Available Emotive Tags
177
- EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
178
-
179
- # Create Gradio interface
180
- with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
181
- gr.Markdown(f"""
182
- # 🎵 [Orpheus Text-to-Speech](https://github.com/canopyai/Orpheus-TTS)
183
- Enter your text below and hear it converted to natural-sounding speech with the Orpheus TTS model.
184
-
185
- ## Tips for better prompts:
186
- - Add paralinguistic elements like {", ".join(EMOTIVE_TAGS)} or `uhm` for more human-like speech.
187
- - Longer text prompts generally work better than very short phrases
188
- - Increasing `repetition_penalty` and `temperature` makes the model speak faster.
189
- """)
190
- with gr.Row():
191
- with gr.Column(scale=3):
192
- text_input = gr.Textbox(
193
- label="Text to speak",
194
- placeholder="Enter your text here...",
195
- lines=5
196
- )
197
- voice = gr.Dropdown(
198
- choices=VOICES,
199
- value="tara",
200
- label="Voice"
201
- )
202
-
203
- with gr.Accordion("Advanced Settings", open=False):
204
- temperature = gr.Slider(
205
- minimum=0.1, maximum=1.5, value=0.6, step=0.05,
206
- label="Temperature",
207
- info="Higher values (0.7-1.0) create more expressive but less stable speech"
208
- )
209
- top_p = gr.Slider(
210
- minimum=0.1, maximum=1.0, value=0.95, step=0.05,
211
- label="Top P",
212
- info="Nucleus sampling threshold"
213
- )
214
- repetition_penalty = gr.Slider(
215
- minimum=1.0, maximum=2.0, value=1.1, step=0.05,
216
- label="Repetition Penalty",
217
- info="Higher values discourage repetitive patterns"
218
- )
219
- max_new_tokens = gr.Slider(
220
- minimum=100, maximum=2000, value=1200, step=100,
221
- label="Max Length",
222
- info="Maximum length of generated audio (in tokens)"
223
- )
224
-
225
- with gr.Row():
226
- submit_btn = gr.Button("Generate Speech", variant="primary")
227
- clear_btn = gr.Button("Clear")
228
-
229
- with gr.Column(scale=2):
230
- audio_output = gr.Audio(label="Generated Speech", type="numpy")
231
-
232
-
233
- # Set up event handlers
234
- submit_btn.click(
235
- fn=generate_speech,
236
- inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
237
- outputs=audio_output
238
- )
239
-
240
- clear_btn.click(
241
- fn=lambda: (None, None),
242
- inputs=[],
243
- outputs=[text_input, audio_output]
244
- )
245
-
246
- # Enable queuing for Gradio
247
- # Create FastAPI app and mount Gradio ASGI app (HTTP mode)
248
 
 
249
  app = FastAPI()
250
- app.mount("/", demo.app)
251
 
252
- # WebSocket TTS endpoint
253
  @app.websocket("/ws/tts")
254
- async def websocket_tts(websocket: WebSocket):
255
- await websocket.accept()
256
  try:
257
  while True:
258
- msg = await websocket.receive_text()
259
  data = json.loads(msg)
260
  text = data.get("text", "")
261
- voice = data.get("voice", VOICES[0])
262
- _, audio = generate_speech(text, voice, 0.7, 0.95, 1.1, 1200)
263
- chunk_size = 2400 # 0.1s at 24kHz
264
- for i in range(0, len(audio), chunk_size):
265
- chunk = audio[i:i+chunk_size]
266
- await websocket.send_bytes(chunk.tobytes())
267
- await websocket.send_text("__END__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  except WebSocketDisconnect:
269
- print("Client disconnected from /ws/tts")
 
 
 
270
 
271
- # Launch when run directly
272
- def main():
273
  import uvicorn
274
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
275
-
276
- if __name__ == "__main__":
277
- main()
 
1
+ import os
2
+ import json
3
+ import asyncio
4
  import torch
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
 
 
 
6
  from dotenv import load_dotenv
7
+ from snac import SNAC
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from huggingface_hub import login
10
 
11
+ # Environment & HF‑Auth
12
+ load_dotenv()
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+ if HF_TOKEN:
15
+ login(token=HF_TOKEN)
 
 
16
 
17
+ # Device & Modelle laden —
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  print("Loading SNAC model...")
21
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
 
22
 
23
  model_name = "canopylabs/3b-de-ft-research_release"
24
+ print("Loading Orpheus model...")
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_name, torch_dtype=torch.bfloat16
27
+ ).to(device)
28
+ model.config.pad_token_id = model.config.eos_token_id
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
31
 
32
+ # Hilfsfunktionen
33
+ def process_prompt(text: str, voice: str):
34
+ prompt = f"{voice}: {text}"
35
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
36
+ start = torch.tensor([[128259]], dtype=torch.int64)
37
+ end = torch.tensor([[128009, 128260]], dtype=torch.int64)
38
+ ids = torch.cat([start, input_ids, end], dim=1).to(device)
39
+ mask = torch.ones_like(ids).to(device)
40
+ return ids, mask
 
 
 
 
 
41
 
42
+ def parse_output(generated_ids: torch.LongTensor):
 
43
  token_to_find = 128257
44
  token_to_remove = 128258
45
+ idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
46
+ if idxs.numel() > 0:
47
+ last = idxs[-1].item()
48
+ cropped = generated_ids[:, last+1:]
 
 
49
  else:
50
+ cropped = generated_ids
51
+ # remove padding token markers
52
+ rows = []
53
+ for row in cropped:
54
+ row = row[row != token_to_remove]
55
+ rows.append(row)
56
+ flat = rows[0].tolist()
57
+ # adjust and regroup
58
+ layer1, layer2, layer3 = [], [], []
59
+ for i in range(len(flat)//7):
60
+ base = flat[7*i:7*i+7]
61
+ layer1.append(base[0])
62
+ layer2.append(base[1]-4096)
63
+ layer3.extend([base[2]-(2*4096), base[3]-(3*4096)])
64
+ layer2.append(base[4]-4*4096)
65
+ layer3.extend([base[5]-(5*4096), base[6]-(6*4096)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  codes = [
67
+ torch.tensor(layer1, device=device).unsqueeze(0),
68
+ torch.tensor(layer2, device=device).unsqueeze(0),
69
+ torch.tensor(layer3, device=device).unsqueeze(0),
70
  ]
71
+ audio = snac.decode(codes).detach().squeeze().cpu().numpy()
72
+ return audio # float32 numpy at 24000 Hz
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # — FastAPI + WebSocket-Endpoint —
75
  app = FastAPI()
 
76
 
 
77
  @app.websocket("/ws/tts")
78
+ async def tts_ws(ws: WebSocket):
79
+ await ws.accept()
80
  try:
81
  while True:
82
+ msg = await ws.receive_text()
83
  data = json.loads(msg)
84
  text = data.get("text", "")
85
+ voice = data.get("voice", "jana")
86
+ # Generate tokens
87
+ ids, mask = process_prompt(text, voice)
88
+ with torch.no_grad():
89
+ gen_ids = model.generate(
90
+ input_ids=ids,
91
+ attention_mask=mask,
92
+ max_new_tokens=1200,
93
+ do_sample=True,
94
+ temperature=0.7,
95
+ top_p=0.95,
96
+ repetition_penalty=1.1,
97
+ eos_token_id=128258,
98
+ )
99
+ # Convert to waveform
100
+ audio = parse_output(gen_ids)
101
+ # PCM16 conversion & chunking
102
+ pcm16 = (audio * 32767).astype('int16').tobytes()
103
+ # 0.1 s @24 kHz = 2400 samples = 4800 bytes
104
+ chunk_size = 2400 * 2
105
+ for i in range(0, len(pcm16), chunk_size):
106
+ await ws.send_bytes(pcm16[i:i+chunk_size])
107
+ await asyncio.sleep(0.1) # pacing
108
  except WebSocketDisconnect:
109
+ print("Client disconnected")
110
+ except Exception as e:
111
+ print("Error in /ws/tts:", e)
112
+ await ws.close(code=1011)
113
 
114
+ if __name__ == "__main__":
 
115
  import uvicorn
116
  uvicorn.run("app:app", host="0.0.0.0", port=7860)