Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -33,11 +33,22 @@ class AudioMask(LogitsProcessor):
|
|
33 |
])
|
34 |
self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
|
35 |
self.sent_blocks = 0
|
|
|
36 |
|
37 |
def __call__(self, input_ids, logits):
|
38 |
-
allowed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
if self.sent_blocks: # ab 1. Block EOS zulassen
|
40 |
allowed = torch.cat([allowed, self.eos])
|
|
|
41 |
mask = logits.new_full(logits.shape, float("-inf"))
|
42 |
mask[:, allowed] = 0
|
43 |
return logits + mask
|
@@ -105,6 +116,7 @@ async def tts(ws: WebSocket):
|
|
105 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
106 |
last_tok = None
|
107 |
buf = []
|
|
|
108 |
|
109 |
while True:
|
110 |
# --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
|
@@ -143,10 +155,12 @@ async def tts(ws: WebSocket):
|
|
143 |
buf.clear()
|
144 |
continue
|
145 |
buf.append(t - AUDIO_BASE)
|
|
|
146 |
if len(buf) == 7:
|
147 |
await ws.send_bytes(decode_block(buf))
|
148 |
buf.clear()
|
149 |
masker.sent_blocks = 1 # ab jetzt EOS zulässig
|
|
|
150 |
|
151 |
except (StopIteration, WebSocketDisconnect):
|
152 |
pass
|
|
|
33 |
])
|
34 |
self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
|
35 |
self.sent_blocks = 0
|
36 |
+
self.buffer_pos = 0 # Added buffer position
|
37 |
|
38 |
def __call__(self, input_ids, logits):
|
39 |
+
# Calculate allowed tokens based on buffer position
|
40 |
+
start_token = AUDIO_BASE + self.buffer_pos * 4096
|
41 |
+
end_token = start_token + 4096
|
42 |
+
allowed_audio = torch.arange(start_token, end_token, device=self.allow.device)
|
43 |
+
|
44 |
+
allowed = torch.cat([
|
45 |
+
torch.tensor([NEW_BLOCK], device=self.allow.device),
|
46 |
+
allowed_audio
|
47 |
+
])
|
48 |
+
|
49 |
if self.sent_blocks: # ab 1. Block EOS zulassen
|
50 |
allowed = torch.cat([allowed, self.eos])
|
51 |
+
|
52 |
mask = logits.new_full(logits.shape, float("-inf"))
|
53 |
mask[:, allowed] = 0
|
54 |
return logits + mask
|
|
|
116 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
117 |
last_tok = None
|
118 |
buf = []
|
119 |
+
masker.buffer_pos = 0 # Initialize buffer position for masker
|
120 |
|
121 |
while True:
|
122 |
# --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
|
|
|
155 |
buf.clear()
|
156 |
continue
|
157 |
buf.append(t - AUDIO_BASE)
|
158 |
+
masker.buffer_pos += 1 # Increment buffer position
|
159 |
if len(buf) == 7:
|
160 |
await ws.send_bytes(decode_block(buf))
|
161 |
buf.clear()
|
162 |
masker.sent_blocks = 1 # ab jetzt EOS zulässig
|
163 |
+
masker.buffer_pos = 0 # Reset buffer position after sending a block
|
164 |
|
165 |
except (StopIteration, WebSocketDisconnect):
|
166 |
pass
|