Tomtom84 commited on
Commit
3d65908
·
verified ·
1 Parent(s): 9066efe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -33,11 +33,22 @@ class AudioMask(LogitsProcessor):
33
  ])
34
  self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
35
  self.sent_blocks = 0
 
36
 
37
  def __call__(self, input_ids, logits):
38
- allowed = self.allow
 
 
 
 
 
 
 
 
 
39
  if self.sent_blocks: # ab 1. Block EOS zulassen
40
  allowed = torch.cat([allowed, self.eos])
 
41
  mask = logits.new_full(logits.shape, float("-inf"))
42
  mask[:, allowed] = 0
43
  return logits + mask
@@ -105,6 +116,7 @@ async def tts(ws: WebSocket):
105
  offset_len = ids.size(1) # wie viele Tokens existieren schon
106
  last_tok = None
107
  buf = []
 
108
 
109
  while True:
110
  # --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
@@ -143,10 +155,12 @@ async def tts(ws: WebSocket):
143
  buf.clear()
144
  continue
145
  buf.append(t - AUDIO_BASE)
 
146
  if len(buf) == 7:
147
  await ws.send_bytes(decode_block(buf))
148
  buf.clear()
149
  masker.sent_blocks = 1 # ab jetzt EOS zulässig
 
150
 
151
  except (StopIteration, WebSocketDisconnect):
152
  pass
 
33
  ])
34
  self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
35
  self.sent_blocks = 0
36
+ self.buffer_pos = 0 # Added buffer position
37
 
38
  def __call__(self, input_ids, logits):
39
+ # Calculate allowed tokens based on buffer position
40
+ start_token = AUDIO_BASE + self.buffer_pos * 4096
41
+ end_token = start_token + 4096
42
+ allowed_audio = torch.arange(start_token, end_token, device=self.allow.device)
43
+
44
+ allowed = torch.cat([
45
+ torch.tensor([NEW_BLOCK], device=self.allow.device),
46
+ allowed_audio
47
+ ])
48
+
49
  if self.sent_blocks: # ab 1. Block EOS zulassen
50
  allowed = torch.cat([allowed, self.eos])
51
+
52
  mask = logits.new_full(logits.shape, float("-inf"))
53
  mask[:, allowed] = 0
54
  return logits + mask
 
116
  offset_len = ids.size(1) # wie viele Tokens existieren schon
117
  last_tok = None
118
  buf = []
119
+ masker.buffer_pos = 0 # Initialize buffer position for masker
120
 
121
  while True:
122
  # --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
 
155
  buf.clear()
156
  continue
157
  buf.append(t - AUDIO_BASE)
158
+ masker.buffer_pos += 1 # Increment buffer position
159
  if len(buf) == 7:
160
  await ws.send_bytes(decode_block(buf))
161
  buf.clear()
162
  masker.sent_blocks = 1 # ab jetzt EOS zulässig
163
+ masker.buffer_pos = 0 # Reset buffer position after sending a block
164
 
165
  except (StopIteration, WebSocketDisconnect):
166
  pass