Lenith / streaming.py
Tecnhotron
First
314597a
raw
history blame
13.6 kB
# streaming.py - Fixed version to send incremental delta chunks
import json
import re
import logging
from typing import AsyncGenerator, Callable, Optional, Iterator, Tuple
from dataclasses import dataclass
import sys
import time
from selenium.common.exceptions import NoSuchElementException, StaleElementReferenceException, TimeoutException
from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webelement import WebElement
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@dataclass
class StreamConfig:
"""Configuration for streaming behavior"""
timeout: float = 300.0 # Overall timeout for the stream generation
retry_on_error: bool = True
max_retries: int = 3
# Configs below are more for the StreamProcessor
poll_interval: float = 0.05
response_timeout: float = 900 # Max time for Selenium to wait for response content
stabilization_timeout: float = 1.0 # Extra wait after Selenium stream ends
max_inactivity: float = 10.0 # Max time between chunks from Selenium stream
class StreamingResponseGenerator:
"""
Generates Server-Sent Events (SSE) for streaming chat completions.
It takes an async generator producing text deltas and formats them into SSE chunks.
"""
def __init__(self, config: Optional[StreamConfig] = None):
self.config = config or StreamConfig()
async def create_response(
self,
completion_id: str,
created: int,
model: str,
prompt: str, # This is the full_prompt for the send_message_func
send_message_func: Callable[[str, str], AsyncGenerator[str, None]] # Called with (prompt, model)
) -> AsyncGenerator[str, None]:
"""
Creates an SSE stream from the text deltas generated by send_message_func.
"""
logger.info(f"[{completion_id}] Starting streaming response generation for model '{model}'.")
first_chunk_sent = False
accumulated_content_for_logging = ""
try:
# `send_message_func` is `ChatHandler.send_message_and_stream_response`
# It takes (prompt, model_id) and yields text deltas.
async for content_delta in send_message_func(prompt, model):
if not content_delta: # Skip empty chunks from the source
continue
accumulated_content_for_logging += content_delta
# logger.debug(f"[{completion_id}] Received content delta: '{content_delta[:50].replace(chr(10), ' ')}...' (Total: {len(accumulated_content_for_logging)})")
delta_payload = {"content": content_delta}
if not first_chunk_sent:
# The first contentful chunk should also carry the role.
delta_payload["role"] = "assistant"
first_chunk_sent = True
chunk_data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{
"index": 0,
"delta": delta_payload,
"finish_reason": None
}]
}
chunk_json = json.dumps(chunk_data)
# logger.debug(f"[{completion_id}] Yielding delta chunk: data: {chunk_json}")
yield f"data: {chunk_json}\n\n"
# After the loop, if no content was ever sent but the stream finished,
# ensure at least one chunk (possibly empty with role) is sent before finish_reason.
# However, OpenAI spec usually sends finish_reason in a new chunk.
if not first_chunk_sent: # If stream ended without any content
logger.warning(f"[{completion_id}] Stream ended without any content. Sending empty assistant chunk before finish.")
empty_assistant_chunk = {
"id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}]
}
yield f"data: {json.dumps(empty_assistant_chunk)}\n\n"
# Send the final chunk with finish_reason
final_chunk_data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{
"index": 0,
"delta": {}, # Empty delta
"finish_reason": "stop"
}]
}
final_chunk_json = json.dumps(final_chunk_data)
logger.info(f"[{completion_id}] Yielding final chunk. Total content length: {len(accumulated_content_for_logging)} chars.")
yield f"data: {final_chunk_json}\n\n"
logger.info(f"[{completion_id}] Yielding [DONE] signal.")
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"[{completion_id}] Error during streaming response generation: {e}", exc_info=True)
error_payload = {"content": f"\n\nError in stream: {str(e)}"}
if not first_chunk_sent:
error_payload["role"] = "assistant"
error_chunk_data = {
"id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [{"index": 0, "delta": error_payload, "finish_reason": "error" }]
}
yield f"data: {json.dumps(error_chunk_data)}\n\n"
yield "data: [DONE]\n\n" # Always end with [DONE]
class StreamProcessor:
"""Processes streaming content with stabilization detection from Selenium elements."""
def __init__(self, config: Optional[StreamConfig] = None):
# Use a passed config or default StreamConfig values for processor-specific settings
default_cfg = StreamConfig()
self.poll_interval = config.poll_interval if config else default_cfg.poll_interval
self.response_timeout = config.response_timeout if config else default_cfg.response_timeout # For poll_element_text_stream's max_wait
self.stabilization_timeout = config.stabilization_timeout if config else default_cfg.stabilization_timeout
self.max_inactivity = config.max_inactivity if config else default_cfg.max_inactivity
self.request_id_for_log = "StreamProc" # Default, should be updated if possible
def set_request_id(self, request_id: str):
self.request_id_for_log = request_id
def read_stream_with_stabilization(self, stream_iterator: Iterator[str]) -> Iterator[str]:
"""
Wraps an iterator of text deltas, forwarding them while monitoring for inactivity.
Terminates if max_inactivity is breached or after the underlying stream is exhausted
and an additional stabilization_timeout has passed.
"""
last_successful_yield_time = time.time()
log_prefix = f"[{self.request_id_for_log}/ReadStream]"
try:
for chunk_delta in stream_iterator: # `stream_iterator` is `poll_element_text_stream`
current_time = time.time()
if current_time - last_successful_yield_time > self.max_inactivity:
logger.warning(f"{log_prefix} Max inactivity ({self.max_inactivity:.2f}s) breached. Ending stream.")
return
if chunk_delta:
yield chunk_delta
last_successful_yield_time = current_time
# No specific handling for empty chunk here as poll_element_text_stream filters them.
# Underlying stream_iterator is now exhated.
# Wait for an additional stabilization_timeout period.
time_since_last_yield = time.time() - last_successful_yield_time
if time_since_last_yield < self.stabilization_timeout:
wait_needed = self.stabilization_timeout - time_since_last_yield
if wait_needed > 0.01: # Only log/sleep if meaningful wait
logger.debug(f"{log_prefix} Underlying stream exhausted. Waiting for stabilization: {wait_needed:.2f}s.")
time.sleep(wait_needed)
logger.debug(f"{log_prefix} Stabilization complete or underlying stream ended.")
except Exception as e:
logger.error(f"{log_prefix} Error: {e}", exc_info=True)
return
def poll_element_text_stream(self, driver, element_locator: Tuple[str, str]) -> Iterator[str]:
"""
Polls a web element's text content and yields incremental changes (deltas).
Stops if the element disappears, or if `self.response_timeout` is reached and text has not changed.
"""
log_prefix = f"[{self.request_id_for_log}/PollStream]"
start_time_for_stream = time.time()
last_text = ""
element_previously_found = False
last_change_time = time.time()
logger.debug(f"{log_prefix} Starting to poll {element_locator} for text. Max wait: {self.response_timeout}s, Poll interval: {self.poll_interval}s.")
while time.time() - start_time_for_stream < self.response_timeout:
loop_start_time = time.time()
try:
# Re-evaluate WebDriverWait for presence here
element = WebDriverWait(driver, self.poll_interval * 2, poll_frequency=self.poll_interval/2).until(
EC.presence_of_element_located(element_locator)
)
element_previously_found = True
current_text = element.text
if current_text != last_text:
new_text_delta = current_text[len(last_text):]
if new_text_delta:
# logger.debug(f"{log_prefix} Yielding delta (len {len(new_text_delta)}): '{new_text_delta[:30].replace(chr(10),' ')}...'")
yield new_text_delta
last_change_time = time.time() # Update time of last actual text change
last_text = current_text
else:
# Text has not changed. If no change for `max_inactivity` (handled by read_stream_with_stabilization)
# or if `response_timeout` (overall) is hit, it will stop.
# This specific condition checks if text has stabilized for a duration longer than max_inactivity
# If this poller's `response_timeout` is very long, it relies on the outer wrapper's max_inactivity.
if time.time() - last_change_time > self.max_inactivity : # Check inactivity from *this* poller's perspective
logger.info(f"{log_prefix} Text has not changed for {self.max_inactivity:.2f}s. Assuming stable.")
return
except TimeoutException: # From WebDriverWait if element not present within its short timeout
if element_previously_found:
logger.info(f"{log_prefix} Element {element_locator} became non-present after being found. Assuming stream ended.")
return
# else: Element not yet appeared, continue polling up to self.response_timeout
logger.debug(f"{log_prefix} Element {element_locator} not present yet. Continuing poll.")
except StaleElementReferenceException:
logger.warning(f"{log_prefix} StaleElementReferenceException for {element_locator}. Resetting and retrying find.")
last_text = ""
element_previously_found = False
except NoSuchElementException: # Should be caught by WebDriverWait's TimeoutException mostly
logger.warning(f"{log_prefix} NoSuchElementException for {element_locator} (should be rare with WebDriverWait).")
if element_previously_found: return
except Exception as e:
logger.error(f"{log_prefix} Unexpected error polling {element_locator}: {e}", exc_info=True)
return
elapsed_in_loop = time.time() - loop_start_time
sleep_duration = self.poll_interval - elapsed_in_loop
if sleep_duration > 0:
time.sleep(sleep_duration)
logger.info(f"{log_prefix} Max_wait ({self.response_timeout:.2f}s) reached for polling {element_locator} or stream naturally ended.")
# Wrapper for creating the generator instance easily
async def create_streaming_response(
completion_id: str,
created: int,
model: str,
prompt: str,
send_message_func: Callable[[str, str], AsyncGenerator[str, None]],
stream_config: Optional[StreamConfig] = None # Allow passing specific config
) -> AsyncGenerator[str, None]:
generator = StreamingResponseGenerator(config=stream_config)
async for chunk in generator.create_response(completion_id, created, model, prompt, send_message_func):
yield chunk
__all__ = ['create_streaming_response', 'StreamingResponseGenerator', 'StreamConfig', 'StreamProcessor']