Tomtom84 commited on
Commit
0316ec3
·
verified ·
1 Parent(s): b1adbd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -105
app.py CHANGED
@@ -1,113 +1,246 @@
1
- from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
2
- from fastapi.responses import StreamingResponse, JSONResponse
3
- import outetts
4
- import io
5
- import json
6
- import base64
7
- import struct
8
- import os
9
- # Initialize the interface
10
- interface = outetts.Interface(
11
- config=outetts.ModelConfig.auto_config(
12
- model=outetts.Models.VERSION_1_0_SIZE_1B,
13
- # For llama.cpp backend
14
- #backend=outetts.Backend.LLAMACPP,
15
- #quantization=outetts.LlamaCppQuantization.FP16
16
- # For transformers backend
17
- backend=outetts.Backend.HF,
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
 
21
- # Load the default speaker profile
22
- speaker = interface.load_default_speaker("EN-FEMALE-1-NEUTRAL")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- app = FastAPI()
 
 
 
 
 
25
 
26
- @app.get("/")
27
- def greet_json():
28
- return {"Hello": "World!"}
 
 
29
 
30
- @app.websocket("/ws/tts")
31
- async def websocket_tts(websocket: WebSocket):
32
- await websocket.accept()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
- while True:
35
- # Empfange Text-Chunk vom Client
36
- data = await websocket.receive_text()
37
- # Status: Warming up
38
- await websocket.send_text(json.dumps({"generation_status": "Warming up TTS model"}))
39
- output = interface.generate(
40
- config=outetts.GenerationConfig(
41
- text=data,
42
- generation_type=outetts.GenerationType.CHUNKED,
43
- speaker=speaker,
44
- sampler_config=outetts.SamplerConfig(
45
- temperature=0.4
46
- ),
47
- )
 
48
  )
49
- # Status: Generating linguistic features
50
- await websocket.send_text(json.dumps({"generation_status": "Generating linguistic features"}))
51
- # Save to buffer
52
- import uuid
53
- temp_path = f"temp_{uuid.uuid4().hex}.wav"
54
- output.save(temp_path)
55
- chunk_size = 4096
56
- try:
57
- with open(temp_path, "rb") as f:
58
- wav_data = f.read()
59
- # WAV header is typically 44 bytes, but let's detect it robustly
60
- # Find the end of the header (data chunk)
61
- if wav_data[:4] != b'RIFF' or wav_data[8:12] != b'WAVE':
62
- raise ValueError("Not a valid WAV file")
63
- # Find 'data' subchunk
64
- data_offset = wav_data.find(b'data')
65
- if data_offset == -1:
66
- raise ValueError("No 'data' chunk found in WAV file")
67
- header_end = data_offset + 8 # 'data' + size (4 bytes)
68
- wav_header = bytearray(wav_data[:header_end])
69
- pcm_data = wav_data[header_end:]
70
- # Patch header: set data length to 0xFFFFFFFF (unknown/streaming)
71
- wav_header[data_offset+4:data_offset+8] = (0xFFFFFFFF).to_bytes(4, 'little')
72
- # Send header + first PCM chunk
73
- first_chunk = pcm_data[:chunk_size]
74
- audio_b64 = base64.b64encode(wav_header + first_chunk).decode("ascii")
75
- await websocket.send_text(json.dumps({
76
- "data": {
77
- "audio_bytes": audio_b64,
78
- "duration": None,
79
- "request_finished": False
80
- }
81
- }))
82
- # Send rest of PCM data in chunks (without header)
83
- idx = chunk_size
84
- while idx < len(pcm_data):
85
- chunk = pcm_data[idx:idx+chunk_size]
86
- if not chunk:
87
- break
88
- audio_b64 = base64.b64encode(chunk).decode("ascii")
89
- await websocket.send_text(json.dumps({
90
- "data": {
91
- "audio_bytes": audio_b64,
92
- "duration": None,
93
- "request_finished": False
94
- }
95
- }))
96
- idx += chunk_size
97
- finally:
98
- try:
99
- os.remove(temp_path)
100
- except FileNotFoundError:
101
- pass
102
- # Final event
103
- await websocket.send_text(json.dumps({
104
- "data": {
105
- "audio_bytes": "",
106
- "duration": None,
107
- "request_finished": True
108
- }
109
- }))
110
- except WebSocketDisconnect:
111
- pass
112
  except Exception as e:
113
- await websocket.send_text(json.dumps({"error": str(e)}))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from snac import SNAC
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from huggingface_hub import snapshot_download
7
+ from dotenv import load_dotenv
8
+ load_dotenv()
9
+
10
+ # Check if CUDA is available
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ print("Loading SNAC model...")
14
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
15
+ snac_model = snac_model.to(device)
16
+
17
+ model_name = "canopylabs/3b-de-ft-research_release"
18
+ #"canopylabs/orpheus-3b-0.1-ft"
19
+
20
+ # Download only model config and safetensors
21
+ snapshot_download(
22
+ repo_id=model_name,
23
+ allow_patterns=[
24
+ "config.json",
25
+ "*.safetensors",
26
+ "model.safetensors.index.json",
27
+ ],
28
+ ignore_patterns=[
29
+ "optimizer.pt",
30
+ "pytorch_model.bin",
31
+ "training_args.bin",
32
+ "scheduler.pt",
33
+ "tokenizer.json",
34
+ "tokenizer_config.json",
35
+ "special_tokens_map.json",
36
+ "vocab.json",
37
+ "merges.txt",
38
+ "tokenizer.*"
39
+ ]
40
  )
41
 
42
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
43
+ model.to(device)
44
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
45
+ print(f"Orpheus model loaded to {device}")
46
+
47
+ # Process text prompt
48
+ def process_prompt(prompt, voice, tokenizer, device):
49
+ prompt = f"{voice}: {prompt}"
50
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
51
+
52
+ start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
53
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
54
+
55
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
56
+
57
+ # No padding needed for single input
58
+ attention_mask = torch.ones_like(modified_input_ids)
59
+
60
+ return modified_input_ids.to(device), attention_mask.to(device)
61
 
62
+ # Parse output tokens to audio
63
+ def parse_output(generated_ids):
64
+ token_to_find = 128257
65
+ token_to_remove = 128258
66
+
67
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
68
 
69
+ if len(token_indices[1]) > 0:
70
+ last_occurrence_idx = token_indices[1][-1].item()
71
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
72
+ else:
73
+ cropped_tensor = generated_ids
74
 
75
+ processed_rows = []
76
+ for row in cropped_tensor:
77
+ masked_row = row[row != token_to_remove]
78
+ processed_rows.append(masked_row)
79
+
80
+ code_lists = []
81
+ for row in processed_rows:
82
+ row_length = row.size(0)
83
+ new_length = (row_length // 7) * 7
84
+ trimmed_row = row[:new_length]
85
+ trimmed_row = [t - 128266 for t in trimmed_row]
86
+ code_lists.append(trimmed_row)
87
+
88
+ return code_lists[0] # Return just the first one for single sample
89
+
90
+ # Redistribute codes for audio generation
91
+ def redistribute_codes(code_list, snac_model):
92
+ device = next(snac_model.parameters()).device # Get the device of SNAC model
93
+
94
+ layer_1 = []
95
+ layer_2 = []
96
+ layer_3 = []
97
+ for i in range((len(code_list)+1)//7):
98
+ layer_1.append(code_list[7*i])
99
+ layer_2.append(code_list[7*i+1]-4096)
100
+ layer_3.append(code_list[7*i+2]-(2*4096))
101
+ layer_3.append(code_list[7*i+3]-(3*4096))
102
+ layer_2.append(code_list[7*i+4]-(4*4096))
103
+ layer_3.append(code_list[7*i+5]-(5*4096))
104
+ layer_3.append(code_list[7*i+6]-(6*4096))
105
+
106
+ # Move tensors to the same device as the SNAC model
107
+ codes = [
108
+ torch.tensor(layer_1, device=device).unsqueeze(0),
109
+ torch.tensor(layer_2, device=device).unsqueeze(0),
110
+ torch.tensor(layer_3, device=device).unsqueeze(0)
111
+ ]
112
+
113
+ audio_hat = snac_model.decode(codes)
114
+ return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
115
+
116
+ # Main generation function
117
+ @spaces.GPU()
118
+ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
119
+ if not text.strip():
120
+ return None
121
+
122
  try:
123
+ progress(0.1, "Processing text...")
124
+ input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
125
+
126
+ progress(0.3, "Generating speech tokens...")
127
+ with torch.no_grad():
128
+ generated_ids = model.generate(
129
+ input_ids=input_ids,
130
+ attention_mask=attention_mask,
131
+ max_new_tokens=max_new_tokens,
132
+ do_sample=True,
133
+ temperature=temperature,
134
+ top_p=top_p,
135
+ repetition_penalty=repetition_penalty,
136
+ num_return_sequences=1,
137
+ eos_token_id=128258,
138
  )
139
+
140
+ progress(0.6, "Processing speech tokens...")
141
+ code_list = parse_output(generated_ids)
142
+
143
+ progress(0.8, "Converting to audio...")
144
+ audio_samples = redistribute_codes(code_list, snac_model)
145
+
146
+ return (24000, audio_samples) # Return sample rate and audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  except Exception as e:
148
+ print(f"Error generating speech: {e}")
149
+ return None
150
+
151
+ # Examples for the UI
152
+ examples = [
153
+ ["Hey there my name is Tara, <chuckle> and I'm a speech generation model that can sound like a person.", "tara", 0.6, 0.95, 1.1, 1200],
154
+ ["I've also been taught to understand and produce paralinguistic things <sigh> like sighing, or <laugh> laughing, or <yawn> yawning!", "dan", 0.7, 0.95, 1.1, 1200],
155
+ ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... <gasp> well, lets just say a lot of parameters.", "leah", 0.6, 0.9, 1.2, 1200],
156
+ ["Sometimes when I talk too much, I need to <cough> excuse myself. <sniffle> The weather has been quite cold lately.", "leo", 0.65, 0.9, 1.1, 1200],
157
+ ["Public speaking can be challenging. <groan> But with enough practice, anyone can become better at it.", "jess", 0.7, 0.95, 1.1, 1200],
158
+ ["The hike was exhausting but the view from the top was absolutely breathtaking! <sigh> It was totally worth it.", "mia", 0.65, 0.9, 1.15, 1200],
159
+ ["Did you hear that joke? <laugh> I couldn't stop laughing when I first heard it. <chuckle> It's still funny.", "zac", 0.7, 0.95, 1.1, 1200],
160
+ ["After running the marathon, I was so tired <yawn> and needed a long rest. <sigh> But I felt accomplished.", "zoe", 0.6, 0.95, 1.1, 1200]
161
+ ]
162
+
163
+ # Available voices
164
+ VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
165
+
166
+ # Available Emotive Tags
167
+ EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
168
+
169
+ # Create Gradio interface
170
+ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
171
+ gr.Markdown(f"""
172
+ # 🎵 [Orpheus Text-to-Speech](https://github.com/canopyai/Orpheus-TTS)
173
+ Enter your text below and hear it converted to natural-sounding speech with the Orpheus TTS model.
174
+
175
+ ## Tips for better prompts:
176
+ - Add paralinguistic elements like {", ".join(EMOTIVE_TAGS)} or `uhm` for more human-like speech.
177
+ - Longer text prompts generally work better than very short phrases
178
+ - Increasing `repetition_penalty` and `temperature` makes the model speak faster.
179
+ """)
180
+ with gr.Row():
181
+ with gr.Column(scale=3):
182
+ text_input = gr.Textbox(
183
+ label="Text to speak",
184
+ placeholder="Enter your text here...",
185
+ lines=5
186
+ )
187
+ voice = gr.Dropdown(
188
+ choices=VOICES,
189
+ value="tara",
190
+ label="Voice"
191
+ )
192
+
193
+ with gr.Accordion("Advanced Settings", open=False):
194
+ temperature = gr.Slider(
195
+ minimum=0.1, maximum=1.5, value=0.6, step=0.05,
196
+ label="Temperature",
197
+ info="Higher values (0.7-1.0) create more expressive but less stable speech"
198
+ )
199
+ top_p = gr.Slider(
200
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05,
201
+ label="Top P",
202
+ info="Nucleus sampling threshold"
203
+ )
204
+ repetition_penalty = gr.Slider(
205
+ minimum=1.0, maximum=2.0, value=1.1, step=0.05,
206
+ label="Repetition Penalty",
207
+ info="Higher values discourage repetitive patterns"
208
+ )
209
+ max_new_tokens = gr.Slider(
210
+ minimum=100, maximum=2000, value=1200, step=100,
211
+ label="Max Length",
212
+ info="Maximum length of generated audio (in tokens)"
213
+ )
214
+
215
+ with gr.Row():
216
+ submit_btn = gr.Button("Generate Speech", variant="primary")
217
+ clear_btn = gr.Button("Clear")
218
+
219
+ with gr.Column(scale=2):
220
+ audio_output = gr.Audio(label="Generated Speech", type="numpy")
221
+
222
+ # Set up examples
223
+ gr.Examples(
224
+ examples=examples,
225
+ inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
226
+ outputs=audio_output,
227
+ fn=generate_speech,
228
+ cache_examples=True,
229
+ )
230
+
231
+ # Set up event handlers
232
+ submit_btn.click(
233
+ fn=generate_speech,
234
+ inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
235
+ outputs=audio_output
236
+ )
237
+
238
+ clear_btn.click(
239
+ fn=lambda: (None, None),
240
+ inputs=[],
241
+ outputs=[text_input, audio_output]
242
+ )
243
+
244
+ # Launch the app
245
+ if __name__ == "__main__":
246
+ demo.queue().launch(share=False, ssr_mode=False)