Spaces:
Running
Running
# streaming.py - Fixed version to send incremental delta chunks | |
import json | |
import re | |
import logging | |
from typing import AsyncGenerator, Callable, Optional, Iterator, Tuple | |
from dataclasses import dataclass | |
import sys | |
import time | |
from selenium.common.exceptions import NoSuchElementException, StaleElementReferenceException, TimeoutException | |
from selenium.webdriver.common.by import By | |
from selenium.webdriver.remote.webelement import WebElement | |
from selenium.webdriver.support.ui import WebDriverWait | |
from selenium.webdriver.support import expected_conditions as EC | |
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
class StreamConfig: | |
"""Configuration for streaming behavior""" | |
timeout: float = 300.0 # Overall timeout for the stream generation | |
retry_on_error: bool = True | |
max_retries: int = 3 | |
# Configs below are more for the StreamProcessor | |
poll_interval: float = 0.05 | |
response_timeout: float = 900 # Max time for Selenium to wait for response content | |
stabilization_timeout: float = 1.0 # Extra wait after Selenium stream ends | |
max_inactivity: float = 10.0 # Max time between chunks from Selenium stream | |
class StreamingResponseGenerator: | |
""" | |
Generates Server-Sent Events (SSE) for streaming chat completions. | |
It takes an async generator producing text deltas and formats them into SSE chunks. | |
""" | |
def __init__(self, config: Optional[StreamConfig] = None): | |
self.config = config or StreamConfig() | |
async def create_response( | |
self, | |
completion_id: str, | |
created: int, | |
model: str, | |
prompt: str, # This is the full_prompt for the send_message_func | |
send_message_func: Callable[[str, str], AsyncGenerator[str, None]] # Called with (prompt, model) | |
) -> AsyncGenerator[str, None]: | |
""" | |
Creates an SSE stream from the text deltas generated by send_message_func. | |
""" | |
logger.info(f"[{completion_id}] Starting streaming response generation for model '{model}'.") | |
first_chunk_sent = False | |
accumulated_content_for_logging = "" | |
try: | |
# `send_message_func` is `ChatHandler.send_message_and_stream_response` | |
# It takes (prompt, model_id) and yields text deltas. | |
async for content_delta in send_message_func(prompt, model): | |
if not content_delta: # Skip empty chunks from the source | |
continue | |
accumulated_content_for_logging += content_delta | |
# logger.debug(f"[{completion_id}] Received content delta: '{content_delta[:50].replace(chr(10), ' ')}...' (Total: {len(accumulated_content_for_logging)})") | |
delta_payload = {"content": content_delta} | |
if not first_chunk_sent: | |
# The first contentful chunk should also carry the role. | |
delta_payload["role"] = "assistant" | |
first_chunk_sent = True | |
chunk_data = { | |
"id": completion_id, | |
"object": "chat.completion.chunk", | |
"created": created, | |
"model": model, | |
"choices": [{ | |
"index": 0, | |
"delta": delta_payload, | |
"finish_reason": None | |
}] | |
} | |
chunk_json = json.dumps(chunk_data) | |
# logger.debug(f"[{completion_id}] Yielding delta chunk: data: {chunk_json}") | |
yield f"data: {chunk_json}\n\n" | |
# After the loop, if no content was ever sent but the stream finished, | |
# ensure at least one chunk (possibly empty with role) is sent before finish_reason. | |
# However, OpenAI spec usually sends finish_reason in a new chunk. | |
if not first_chunk_sent: # If stream ended without any content | |
logger.warning(f"[{completion_id}] Stream ended without any content. Sending empty assistant chunk before finish.") | |
empty_assistant_chunk = { | |
"id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, | |
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}] | |
} | |
yield f"data: {json.dumps(empty_assistant_chunk)}\n\n" | |
# Send the final chunk with finish_reason | |
final_chunk_data = { | |
"id": completion_id, | |
"object": "chat.completion.chunk", | |
"created": created, | |
"model": model, | |
"choices": [{ | |
"index": 0, | |
"delta": {}, # Empty delta | |
"finish_reason": "stop" | |
}] | |
} | |
final_chunk_json = json.dumps(final_chunk_data) | |
logger.info(f"[{completion_id}] Yielding final chunk. Total content length: {len(accumulated_content_for_logging)} chars.") | |
yield f"data: {final_chunk_json}\n\n" | |
logger.info(f"[{completion_id}] Yielding [DONE] signal.") | |
yield "data: [DONE]\n\n" | |
except Exception as e: | |
logger.error(f"[{completion_id}] Error during streaming response generation: {e}", exc_info=True) | |
error_payload = {"content": f"\n\nError in stream: {str(e)}"} | |
if not first_chunk_sent: | |
error_payload["role"] = "assistant" | |
error_chunk_data = { | |
"id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, | |
"choices": [{"index": 0, "delta": error_payload, "finish_reason": "error" }] | |
} | |
yield f"data: {json.dumps(error_chunk_data)}\n\n" | |
yield "data: [DONE]\n\n" # Always end with [DONE] | |
class StreamProcessor: | |
"""Processes streaming content with stabilization detection from Selenium elements.""" | |
def __init__(self, config: Optional[StreamConfig] = None): | |
# Use a passed config or default StreamConfig values for processor-specific settings | |
default_cfg = StreamConfig() | |
self.poll_interval = config.poll_interval if config else default_cfg.poll_interval | |
self.response_timeout = config.response_timeout if config else default_cfg.response_timeout # For poll_element_text_stream's max_wait | |
self.stabilization_timeout = config.stabilization_timeout if config else default_cfg.stabilization_timeout | |
self.max_inactivity = config.max_inactivity if config else default_cfg.max_inactivity | |
self.request_id_for_log = "StreamProc" # Default, should be updated if possible | |
def set_request_id(self, request_id: str): | |
self.request_id_for_log = request_id | |
def read_stream_with_stabilization(self, stream_iterator: Iterator[str]) -> Iterator[str]: | |
""" | |
Wraps an iterator of text deltas, forwarding them while monitoring for inactivity. | |
Terminates if max_inactivity is breached or after the underlying stream is exhausted | |
and an additional stabilization_timeout has passed. | |
""" | |
last_successful_yield_time = time.time() | |
log_prefix = f"[{self.request_id_for_log}/ReadStream]" | |
try: | |
for chunk_delta in stream_iterator: # `stream_iterator` is `poll_element_text_stream` | |
current_time = time.time() | |
if current_time - last_successful_yield_time > self.max_inactivity: | |
logger.warning(f"{log_prefix} Max inactivity ({self.max_inactivity:.2f}s) breached. Ending stream.") | |
return | |
if chunk_delta: | |
yield chunk_delta | |
last_successful_yield_time = current_time | |
# No specific handling for empty chunk here as poll_element_text_stream filters them. | |
# Underlying stream_iterator is now exhated. | |
# Wait for an additional stabilization_timeout period. | |
time_since_last_yield = time.time() - last_successful_yield_time | |
if time_since_last_yield < self.stabilization_timeout: | |
wait_needed = self.stabilization_timeout - time_since_last_yield | |
if wait_needed > 0.01: # Only log/sleep if meaningful wait | |
logger.debug(f"{log_prefix} Underlying stream exhausted. Waiting for stabilization: {wait_needed:.2f}s.") | |
time.sleep(wait_needed) | |
logger.debug(f"{log_prefix} Stabilization complete or underlying stream ended.") | |
except Exception as e: | |
logger.error(f"{log_prefix} Error: {e}", exc_info=True) | |
return | |
def poll_element_text_stream(self, driver, element_locator: Tuple[str, str]) -> Iterator[str]: | |
""" | |
Polls a web element's text content and yields incremental changes (deltas). | |
Stops if the element disappears, or if `self.response_timeout` is reached and text has not changed. | |
""" | |
log_prefix = f"[{self.request_id_for_log}/PollStream]" | |
start_time_for_stream = time.time() | |
last_text = "" | |
element_previously_found = False | |
last_change_time = time.time() | |
logger.debug(f"{log_prefix} Starting to poll {element_locator} for text. Max wait: {self.response_timeout}s, Poll interval: {self.poll_interval}s.") | |
while time.time() - start_time_for_stream < self.response_timeout: | |
loop_start_time = time.time() | |
try: | |
# Re-evaluate WebDriverWait for presence here | |
element = WebDriverWait(driver, self.poll_interval * 2, poll_frequency=self.poll_interval/2).until( | |
EC.presence_of_element_located(element_locator) | |
) | |
element_previously_found = True | |
current_text = element.text | |
if current_text != last_text: | |
new_text_delta = current_text[len(last_text):] | |
if new_text_delta: | |
# logger.debug(f"{log_prefix} Yielding delta (len {len(new_text_delta)}): '{new_text_delta[:30].replace(chr(10),' ')}...'") | |
yield new_text_delta | |
last_change_time = time.time() # Update time of last actual text change | |
last_text = current_text | |
else: | |
# Text has not changed. If no change for `max_inactivity` (handled by read_stream_with_stabilization) | |
# or if `response_timeout` (overall) is hit, it will stop. | |
# This specific condition checks if text has stabilized for a duration longer than max_inactivity | |
# If this poller's `response_timeout` is very long, it relies on the outer wrapper's max_inactivity. | |
if time.time() - last_change_time > self.max_inactivity : # Check inactivity from *this* poller's perspective | |
logger.info(f"{log_prefix} Text has not changed for {self.max_inactivity:.2f}s. Assuming stable.") | |
return | |
except TimeoutException: # From WebDriverWait if element not present within its short timeout | |
if element_previously_found: | |
logger.info(f"{log_prefix} Element {element_locator} became non-present after being found. Assuming stream ended.") | |
return | |
# else: Element not yet appeared, continue polling up to self.response_timeout | |
logger.debug(f"{log_prefix} Element {element_locator} not present yet. Continuing poll.") | |
except StaleElementReferenceException: | |
logger.warning(f"{log_prefix} StaleElementReferenceException for {element_locator}. Resetting and retrying find.") | |
last_text = "" | |
element_previously_found = False | |
except NoSuchElementException: # Should be caught by WebDriverWait's TimeoutException mostly | |
logger.warning(f"{log_prefix} NoSuchElementException for {element_locator} (should be rare with WebDriverWait).") | |
if element_previously_found: return | |
except Exception as e: | |
logger.error(f"{log_prefix} Unexpected error polling {element_locator}: {e}", exc_info=True) | |
return | |
elapsed_in_loop = time.time() - loop_start_time | |
sleep_duration = self.poll_interval - elapsed_in_loop | |
if sleep_duration > 0: | |
time.sleep(sleep_duration) | |
logger.info(f"{log_prefix} Max_wait ({self.response_timeout:.2f}s) reached for polling {element_locator} or stream naturally ended.") | |
# Wrapper for creating the generator instance easily | |
async def create_streaming_response( | |
completion_id: str, | |
created: int, | |
model: str, | |
prompt: str, | |
send_message_func: Callable[[str, str], AsyncGenerator[str, None]], | |
stream_config: Optional[StreamConfig] = None # Allow passing specific config | |
) -> AsyncGenerator[str, None]: | |
generator = StreamingResponseGenerator(config=stream_config) | |
async for chunk in generator.create_response(completion_id, created, model, prompt, send_message_func): | |
yield chunk | |
__all__ = ['create_streaming_response', 'StreamingResponseGenerator', 'StreamConfig', 'StreamProcessor'] |