Tomtom84 commited on
Commit
a0cc672
·
verified ·
1 Parent(s): 5d55203

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -38
app.py CHANGED
@@ -35,28 +35,11 @@ class AudioMask(LogitsProcessor):
35
  self.sent_blocks = 0
36
  self.buffer_pos = 0 # Added buffer position
37
 
38
- def __call__(self, input_ids, logits):
39
- # Calculate allowed tokens based on buffer position
40
- start_token = AUDIO_BASE + self.buffer_pos * 4096
41
- end_token = start_token + 4096
42
- allowed_audio = torch.arange(start_token, end_token, device=self.allow.device)
43
-
44
- # Only allow NEW_BLOCK if buffer is full, otherwise only allow audio tokens
45
- if self.buffer_pos == 7:
46
- allowed = torch.cat([
47
- torch.tensor([NEW_BLOCK], device=self.allow.device),
48
- allowed_audio
49
- ])
50
- else:
51
- allowed = allowed_audio # Only allow audio tokens
52
-
53
- if self.sent_blocks: # ab 1. Block EOS zulassen
54
- allowed = torch.cat([allowed, self.eos])
55
-
56
- mask = logits.new_full(logits.shape, float("-inf"))
57
- mask = logits.new_full(logits.shape, float("-inf"))
58
- mask[:, allowed] = 0
59
- return logits + mask
60
 
61
  # 3) FastAPI Grundgerüst ---------------------------------------------
62
  app = FastAPI()
@@ -94,7 +77,7 @@ def build_prompt(text: str, voice: str):
94
 
95
  def decode_block(block7: list[int]) -> bytes:
96
  l1,l2,l3=[],[],[]
97
- l1.append(block7[0] - (0 * 4096)) # Subtract AUDIO_BASE + position 0 offset
98
  l2.append(block7[1] + (1 * 4096)) # Subtract AUDIO_BASE + position 1 offset
99
  l3 += [block7[2] + (2 * 4096), block7[3] + (3 * 4096)] # Subtract AUDIO_BASE + position offsets
100
  l2.append(block7[4] + (4 * 4096)) # Subtract AUDIO_BASE + position 4 offset
@@ -120,13 +103,11 @@ async def tts(ws: WebSocket):
120
  past = None
121
  offset_len = ids.size(1) # wie viele Tokens existieren schon
122
  last_tok = None
123
- buf = []
124
  # masker.buffer_pos = 0 # Removed initialization here
 
125
 
126
  while True:
127
- # Update buffer_pos based on current buffer length before generation
128
- masker.buffer_pos = len(buf)
129
-
130
  # --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
131
  gen = model.generate(
132
  input_ids = ids, # Always use full sequence
@@ -164,17 +145,14 @@ async def tts(ws: WebSocket):
164
  continue
165
  # Only append if it's an audio token
166
  # Only append if it's an audio token
167
- if t >= AUDIO_BASE and t < AUDIO_BASE + AUDIO_SPAN:
168
- buf.append(t - AUDIO_BASE) # Append token relative to AUDIO_BASE
169
- # masker.buffer_pos += 1 # Removed increment here
170
- if len(buf) == 7:
171
- await ws.send_bytes(decode_block(buf))
172
- buf.clear()
173
- masker.sent_blocks = 1 # ab jetzt EOS zulässig
174
- # masker.buffer_pos = 0 # Removed reset here
175
- else:
176
- # Optional: Log unexpected tokens
177
- print(f"DEBUG: Skipping non-audio token: {t}", flush=True)
178
 
179
  except (StopIteration, WebSocketDisconnect):
180
  pass
 
35
  self.sent_blocks = 0
36
  self.buffer_pos = 0 # Added buffer position
37
 
38
+ def __call__(self, input_ids, scores):
39
+ allow = torch.cat([self.allow, self.eos]) # Reverted masking logic
40
+ mask = torch.full_like(scores, float("-inf"))
41
+ mask[:, allow] = 0
42
+ return scores + mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # 3) FastAPI Grundgerüst ---------------------------------------------
45
  app = FastAPI()
 
77
 
78
  def decode_block(block7: list[int]) -> bytes:
79
  l1,l2,l3=[],[],[]
80
+ l1.append(block7[0] + (0 * 4096)) # Subtract AUDIO_BASE + position 0 offset
81
  l2.append(block7[1] + (1 * 4096)) # Subtract AUDIO_BASE + position 1 offset
82
  l3 += [block7[2] + (2 * 4096), block7[3] + (3 * 4096)] # Subtract AUDIO_BASE + position offsets
83
  l2.append(block7[4] + (4 * 4096)) # Subtract AUDIO_BASE + position 4 offset
 
103
  past = None
104
  offset_len = ids.size(1) # wie viele Tokens existieren schon
105
  last_tok = None
106
+ buf = []
107
  # masker.buffer_pos = 0 # Removed initialization here
108
+ # Removed buffer_pos update before generation
109
 
110
  while True:
 
 
 
111
  # --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
112
  gen = model.generate(
113
  input_ids = ids, # Always use full sequence
 
145
  continue
146
  # Only append if it's an audio token
147
  # Only append if it's an audio token
148
+ buf.append(t - AUDIO_BASE) # Reverted to appending relative token
149
+ # masker.buffer_pos += 1 # Removed increment here
150
+ if len(buf) == 7:
151
+ await ws.send_bytes(decode_block(buf))
152
+ buf.clear()
153
+ masker.sent_blocks = 1 # ab jetzt EOS zulässig
154
+ # masker.buffer_pos = 0 # Removed reset here
155
+ # Removed else block for skipping non-audio tokens
 
 
 
156
 
157
  except (StopIteration, WebSocketDisconnect):
158
  pass