Tomtom84 commited on
Commit
7d18470
·
verified ·
1 Parent(s): bb5c241

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -41,18 +41,19 @@ class AudioMask(LogitsProcessor):
41
  end_token = start_token + 4096
42
  allowed_audio = torch.arange(start_token, end_token, device=self.allow.device)
43
 
44
- allowed = torch.cat([
45
- torch.tensor([NEW_BLOCK], device=self.allow.device),
46
- allowed_audio
47
- ])
48
-
49
- # Penalize NEW_BLOCK if buffer is not empty
50
- if self.buffer_pos > 0:
51
- logits[:, NEW_BLOCK] = float("-inf") # Apply a large negative penalty
52
 
53
  if self.sent_blocks: # ab 1. Block EOS zulassen
54
  allowed = torch.cat([allowed, self.eos])
55
 
 
56
  mask = logits.new_full(logits.shape, float("-inf"))
57
  mask[:, allowed] = 0
58
  return logits + mask
@@ -161,13 +162,18 @@ async def tts(ws: WebSocket):
161
  if t == NEW_BLOCK:
162
  buf.clear()
163
  continue
164
- buf.append(t - AUDIO_BASE)
165
- # masker.buffer_pos += 1 # Removed increment here
166
- if len(buf) == 7:
167
- await ws.send_bytes(decode_block(buf))
168
- buf.clear()
169
- masker.sent_blocks = 1 # ab jetzt EOS zulässig
170
- # masker.buffer_pos = 0 # Removed reset here
 
 
 
 
 
171
 
172
  except (StopIteration, WebSocketDisconnect):
173
  pass
 
41
  end_token = start_token + 4096
42
  allowed_audio = torch.arange(start_token, end_token, device=self.allow.device)
43
 
44
+ # Only allow NEW_BLOCK if buffer is full, otherwise only allow audio tokens
45
+ if self.buffer_pos == 7:
46
+ allowed = torch.cat([
47
+ torch.tensor([NEW_BLOCK], device=self.allow.device),
48
+ allowed_audio
49
+ ])
50
+ else:
51
+ allowed = allowed_audio # Only allow audio tokens
52
 
53
  if self.sent_blocks: # ab 1. Block EOS zulassen
54
  allowed = torch.cat([allowed, self.eos])
55
 
56
+ mask = logits.new_full(logits.shape, float("-inf"))
57
  mask = logits.new_full(logits.shape, float("-inf"))
58
  mask[:, allowed] = 0
59
  return logits + mask
 
162
  if t == NEW_BLOCK:
163
  buf.clear()
164
  continue
165
+ # Only append if it's an audio token
166
+ if t >= AUDIO_BASE and t < AUDIO_BASE + AUDIO_SPAN:
167
+ buf.append(t - AUDIO_BASE)
168
+ # masker.buffer_pos += 1 # Removed increment here
169
+ if len(buf) == 7:
170
+ await ws.send_bytes(decode_block(buf))
171
+ buf.clear()
172
+ masker.sent_blocks = 1 # ab jetzt EOS zulässig
173
+ # masker.buffer_pos = 0 # Removed reset here
174
+ else:
175
+ # Optional: Log unexpected tokens
176
+ print(f"DEBUG: Skipping non-audio token: {t}", flush=True)
177
 
178
  except (StopIteration, WebSocketDisconnect):
179
  pass