Spaces:
Running
Running
# streaming.py - FINAL CORRECTED version with robust, state-aware processing. | |
import json | |
import re | |
import logging | |
from typing import AsyncGenerator, Callable, Optional, Iterator, Tuple | |
from dataclasses import dataclass | |
import sys, re, html | |
import time | |
from selenium.common.exceptions import NoSuchElementException, StaleElementReferenceException, TimeoutException | |
from selenium.webdriver.common.by import By | |
from selenium.webdriver.remote.webelement import WebElement | |
from selenium.webdriver.support.ui import WebDriverWait | |
from selenium.webdriver.support import expected_conditions as EC | |
from bs4 import BeautifulSoup, Tag | |
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Helper Classes (largely unchanged, but included for completeness) --- | |
class HtmlToMarkdownConverter: | |
""" | |
Converts a BeautifulSoup Tag element into a Markdown string. | |
This is designed to be called on finalized, complete HTML elements. | |
""" | |
def convert_element(self, el: Tag) -> str: | |
if not isinstance(el, Tag): | |
return "" | |
name = el.name | |
if name == 'pre': | |
lang_match = re.search(r'class="language-(.*?)"', str(el), re.IGNORECASE) | |
lang = lang_match.group(1).strip() if lang_match else '' | |
# .get_text() from BS4 correctly preserves newlines from the code structure | |
content = el.get_text().strip() | |
return f'```{lang}\n{content}\n```' | |
if name == 'p': | |
# Convert child tags like <strong> or <code> within the paragraph | |
content = ''.join(self.convert_inline(child) for child in el.contents) | |
return content | |
if name in ['ul', 'ol']: | |
# The site's HTML for lists already contains the bullet/number in the text. | |
items = [li.get_text().strip() for li in el.find_all('li', recursive=False)] | |
return '\n'.join(items) | |
if name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: | |
level = int(name[1]) | |
return '#' * level + ' ' + el.get_text().strip() | |
return el.get_text() | |
def convert_inline(self, el): | |
"""Recursively converts inline elements to markdown.""" | |
if isinstance(el, str): | |
return html.unescape(el) | |
if not isinstance(el, Tag): | |
return "" | |
content = ''.join(self.convert_inline(child) for child in el.contents) | |
if el.name == 'strong' or el.name == 'b': | |
return f'**{content}**' | |
if el.name == 'em' or el.name == 'i': | |
return f'*{content}*' | |
if el.name == 'code': | |
return f'`{content}`' | |
if el.name == 'a': | |
href = el.get('href', '') | |
return f'[{content}]({href})' | |
return content | |
class StreamConfig: | |
"""Configuration for streaming behavior""" | |
timeout: float = 300.0 | |
retry_on_error: bool = True | |
max_retries: int = 3 | |
poll_interval: float = 0.05 | |
response_timeout: float = 900 | |
stabilization_timeout: float = 1.0 | |
max_inactivity: float = 10.0 | |
stream_raw_html: bool = False | |
convert_html_to_markdown: bool = True | |
class StreamingResponseGenerator: | |
"""Generates Server-Sent Events (SSE) for streaming chat completions.""" | |
def __init__(self, config: Optional[StreamConfig] = None): | |
self.config = config or StreamConfig() | |
async def create_response( | |
self, | |
completion_id: str, | |
created: int, | |
model: str, | |
prompt: str, | |
send_message_func: Callable[..., AsyncGenerator[str, None]] | |
) -> AsyncGenerator[str, None]: | |
logger.info(f"[{completion_id}] Starting streaming response generation for model '{model}'.") | |
first_chunk_sent = False | |
accumulated_content_for_logging = "" | |
try: | |
async for content_delta in send_message_func(prompt, model): | |
if not content_delta: continue | |
accumulated_content_for_logging += content_delta | |
delta_payload = {"content": content_delta} | |
if not first_chunk_sent: | |
delta_payload["role"] = "assistant" | |
first_chunk_sent = True | |
chunk_data = { | |
"id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, | |
"choices": [{"index": 0, "delta": delta_payload, "finish_reason": None}] | |
} | |
yield f"data: {json.dumps(chunk_data)}\n\n" | |
if not first_chunk_sent: | |
logger.warning(f"[{completion_id}] Stream ended without content. Sending empty assistant chunk.") | |
empty_chunk = { | |
"id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, | |
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}] | |
} | |
yield f"data: {json.dumps(empty_chunk)}\n\n" | |
final_chunk_data = { | |
"id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, | |
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] | |
} | |
logger.info(f"[{completion_id}] Yielding final chunk. Total content length: {len(accumulated_content_for_logging)} chars.") | |
yield f"data: {json.dumps(final_chunk_data)}\n\n" | |
logger.info(f"[{completion_id}] Yielding [DONE] signal.") | |
yield "data: [DONE]\n\n" | |
except Exception as e: | |
logger.error(f"[{completion_id}] Error during streaming response generation: {e}", exc_info=True) | |
error_payload = {"content": f"\n\nError in stream: {str(e)}"} | |
if not first_chunk_sent: error_payload["role"] = "assistant" | |
error_chunk = { | |
"id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, | |
"choices": [{"index": 0, "delta": error_payload, "finish_reason": "error" }] | |
} | |
yield f"data: {json.dumps(error_chunk)}\n\n" | |
yield "data: [DONE]\n\n" | |
# --- Main Processor with Corrected Logic --- | |
class StreamProcessor: | |
"""Processes streaming content from Selenium, with stabilization and content conversion.""" | |
def __init__(self, config: Optional[StreamConfig] = None): | |
cfg = config or StreamConfig() | |
self.poll_interval = cfg.poll_interval | |
self.response_timeout = cfg.response_timeout | |
self.stabilization_timeout = cfg.stabilization_timeout | |
self.max_inactivity = cfg.max_inactivity | |
self.request_id_for_log = "StreamProc" | |
self.markdown_converter = HtmlToMarkdownConverter() | |
def set_request_id(self, request_id: str): | |
self.request_id_for_log = request_id | |
def get_processed_text_stream(self, driver, element_locator: Tuple[str, str]) -> Iterator[str]: | |
html_stream = self._poll_element_content_stream(driver, element_locator) | |
return self._convert_html_stream_to_markdown_deltas(html_stream) | |
def _convert_html_stream_to_markdown_deltas(self, html_iterator: Iterator[str]) -> Iterator[str]: | |
""" | |
Statefully converts a stream of HTML snapshots to Markdown deltas. | |
It identifies "finalized" vs "active" blocks to avoid streaming | |
volatile, incomplete code blocks, preventing corruption. | |
""" | |
last_yielded_markdown = "" | |
last_html_snapshot = "" | |
for full_html_snapshot in html_iterator: | |
last_html_snapshot = full_html_snapshot | |
soup = BeautifulSoup(f"<ol>{full_html_snapshot}</ol>", 'lxml') | |
# Find all content blocks in the last message bubble. | |
# The structure is <ol> -> <div> (bubble) -> ... -> <div class="prose"> -> <p/pre/ul> | |
all_prose_divs = soup.select('div.prose') | |
if not all_prose_divs: | |
continue | |
content_elements = [div.contents[0] for div in all_prose_divs if div.contents] | |
finalized_elements = content_elements | |
# Check if the very last element is an incomplete code block. | |
# If so, we don't process it in this pass, we wait for it to be finalized. | |
if content_elements and content_elements[-1].name == 'pre': | |
finalized_elements = content_elements[:-1] | |
# Convert all finalized elements to markdown | |
md_pieces = [self.markdown_converter.convert_element(el) for el in finalized_elements] | |
current_safe_markdown = "\n\n".join(md_pieces) | |
# Yield the delta if it's a simple append | |
if current_safe_markdown != last_yielded_markdown: | |
if current_safe_markdown.startswith(last_yielded_markdown): | |
delta = current_safe_markdown[len(last_yielded_markdown):] | |
if delta: | |
yield delta.lstrip('\n') | |
last_yielded_markdown = current_safe_markdown | |
# After the loop, the stream has finished. Process the very last snapshot in full. | |
final_soup = BeautifulSoup(f"<ol>{last_html_snapshot}</ol>", 'lxml') | |
all_prose_divs = final_soup.select('div.prose') | |
content_elements = [div.contents[0] for div in all_prose_divs if div.contents] | |
final_md_pieces = [self.markdown_converter.convert_element(el) for el in content_elements] | |
final_markdown = "\n\n".join(final_md_pieces) | |
if final_markdown.startswith(last_yielded_markdown): | |
final_delta = final_markdown[len(last_yielded_markdown):] | |
if final_delta: | |
yield final_delta.lstrip('\n') | |
def _poll_element_content_stream(self, driver, element_locator: Tuple[str, str]) -> Iterator[str]: | |
log_prefix = f"[{self.request_id_for_log}/PollStream]" | |
start_time = time.time() | |
last_content = "" | |
last_change_time = time.time() | |
while time.time() - start_time < self.response_timeout: | |
loop_start_time = time.time() | |
try: | |
element = WebDriverWait(driver, self.poll_interval * 2).until(EC.presence_of_element_located(element_locator)) | |
current_content = element.get_attribute('innerHTML') | |
if current_content != last_content: | |
yield current_content | |
last_content = current_content | |
last_change_time = time.time() | |
elif time.time() - last_change_time > self.max_inactivity: | |
logger.info(f"{log_prefix} Content stable for {self.max_inactivity:.2f}s. Ending poll.") | |
return | |
except (TimeoutException, StaleElementReferenceException, NoSuchElementException): | |
logger.debug(f"{log_prefix} Element not present or stale. Continuing poll.") | |
except Exception as e: | |
logger.error(f"{log_prefix} Unexpected error polling: {e}", exc_info=True) | |
return | |
sleep_duration = self.poll_interval - (time.time() - loop_start_time) | |
if sleep_duration > 0: time.sleep(sleep_duration) | |
logger.warning(f"{log_prefix} Polling finished due to max wait ({self.response_timeout:.2f}s).") | |
async def create_streaming_response( | |
completion_id: str, | |
created: int, | |
model: str, | |
prompt: str, | |
send_message_func: Callable[..., AsyncGenerator[str, None]], | |
stream_config: Optional[StreamConfig] = None | |
) -> AsyncGenerator[str, None]: | |
generator = StreamingResponseGenerator(config=stream_config) | |
async for chunk in generator.create_response(completion_id, created, model, prompt, send_message_func): | |
yield chunk | |
__all__ = ['create_streaming_response', 'StreamingResponseGenerator', 'StreamConfig', 'StreamProcessor', 'HtmlToMarkdownConverter'] |