Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,7 @@ model = None
|
|
18 |
snac = None
|
19 |
masker = None
|
20 |
stopping_criteria = None
|
|
|
21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
|
23 |
# 0) Login + Device ---------------------------------------------------
|
@@ -33,7 +34,7 @@ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
|
33 |
# CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
|
34 |
START_TOKEN = 128259
|
35 |
NEW_BLOCK = 128257
|
36 |
-
EOS_TOKEN = 128258
|
37 |
AUDIO_BASE = 128266
|
38 |
AUDIO_SPAN = 4096 * 7 # 28672 Codes
|
39 |
CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
|
@@ -41,45 +42,61 @@ CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
|
|
41 |
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
|
42 |
|
43 |
# 2) Logit‑Mask -------------------------------------------------------
|
|
|
44 |
class AudioMask(LogitsProcessor):
|
45 |
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
|
46 |
super().__init__()
|
|
|
|
|
|
|
|
|
47 |
# Allow NEW_BLOCK and all valid audio tokens initially
|
48 |
-
self.allow = torch.cat([
|
49 |
-
|
50 |
-
|
51 |
-
], dim=0)
|
52 |
-
self.eos = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
|
53 |
-
self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
|
54 |
self.sent_blocks = 0 # State: Number of audio blocks sent
|
55 |
|
56 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
57 |
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
|
|
|
|
|
58 |
mask = torch.full_like(scores, float("-inf"))
|
|
|
59 |
mask[:, current_allow] = 0
|
|
|
60 |
return scores + mask
|
61 |
|
62 |
def reset(self):
|
|
|
63 |
self.sent_blocks = 0
|
64 |
|
65 |
# 3) StoppingCriteria für EOS ---------------------------------------
|
|
|
66 |
class EosStoppingCriteria(StoppingCriteria):
|
67 |
def __init__(self, eos_token_id: int):
|
68 |
self.eos_token_id = eos_token_id
|
|
|
|
|
69 |
|
70 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
|
|
|
71 |
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
|
|
|
72 |
return True
|
73 |
return False
|
74 |
|
75 |
# 4) Benutzerdefinierter AudioStreamer -------------------------------
|
76 |
class AudioStreamer(BaseStreamer):
|
77 |
-
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
|
78 |
self.ws = ws
|
79 |
self.snac = snac_decoder
|
80 |
self.masker = audio_mask
|
81 |
self.loop = loop
|
82 |
self.device = target_device
|
|
|
83 |
self.buf: list[int] = []
|
84 |
self.tasks = set()
|
85 |
|
@@ -105,7 +122,6 @@ class AudioStreamer(BaseStreamer):
|
|
105 |
code_val_6 = block7[6] % CODEBOOK_SIZE
|
106 |
|
107 |
# --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
|
108 |
-
# Using the structure from the user's previous version, believed to be correct
|
109 |
l1 = [code_val_0]
|
110 |
l2 = [code_val_1, code_val_4]
|
111 |
l3 = [code_val_2, code_val_3, code_val_5, code_val_6]
|
@@ -130,15 +146,12 @@ class AudioStreamer(BaseStreamer):
|
|
130 |
# --- Decode using SNAC ---
|
131 |
try:
|
132 |
with torch.no_grad():
|
133 |
-
|
134 |
-
audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
|
135 |
except Exception as e_decode:
|
136 |
-
# Add more detailed logging here if it fails again
|
137 |
print(f"Streamer Error: Exception during snac.decode: {e_decode}")
|
138 |
print(f"Input codes shapes: {[c.shape for c in codes]}")
|
139 |
print(f"Input codes dtypes: {[c.dtype for c in codes]}")
|
140 |
print(f"Input codes devices: {[c.device for c in codes]}")
|
141 |
-
# Avoid printing potentially huge lists, maybe just check min/max?
|
142 |
print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
|
143 |
return b""
|
144 |
|
@@ -160,7 +173,12 @@ class AudioStreamer(BaseStreamer):
|
|
160 |
except WebSocketDisconnect:
|
161 |
print("Streamer: WebSocket disconnected during send.")
|
162 |
except Exception as e:
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
def put(self, value: torch.LongTensor):
|
166 |
"""
|
@@ -169,30 +187,34 @@ class AudioStreamer(BaseStreamer):
|
|
169 |
"""
|
170 |
if value.numel() == 0:
|
171 |
return
|
172 |
-
|
|
|
173 |
if isinstance(new_token_ids, int):
|
174 |
new_token_ids = [new_token_ids]
|
175 |
|
176 |
for t in new_token_ids:
|
177 |
-
|
178 |
-
break
|
179 |
if t == NEW_BLOCK:
|
180 |
self.buf.clear()
|
181 |
continue
|
|
|
182 |
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
|
183 |
self.buf.append(t - AUDIO_BASE) # Store value relative to base
|
184 |
-
# else: # Optionally log ignored tokens
|
185 |
-
#
|
|
|
186 |
|
187 |
if len(self.buf) == 7:
|
188 |
audio_bytes = self._decode_block(self.buf)
|
189 |
self.buf.clear()
|
190 |
|
191 |
if audio_bytes:
|
|
|
192 |
future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
|
193 |
self.tasks.add(future)
|
194 |
future.add_done_callback(self.tasks.discard)
|
195 |
|
|
|
196 |
if self.masker.sent_blocks == 0:
|
197 |
self.masker.sent_blocks = 1
|
198 |
|
@@ -201,7 +223,6 @@ class AudioStreamer(BaseStreamer):
|
|
201 |
if len(self.buf) > 0:
|
202 |
print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
|
203 |
self.buf.clear()
|
204 |
-
# print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}")
|
205 |
pass
|
206 |
|
207 |
# 5) FastAPI App ------------------------------------------------------
|
@@ -209,7 +230,7 @@ app = FastAPI()
|
|
209 |
|
210 |
@app.on_event("startup")
|
211 |
async def load_models_startup():
|
212 |
-
global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
|
213 |
|
214 |
print(f"🚀 Starting up on device: {device}")
|
215 |
print("⏳ Lade Modelle …", flush=True)
|
@@ -218,7 +239,7 @@ async def load_models_startup():
|
|
218 |
print("Tokenizer loaded.")
|
219 |
|
220 |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
|
221 |
-
print(f"SNAC loaded to {device}.")
|
222 |
|
223 |
model_dtype = torch.float32
|
224 |
if device == "cuda":
|
@@ -235,25 +256,40 @@ async def load_models_startup():
|
|
235 |
torch_dtype=model_dtype,
|
236 |
low_cpu_mem_usage=True,
|
237 |
)
|
238 |
-
model.config.pad_token_id = model.config.eos_token_id
|
239 |
print(f"Model loaded to {model.device} with dtype {model.dtype}.")
|
240 |
model.eval()
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
audio_ids_device = AUDIO_IDS_CPU.to(device)
|
243 |
-
|
|
|
244 |
print("AudioMask initialized.")
|
245 |
|
246 |
-
|
|
|
247 |
print("StoppingCriteria initialized.")
|
248 |
|
249 |
print("✅ Modelle geladen und bereit!", flush=True)
|
250 |
-
print(f"Tokenizer EOS ID: {tok.eos_token_id}")
|
251 |
-
print(f"Model Config EOS ID: {model.config.eos_token_id}")
|
252 |
-
print(f"Constant EOS_TOKEN: {EOS_TOKEN}")
|
253 |
-
if tok.eos_token_id != EOS_TOKEN or model.config.eos_token_id != EOS_TOKEN:
|
254 |
-
print("⚠️ WARNING: EOS_TOKEN constant might not match model/tokenizer configuration!")
|
255 |
-
# Consider updating EOS_TOKEN if they differ, e.g.:
|
256 |
-
# EOS_TOKEN = model.config.eos_token_id
|
257 |
|
258 |
@app.get("/")
|
259 |
def hello():
|
@@ -277,6 +313,7 @@ def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
|
|
277 |
# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
|
278 |
@app.websocket("/ws/tts")
|
279 |
async def tts(ws: WebSocket):
|
|
|
280 |
await ws.accept()
|
281 |
print("🔌 Client connected")
|
282 |
streamer = None
|
@@ -297,24 +334,27 @@ async def tts(ws: WebSocket):
|
|
297 |
print(f"Generating audio for: '{text}' with voice '{voice}'")
|
298 |
ids, attn = build_prompt(text, voice)
|
299 |
masker.reset()
|
300 |
-
streamer
|
|
|
301 |
|
302 |
print("Starting generation in background thread...")
|
|
|
303 |
await asyncio.to_thread(
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
|
|
318 |
)
|
319 |
print("Generation thread finished.")
|
320 |
|
@@ -347,7 +387,9 @@ async def tts(ws: WebSocket):
|
|
347 |
try:
|
348 |
await ws.close(code=1000)
|
349 |
except RuntimeError as e_close:
|
350 |
-
|
|
|
|
|
351 |
except Exception as e_close_final:
|
352 |
print(f"Error closing websocket: {e_close_final}")
|
353 |
elif ws.client_state.name != "DISCONNECTED":
|
|
|
18 |
snac = None
|
19 |
masker = None
|
20 |
stopping_criteria = None
|
21 |
+
actual_eos_token_id = None # Will be determined during startup
|
22 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
|
24 |
# 0) Login + Device ---------------------------------------------------
|
|
|
34 |
# CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
|
35 |
START_TOKEN = 128259
|
36 |
NEW_BLOCK = 128257
|
37 |
+
# EOS_TOKEN = 128258 # REMOVED - Will be determined from model/tokenizer config
|
38 |
AUDIO_BASE = 128266
|
39 |
AUDIO_SPAN = 4096 * 7 # 28672 Codes
|
40 |
CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
|
|
|
42 |
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
|
43 |
|
44 |
# 2) Logit‑Mask -------------------------------------------------------
|
45 |
+
# Uses the dynamically determined EOS token ID
|
46 |
class AudioMask(LogitsProcessor):
|
47 |
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
|
48 |
super().__init__()
|
49 |
+
# Ensure input tensors are Long type for concatenation if needed, although indices are usually int
|
50 |
+
new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long)
|
51 |
+
eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
|
52 |
+
|
53 |
# Allow NEW_BLOCK and all valid audio tokens initially
|
54 |
+
self.allow = torch.cat([new_block_tensor, audio_ids], dim=0)
|
55 |
+
self.eos = eos_tensor # Store EOS token ID as tensor
|
56 |
+
self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor
|
|
|
|
|
|
|
57 |
self.sent_blocks = 0 # State: Number of audio blocks sent
|
58 |
|
59 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
60 |
+
# Determine which tokens are allowed based on whether blocks have been sent
|
61 |
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
|
62 |
+
|
63 |
+
# Create a mask initialized to negative infinity
|
64 |
mask = torch.full_like(scores, float("-inf"))
|
65 |
+
# Set allowed token scores to 0 (effectively allowing them)
|
66 |
mask[:, current_allow] = 0
|
67 |
+
# Apply the mask to the scores
|
68 |
return scores + mask
|
69 |
|
70 |
def reset(self):
|
71 |
+
"""Resets the state for a new generation request."""
|
72 |
self.sent_blocks = 0
|
73 |
|
74 |
# 3) StoppingCriteria für EOS ---------------------------------------
|
75 |
+
# Uses the dynamically determined EOS token ID
|
76 |
class EosStoppingCriteria(StoppingCriteria):
|
77 |
def __init__(self, eos_token_id: int):
|
78 |
self.eos_token_id = eos_token_id
|
79 |
+
if self.eos_token_id is None:
|
80 |
+
print("⚠️ EosStoppingCriteria initialized with eos_token_id=None!")
|
81 |
|
82 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
83 |
+
if self.eos_token_id is None:
|
84 |
+
return False # Cannot stop if EOS ID is unknown
|
85 |
+
# Check if the *last* generated token is the EOS token
|
86 |
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
|
87 |
+
# print("StoppingCriteria: EOS detected.")
|
88 |
return True
|
89 |
return False
|
90 |
|
91 |
# 4) Benutzerdefinierter AudioStreamer -------------------------------
|
92 |
class AudioStreamer(BaseStreamer):
|
93 |
+
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
|
94 |
self.ws = ws
|
95 |
self.snac = snac_decoder
|
96 |
self.masker = audio_mask
|
97 |
self.loop = loop
|
98 |
self.device = target_device
|
99 |
+
self.eos_token_id = eos_token_id # Store EOS ID for potential use in put (optional)
|
100 |
self.buf: list[int] = []
|
101 |
self.tasks = set()
|
102 |
|
|
|
122 |
code_val_6 = block7[6] % CODEBOOK_SIZE
|
123 |
|
124 |
# --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
|
|
|
125 |
l1 = [code_val_0]
|
126 |
l2 = [code_val_1, code_val_4]
|
127 |
l3 = [code_val_2, code_val_3, code_val_5, code_val_6]
|
|
|
146 |
# --- Decode using SNAC ---
|
147 |
try:
|
148 |
with torch.no_grad():
|
149 |
+
audio = self.snac.decode(codes)[0]
|
|
|
150 |
except Exception as e_decode:
|
|
|
151 |
print(f"Streamer Error: Exception during snac.decode: {e_decode}")
|
152 |
print(f"Input codes shapes: {[c.shape for c in codes]}")
|
153 |
print(f"Input codes dtypes: {[c.dtype for c in codes]}")
|
154 |
print(f"Input codes devices: {[c.device for c in codes]}")
|
|
|
155 |
print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
|
156 |
return b""
|
157 |
|
|
|
173 |
except WebSocketDisconnect:
|
174 |
print("Streamer: WebSocket disconnected during send.")
|
175 |
except Exception as e:
|
176 |
+
# Handle cases where sending fails after connection closed
|
177 |
+
if "Cannot call \"send\" once a close message has been sent" in str(e):
|
178 |
+
# This is expected if client disconnects during generation, suppress repetitive logs
|
179 |
+
pass
|
180 |
+
else:
|
181 |
+
print(f"Streamer: Error sending bytes: {e}")
|
182 |
|
183 |
def put(self, value: torch.LongTensor):
|
184 |
"""
|
|
|
187 |
"""
|
188 |
if value.numel() == 0:
|
189 |
return
|
190 |
+
# Ensure value is on CPU and flatten to a list of ints
|
191 |
+
new_token_ids = value.squeeze().cpu().tolist()
|
192 |
if isinstance(new_token_ids, int):
|
193 |
new_token_ids = [new_token_ids]
|
194 |
|
195 |
for t in new_token_ids:
|
196 |
+
# No need to check for EOS here, StoppingCriteria handles it
|
|
|
197 |
if t == NEW_BLOCK:
|
198 |
self.buf.clear()
|
199 |
continue
|
200 |
+
|
201 |
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
|
202 |
self.buf.append(t - AUDIO_BASE) # Store value relative to base
|
203 |
+
# else: # Optionally log ignored tokens outside audio range
|
204 |
+
# if t != self.eos_token_id: # Don't warn about the EOS token itself
|
205 |
+
# print(f"Streamer Warning: Ignoring unexpected token {t}")
|
206 |
|
207 |
if len(self.buf) == 7:
|
208 |
audio_bytes = self._decode_block(self.buf)
|
209 |
self.buf.clear()
|
210 |
|
211 |
if audio_bytes:
|
212 |
+
# Schedule the async send function to run on the main event loop
|
213 |
future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
|
214 |
self.tasks.add(future)
|
215 |
future.add_done_callback(self.tasks.discard)
|
216 |
|
217 |
+
# Allow EOS only after the first full block has been processed
|
218 |
if self.masker.sent_blocks == 0:
|
219 |
self.masker.sent_blocks = 1
|
220 |
|
|
|
223 |
if len(self.buf) > 0:
|
224 |
print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
|
225 |
self.buf.clear()
|
|
|
226 |
pass
|
227 |
|
228 |
# 5) FastAPI App ------------------------------------------------------
|
|
|
230 |
|
231 |
@app.on_event("startup")
|
232 |
async def load_models_startup():
|
233 |
+
global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU, actual_eos_token_id
|
234 |
|
235 |
print(f"🚀 Starting up on device: {device}")
|
236 |
print("⏳ Lade Modelle …", flush=True)
|
|
|
239 |
print("Tokenizer loaded.")
|
240 |
|
241 |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
|
242 |
+
print(f"SNAC loaded to {device}.")
|
243 |
|
244 |
model_dtype = torch.float32
|
245 |
if device == "cuda":
|
|
|
256 |
torch_dtype=model_dtype,
|
257 |
low_cpu_mem_usage=True,
|
258 |
)
|
|
|
259 |
print(f"Model loaded to {model.device} with dtype {model.dtype}.")
|
260 |
model.eval()
|
261 |
|
262 |
+
# --- Determine and set the correct EOS token ID ---
|
263 |
+
conf_eos = model.config.eos_token_id
|
264 |
+
tok_eos = tok.eos_token_id
|
265 |
+
print(f"Model Config EOS ID: {conf_eos}")
|
266 |
+
print(f"Tokenizer EOS ID: {tok_eos}")
|
267 |
+
|
268 |
+
if conf_eos is not None:
|
269 |
+
actual_eos_token_id = conf_eos
|
270 |
+
elif tok_eos is not None:
|
271 |
+
actual_eos_token_id = tok_eos
|
272 |
+
print(f"⚠️ Model config EOS ID is None, using Tokenizer EOS ID: {actual_eos_token_id}")
|
273 |
+
else:
|
274 |
+
raise ValueError("Could not determine EOS token ID from model config or tokenizer.")
|
275 |
+
|
276 |
+
print(f"Using EOS Token ID: {actual_eos_token_id}")
|
277 |
+
# Set pad_token_id to eos_token_id if not already set (common practice for generation)
|
278 |
+
if model.config.pad_token_id is None:
|
279 |
+
print(f"Setting model.config.pad_token_id to EOS token ID ({actual_eos_token_id})")
|
280 |
+
model.config.pad_token_id = actual_eos_token_id
|
281 |
+
# --- End EOS Token ID determination ---
|
282 |
+
|
283 |
audio_ids_device = AUDIO_IDS_CPU.to(device)
|
284 |
+
# Pass the determined EOS ID to the mask
|
285 |
+
masker = AudioMask(audio_ids_device, NEW_BLOCK, actual_eos_token_id)
|
286 |
print("AudioMask initialized.")
|
287 |
|
288 |
+
# Pass the determined EOS ID to the stopping criteria
|
289 |
+
stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(actual_eos_token_id)])
|
290 |
print("StoppingCriteria initialized.")
|
291 |
|
292 |
print("✅ Modelle geladen und bereit!", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
@app.get("/")
|
295 |
def hello():
|
|
|
313 |
# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
|
314 |
@app.websocket("/ws/tts")
|
315 |
async def tts(ws: WebSocket):
|
316 |
+
global actual_eos_token_id # Ensure we can access the determined EOS ID
|
317 |
await ws.accept()
|
318 |
print("🔌 Client connected")
|
319 |
streamer = None
|
|
|
334 |
print(f"Generating audio for: '{text}' with voice '{voice}'")
|
335 |
ids, attn = build_prompt(text, voice)
|
336 |
masker.reset()
|
337 |
+
# Pass the determined EOS ID to the streamer as well (optional, for logging/checks)
|
338 |
+
streamer = AudioStreamer(ws, snac, masker, main_loop, device, actual_eos_token_id)
|
339 |
|
340 |
print("Starting generation in background thread...")
|
341 |
+
# Use sampling parameters to avoid repetition
|
342 |
await asyncio.to_thread(
|
343 |
+
model.generate,
|
344 |
+
input_ids=ids,
|
345 |
+
attention_mask=attn,
|
346 |
+
max_new_tokens=2500, # Increased slightly, adjust as needed
|
347 |
+
logits_processor=[masker],
|
348 |
+
stopping_criteria=stopping_criteria,
|
349 |
+
# --- Sampling Parameters ---
|
350 |
+
do_sample=True,
|
351 |
+
temperature=0.6,
|
352 |
+
top_p=0.9,
|
353 |
+
repetition_penalty=1.15,
|
354 |
+
# --- End Sampling Parameters ---
|
355 |
+
use_cache=True,
|
356 |
+
streamer=streamer,
|
357 |
+
eos_token_id=actual_eos_token_id # Explicitly pass correct EOS ID here too
|
358 |
)
|
359 |
print("Generation thread finished.")
|
360 |
|
|
|
387 |
try:
|
388 |
await ws.close(code=1000)
|
389 |
except RuntimeError as e_close:
|
390 |
+
# Suppress "Cannot call 'send'..." error during final close if already disconnected
|
391 |
+
if "Cannot call \"send\"" not in str(e_close):
|
392 |
+
print(f"Runtime error closing websocket: {e_close}")
|
393 |
except Exception as e_close_final:
|
394 |
print(f"Error closing websocket: {e_close_final}")
|
395 |
elif ws.client_state.name != "DISCONNECTED":
|