Tomtom84 commited on
Commit
325e9ba
·
verified ·
1 Parent(s): 9a2b198

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -29
app.py CHANGED
@@ -11,11 +11,11 @@ if HF_TOKEN:
11
  login(HF_TOKEN)
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2‑Bug
15
 
16
  # 1) Konstanten -------------------------------------------------------
17
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
18
- CHUNK_TOKENS = 7
19
  START_TOKEN = 128259
20
  NEW_BLOCK = 128257
21
  EOS_TOKEN = 128258
@@ -101,45 +101,107 @@ async def tts(ws: WebSocket):
101
 
102
  ids, attn = build_prompt(text, voice)
103
  past = None
104
- offset_len = ids.size(1)
105
- past = None
106
- last_tok = None
107
- buf = []
108
 
109
  while True:
110
- next_cache_pos = torch.tensor([past.get_seq_length()], device=device)
111
- gen = model.generate(
112
- input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
113
- attention_mask = attn if past is None else None,
114
- past_key_values = past,
115
- cache_position = next_cache_pos,
116
- max_new_tokens = CHUNK_TOKENS,
117
- logits_processor=[masker],
118
- do_sample=True, temperature=0.7, top_p=0.95,
119
- use_cache=True, return_dict_in_generate=True,
120
- return_legacy_cache=False
121
- )
122
-
123
- # neu erzeugte Tokens hinter dem bisherigen Ende
124
- new_tokens = gen.sequences[0, offset_len:].tolist()
125
- if not new_tokens:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  break
127
 
128
- offset_len += len(new_tokens) # Cache ist jetzt größer
129
- past = gen.past_key_values # Cache zurück für nächste Runde
130
- last_tok = new_tokens[-1]
131
 
132
- for t in new_tokens:
 
 
 
 
 
133
  if t == EOS_TOKEN:
134
- raise StopIteration
 
 
135
  if t == NEW_BLOCK:
136
  buf.clear()
137
  continue
138
- buf.append(t - AUDIO_BASE)
 
 
 
 
 
 
 
139
  if len(buf) == 7:
140
  await ws.send_bytes(decode_block(buf))
141
  buf.clear()
142
- masker.sent_blocks = 1 # ab jetzt darf EOS
 
 
 
 
 
 
 
 
 
143
 
144
  except (StopIteration, WebSocketDisconnect):
145
  pass
 
11
  login(HF_TOKEN)
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ #torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2‑Bug
15
 
16
  # 1) Konstanten -------------------------------------------------------
17
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
18
+ CHUNK_TOKENS = 50
19
  START_TOKEN = 128259
20
  NEW_BLOCK = 128257
21
  EOS_TOKEN = 128258
 
101
 
102
  ids, attn = build_prompt(text, voice)
103
  past = None
104
+ ids, attn = build_prompt(text, voice)
105
+ past = None # Holds the DynamicCache object from past_key_values
106
+ buf = []
107
+ last_tok = None # Initialize last_tok
108
 
109
  while True:
110
+ # Determine inputs for this iteration
111
+ if past is None:
112
+ # First iteration: Use the full prompt
113
+ current_input_ids = ids
114
+ current_attn_mask = attn
115
+ # DO NOT pass cache_position on the first run
116
+ current_cache_position = None
117
+ else:
118
+ # Subsequent iterations: Use only the last token
119
+ if last_tok is None:
120
+ print("Error: last_tok is None before subsequent generate call.")
121
+ break # Should not happen if generation proceeded
122
+ current_input_ids = torch.tensor([[last_tok]], device=device)
123
+ current_attn_mask = None # Not needed when past_key_values is provided
124
+ # DO NOT pass cache_position; let DynamicCache handle it
125
+ current_cache_position = None
126
+
127
+ # --- Call model.generate ---
128
+ try:
129
+ gen = model.generate(
130
+ input_ids=current_input_ids,
131
+ attention_mask=current_attn_mask,
132
+ past_key_values=past,
133
+ cache_position=current_cache_position, # Will be None after first iteration
134
+ max_new_tokens=CHUNK_TOKENS,
135
+ logits_processor=[masker],
136
+ do_sample=True, temperature=0.7, top_p=0.95,
137
+ use_cache=True,
138
+ return_dict_in_generate=True,
139
+ return_legacy_cache=False # Ensures DynamicCache
140
+ )
141
+ except Exception as e:
142
+ print(f"❌ Error during model.generate: {e}")
143
+ import traceback
144
+ traceback.print_exc()
145
+ break # Exit loop on generation error
146
+
147
+ # --- Process Output ---
148
+ # Get the full sequence generated *up to this point*
149
+ full_sequence_now = gen.sequences # Get the sequence tensor
150
+
151
+ # Determine the sequence length *before* this generation call using the cache
152
+ # If past is None, the previous length was the initial prompt length
153
+ prev_seq_len = past.get_seq_length() if past is not None else ids.shape
154
+
155
+ # The new tokens are those generated *in this call*
156
+ # These appear *after* the previously cached sequence length
157
+ # Ensure slicing is correct even if no new tokens are generated
158
+ if full_sequence_now.shape > prev_seq_len:
159
+ new_token_ids = full_sequence_now[prev_seq_len:]
160
+ new = new_token_ids.tolist() # Convert tensor to list
161
+ else:
162
+ new = [] # No new tokens generated
163
+
164
+ if not new: # If no new tokens were generated, stop
165
+ print("No new tokens generated, stopping.")
166
  break
167
 
168
+ # Update past_key_values for the *next* iteration
169
+ past = gen.past_key_values # Update the cache state
 
170
 
171
+ # Get the very last token generated in *this* call for the *next* input
172
+ last_tok = new[-1]
173
+
174
+ # ----- Token‑Handling (process the 'new' list) -----
175
+ eos_found = False
176
+ for t in new:
177
  if t == EOS_TOKEN:
178
+ print("EOS token encountered.")
179
+ eos_found = True
180
+ break # Stop processing tokens in this chunk
181
  if t == NEW_BLOCK:
182
  buf.clear()
183
  continue
184
+ # Check if token is within the expected audio range
185
+ if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
186
+ buf.append(t - AUDIO_BASE)
187
+ else:
188
+ # Log unexpected tokens if necessary
189
+ # print(f"Warning: Generated token {t} outside expected audio range.")
190
+ pass # Ignore unexpected tokens for now
191
+
192
  if len(buf) == 7:
193
  await ws.send_bytes(decode_block(buf))
194
  buf.clear()
195
+ # Allow EOS only after the first full block is sent
196
+ if not masker.sent_blocks:
197
+ masker.sent_blocks = 1
198
+
199
+ if eos_found:
200
+ # Handle any remaining buffer content if needed (e.g., log incomplete block)
201
+ if len(buf) > 0:
202
+ print(f"Warning: Incomplete audio block at EOS: {len(buf)} tokens. Discarding.")
203
+ buf.clear()
204
+ break # Exit the while loop
205
 
206
  except (StopIteration, WebSocketDisconnect):
207
  pass