Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -41,18 +41,19 @@ class AudioMask(LogitsProcessor):
|
|
41 |
end_token = start_token + 4096
|
42 |
allowed_audio = torch.arange(start_token, end_token, device=self.allow.device)
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
|
53 |
if self.sent_blocks: # ab 1. Block EOS zulassen
|
54 |
allowed = torch.cat([allowed, self.eos])
|
55 |
|
|
|
56 |
mask = logits.new_full(logits.shape, float("-inf"))
|
57 |
mask[:, allowed] = 0
|
58 |
return logits + mask
|
@@ -161,13 +162,18 @@ async def tts(ws: WebSocket):
|
|
161 |
if t == NEW_BLOCK:
|
162 |
buf.clear()
|
163 |
continue
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
buf
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
except (StopIteration, WebSocketDisconnect):
|
173 |
pass
|
|
|
41 |
end_token = start_token + 4096
|
42 |
allowed_audio = torch.arange(start_token, end_token, device=self.allow.device)
|
43 |
|
44 |
+
# Only allow NEW_BLOCK if buffer is full, otherwise only allow audio tokens
|
45 |
+
if self.buffer_pos == 7:
|
46 |
+
allowed = torch.cat([
|
47 |
+
torch.tensor([NEW_BLOCK], device=self.allow.device),
|
48 |
+
allowed_audio
|
49 |
+
])
|
50 |
+
else:
|
51 |
+
allowed = allowed_audio # Only allow audio tokens
|
52 |
|
53 |
if self.sent_blocks: # ab 1. Block EOS zulassen
|
54 |
allowed = torch.cat([allowed, self.eos])
|
55 |
|
56 |
+
mask = logits.new_full(logits.shape, float("-inf"))
|
57 |
mask = logits.new_full(logits.shape, float("-inf"))
|
58 |
mask[:, allowed] = 0
|
59 |
return logits + mask
|
|
|
162 |
if t == NEW_BLOCK:
|
163 |
buf.clear()
|
164 |
continue
|
165 |
+
# Only append if it's an audio token
|
166 |
+
if t >= AUDIO_BASE and t < AUDIO_BASE + AUDIO_SPAN:
|
167 |
+
buf.append(t - AUDIO_BASE)
|
168 |
+
# masker.buffer_pos += 1 # Removed increment here
|
169 |
+
if len(buf) == 7:
|
170 |
+
await ws.send_bytes(decode_block(buf))
|
171 |
+
buf.clear()
|
172 |
+
masker.sent_blocks = 1 # ab jetzt EOS zulässig
|
173 |
+
# masker.buffer_pos = 0 # Removed reset here
|
174 |
+
else:
|
175 |
+
# Optional: Log unexpected tokens
|
176 |
+
print(f"DEBUG: Skipping non-audio token: {t}", flush=True)
|
177 |
|
178 |
except (StopIteration, WebSocketDisconnect):
|
179 |
pass
|