Spaces:
Running
Running
import asyncio | |
import time | |
import uuid | |
from typing import List, Optional, Dict, Union, Callable, Iterator, AsyncGenerator | |
from contextlib import asynccontextmanager | |
from dataclasses import dataclass | |
import logging | |
import os | |
import io | |
from google import genai | |
from google.genai import types | |
import computer_control_helper | |
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Field | |
from seleniumbase import Driver | |
from selenium.webdriver.common.by import By | |
from selenium.webdriver.common.keys import Keys | |
from selenium.webdriver.support.ui import WebDriverWait | |
from selenium.webdriver.support import expected_conditions as EC | |
from selenium.common.exceptions import TimeoutException, NoSuchElementException, StaleElementReferenceException | |
import random | |
import pyautogui | |
# Import the fully updated streaming module | |
from streaming import StreamProcessor, create_streaming_response, StreamConfig, StreamingResponseGenerator | |
import base64 | |
import mss | |
import mss.tools | |
from PIL import Image | |
# Virtual display setup for Linux headless environments | |
import platform | |
if platform.system() == 'Linux': | |
from pyvirtualdisplay import Display | |
display = Display(visible=0, size=(1920, 1080)) | |
display.start() | |
logger = logging.getLogger(__name__) | |
logger.info("Started virtual display for Linux environment") | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler(), | |
logging.FileHandler("lmarena.log") | |
], | |
force=True | |
) | |
logger = logging.getLogger(__name__) | |
class Config: | |
lmarena_url: str = "https://beta.lmarena.ai/?mode=direct" | |
driver_timeout: int = 2 | |
response_timeout: int = 900 | |
poll_interval: float = 0.05 | |
new_chat_click_max_attempts: int = 3 | |
new_chat_click_retry_delay_seconds: float = 0.1 | |
new_chat_click_success_pause_seconds: float = 0.5 | |
page_load_wait_after_refresh_seconds: float = 5 | |
stabilization_timeout: float = 1.0 | |
max_inactivity: float = 10.0 | |
config = Config() | |
logger.info("Configuration loaded.") | |
class LmArenaError(Exception): pass | |
class APIError(Exception): | |
def __init__(self, message: str, status_code: int): | |
self.message = message | |
self.status_code = status_code | |
class ChatInteractionError(APIError): | |
def __init__(self, message: str): | |
super().__init__(message, status_code=502) | |
class ModelSelectionError(LmArenaError): pass | |
class DriverNotAvailableError(LmArenaError): pass | |
class Message(BaseModel): | |
role: str | |
content: Union[str, List[Dict[str, str]]] | |
class ChatCompletionRequest(BaseModel): | |
messages: List[Message] | |
model: str | |
stream: Optional[bool] = False | |
stream_raw_html: Optional[bool] = False | |
convert_html_to_markdown: Optional[bool] = True | |
class Usage(BaseModel): | |
prompt_tokens: int; completion_tokens: int; total_tokens: int | |
class Choice(BaseModel): | |
index: int | |
message: Optional[Dict[str, str]] = None | |
delta: Optional[Dict[str, str]] = None | |
finish_reason: Optional[str] = None | |
class ChatCompletionResponse(BaseModel): | |
id: str; object: str; created: int; model: str | |
choices: List[Choice] | |
usage: Optional[Usage] = None | |
class ModelInfo(BaseModel): | |
id: str | |
object: str = "model" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
owned_by: str = "lmarena" | |
class ModelListResponse(BaseModel): | |
object: str = "list" | |
data: List[ModelInfo] | |
class DriverManager: | |
def __init__(self): | |
logger.info("DriverManager instance created.") | |
self._driver: Optional[Driver] = None | |
self._lock = asyncio.Lock() | |
self._genai_client = None | |
if os.environ.get("GEMINI_API_KEY"): | |
try: | |
self._genai_client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) | |
logger.info("Gemini client initialized successfully.") | |
except Exception as e: | |
logger.error(f"Failed to initialize Gemini client: {e}", exc_info=True) | |
self._genai_client = None | |
else: | |
logger.info("GEMINI_API_KEY not set, Gemini client will not be used for captcha.") | |
async def initialize(self) -> None: | |
async with self._lock: | |
if self._driver is not None: | |
logger.warning("Driver initialization called but driver already exists.") | |
return | |
loop = asyncio.get_event_loop() | |
logger.info("Initializing Selenium driver...") | |
def _sync_initialize_driver_logic(): | |
logger.info("Executing synchronous driver initialization and enhanced readiness check.") | |
temp_driver = None | |
try: | |
temp_driver = Driver(uc=True, headless=False) | |
logger.info("Driver instantiated. Opening URL...") | |
temp_driver.open(config.lmarena_url) | |
logger.info(f"URL '{config.lmarena_url}' opened.") | |
logger.info("Attempting to solve initial (Cloudflare-style) captcha with uc_gui_click_captcha()...") | |
temp_driver.uc_gui_click_captcha() | |
logger.info("uc_gui_click_captcha() completed. Main site should be loading now.") | |
self._perform_sync_captcha_checks(temp_driver) | |
return temp_driver | |
except Exception as e: | |
logger.error(f"Synchronous driver initialization failed: {e}", exc_info=True) | |
if temp_driver: temp_driver.quit() | |
raise LmArenaError(f"Failed to initialize driver: {e}") from e | |
try: | |
self._driver = await loop.run_in_executor(None, _sync_initialize_driver_logic) | |
logger.info("Selenium driver initialization process completed successfully.") | |
except Exception as e: | |
logger.error(f"Asynchronous driver initialization failed: {e}", exc_info=True) | |
if self._driver: | |
try: | |
driver_to_quit = self._driver | |
self._driver = None | |
await loop.run_in_executor(None, driver_to_quit.quit) | |
except Exception as quit_e: | |
logger.error(f"Failed to quit driver after initialization error: {quit_e}") | |
if isinstance(e, LmArenaError): | |
raise | |
raise LmArenaError(f"Failed to initialize driver: {e}") from e | |
def _human_like_reload(self, driver: Driver): | |
logger.info("Performing human-like page reload") | |
if random.random() > 0.5: | |
logger.info("Using F5 key") | |
pyautogui.press('f5') | |
else: | |
logger.info("Using FN+F5 key combination") | |
pyautogui.keyDown('ctrl') | |
time.sleep(0.3) | |
pyautogui.press('f5') | |
time.sleep(0.3) | |
pyautogui.keyUp('ctrl') | |
sleep_time = random.uniform(0.5, 2.0) | |
time.sleep(sleep_time) | |
logger.info(f"Page reloaded after {sleep_time:.2f}s delay") | |
def _perform_sync_captcha_checks(self, driver: Driver): | |
logger.info("Checking for on-site ('Verify Human') captcha...") | |
try: | |
textarea_locator = (By.TAG_NAME, "textarea") | |
textarea = WebDriverWait(driver, 2).until(EC.element_to_be_clickable(textarea_locator)) | |
captcha_present = True | |
time.sleep(10) | |
for selector in ['iframe[src*="api2/anchor"]', 'iframe[src*="recaptcha"]', '.g-recaptcha', '.h-captcha']: | |
try: | |
if driver.find_element(By.CSS_SELECTOR, selector).is_displayed(): | |
captcha_present = True; break | |
except NoSuchElementException: continue | |
if not captcha_present: | |
try: | |
if driver.find_element(By.XPATH, "//*[contains(text(), 'Verify you are human')]").is_displayed(): | |
captcha_present = True | |
except NoSuchElementException: pass | |
if textarea.is_enabled() and textarea.is_displayed() and not captcha_present: | |
logger.info("No on-site captcha detected. Main UI is ready.") | |
return | |
else: | |
try: | |
screenshot_b64 = driver.get_screenshot_as_base64() | |
logger.info(f"Captcha encountered screenshot (base64): {screenshot_b64}") | |
except Exception as ss_e: | |
logger.error(f"Failed to capture screenshot on captcha detection: {ss_e}", exc_info=True) | |
logger.info("Textarea not ready or an on-site captcha indicator was found. Proceeding with AI solver.") | |
except (TimeoutException, NoSuchElementException): | |
logger.info("Chat input textarea not interactable. Proceeding with AI captcha solver.") | |
except Exception as e: | |
logger.warning(f"Unexpected error checking UI state for on-site captcha: {e}", exc_info=True) | |
if not self._genai_client: | |
logger.error("On-site captcha detected, but Gemini client not available. Cannot proceed.") | |
raise LmArenaError("AI Captcha solver is required but not configured.") | |
try: | |
logger.info("Starting visual AI check for on-site captcha.") | |
screenshot = computer_control_helper.capture_screen() | |
if not screenshot: | |
logger.error("Failed to capture screen for AI captcha check."); return | |
img_byte_arr = io.BytesIO() | |
screenshot.save(img_byte_arr, format='PNG') | |
# *** CORRECTED GEMINI API CALL *** | |
contents = [ | |
types.Part.from_bytes(mime_type="image/png", data=img_byte_arr.getvalue()), | |
"""find the text "Verify you are human". do not give me the coordinates of the text itself - give me the coordinates of the small box to the LEFT of the text. Example response: | |
``json | |
[ | |
{"box_2d": [504, 151, 541, 170], "label": "box"} | |
] | |
`` | |
If you cannot find the checkbox, respond with "No checkbox found". | |
""" | |
] | |
generate_content_config = types.GenerateContentConfig(response_mime_type="text/plain") | |
logger.info("Sending screenshot to Gemini API for analysis.") | |
response_stream = self._genai_client.models.generate_content_stream( | |
model="gemini-2.0-flash", | |
contents=contents, | |
config=generate_content_config, | |
) | |
full_response_text = "".join(chunk.text for chunk in response_stream) | |
logger.info(f"Received Gemini response for on-site captcha check: {full_response_text}") | |
if "No checkbox found" in full_response_text: | |
logger.info("Gemini indicated no checkbox found for on-site captcha.") | |
else: | |
parsed_data = computer_control_helper.parse_json_safely(full_response_text) | |
click_target = None | |
if isinstance(parsed_data, list) and parsed_data: | |
if isinstance(parsed_data[0], dict) and "box_2d" in parsed_data[0]: | |
click_target = parsed_data[0] | |
elif isinstance(parsed_data, dict) and "box_2d" in parsed_data: | |
click_target = parsed_data | |
if click_target: | |
logger.info(f"On-site captcha checkbox found via Gemini. Clicking coordinates: {click_target}") | |
computer_control_helper.perform_click(click_target) | |
time.sleep(3) | |
logger.info("Click performed. Now reloading page as requested for post-AI solve.") | |
self._human_like_reload(driver) | |
time.sleep(config.page_load_wait_after_refresh_seconds) | |
else: | |
logger.info("No valid 'box_2d' data found in Gemini response. Reloading as fallback.") | |
self._human_like_reload(driver) | |
except Exception as e: | |
logger.error(f"An error occurred during AI visual captcha check: {e}", exc_info=True) | |
async def cleanup(self) -> None: | |
async with self._lock: | |
if self._driver: | |
logger.info("Cleaning up and quitting Selenium driver...") | |
loop, driver_to_quit, self._driver = asyncio.get_event_loop(), self._driver, None | |
try: | |
await loop.run_in_executor(None, driver_to_quit.quit) | |
logger.info("Driver quit successfully.") | |
except Exception as e: | |
logger.error(f"Error during driver cleanup: {e}", exc_info=True) | |
def get_driver(self) -> Driver: | |
if self._driver is None: raise DriverNotAvailableError("Driver not available") | |
return self._driver | |
async def _select_model(self, model_id: str) -> None: | |
driver = self.get_driver() | |
logger.info(f"Selecting model: {model_id}") | |
def _sync_select_model_logic(drv: Driver, m_id: str): | |
try: | |
dropdown_locator = (By.XPATH, "//button[@data-sentry-source-file='select-model.tsx' and @role='combobox']") | |
WebDriverWait(drv, config.driver_timeout).until(EC.element_to_be_clickable(dropdown_locator)).click() | |
search_locator = (By.XPATH, "//input[@placeholder='Search models' and @cmdk-input]") | |
search_element = WebDriverWait(drv, config.driver_timeout).until(EC.visibility_of_element_located(search_locator)) | |
search_element.clear() | |
search_element.send_keys(m_id) | |
search_element.send_keys(Keys.ENTER) | |
logger.info(f"Selected model: {m_id}") | |
except (NoSuchElementException, TimeoutException) as e: | |
raise ModelSelectionError(f"Failed to select model {m_id}. Original error: {type(e).__name__}") from e | |
except Exception as e_sync: | |
raise ModelSelectionError(f"Failed to select model {m_id}") from e_sync | |
try: | |
await asyncio.get_event_loop().run_in_executor(None, _sync_select_model_logic, driver, model_id) | |
except ModelSelectionError: raise | |
except Exception as e_exec: | |
raise ModelSelectionError(f"Failed to select model {model_id} due to executor error: {e_exec}") from e_exec | |
async def _retry_with_reload(self, driver: Driver, model_id: str): | |
try: | |
pyautogui.press('f5') | |
time.sleep(0.5) | |
pyautogui.keyDown('ctrl'); time.sleep(0.3); pyautogui.press('f5'); time.sleep(0.3); pyautogui.keyUp('ctrl') | |
WebDriverWait(driver, config.page_load_wait_after_refresh_seconds).until(EC.presence_of_element_located((By.XPATH, "//input[@placeholder='Search models' and @cmdk-input]"))) | |
await self._select_model(model_id) | |
except Exception as reload_err: | |
logger.error(f"Reload and retry failed: {reload_err}", exc_info=True) | |
with mss.mss() as sct: | |
img = Image.frombytes("RGB", sct.grab(sct.monitors[1]).size, sct.grab(sct.monitors[1]).bgra, "raw", "BGRX") | |
img_bytes = io.BytesIO(); img.save(img_bytes, format="PNG") | |
b64 = base64.b64encode(img_bytes.getvalue()).decode('utf-8') | |
logger.error(f"Screenshot base64 after failed reload: {b64}") | |
raise ModelSelectionError(f"Failed after reload attempt: {reload_err}") from reload_err | |
def generate_reload_button_location(self, driver: Driver) -> str: | |
logger.info("Generating reload button location with Gemini") | |
try: | |
with mss.mss() as sct: | |
img = Image.frombytes("RGB", sct.grab(sct.monitors[1]).size, sct.grab(sct.monitors[1]).bgra, "raw", "BGRX") | |
img_bytes = io.BytesIO(); img.save(img_bytes, format="PNG"); img_bytes = img_bytes.getvalue() | |
# *** CORRECTED GEMINI API CALL *** | |
contents = [ | |
types.Part.from_bytes(mime_type="image/png", data=img_bytes), | |
"""Find the reload button on the page. It might be labeled with words like "Reload", "Refresh", or have a circular arrow icon. Return the coordinates of the button in the following format: | |
``json | |
[ | |
{"box_2d": [x1, y1, x2, y2], "label": "reload button"} | |
] | |
`` | |
If you cannot find the reload button, respond with "No reload button found". | |
""" | |
] | |
generate_content_config = types.GenerateContentConfig(response_mime_type="text/plain") | |
response_stream = self._genai_client.models.generate_content_stream(model="gemini-2.0-flash", contents=contents, config=generate_content_config) | |
full_response_text = "".join(chunk.text for chunk in response_stream) | |
logger.info(f"Gemini response for reload button: {full_response_text}") | |
if "No reload button found" in full_response_text: | |
logger.info("AI did not find reload button, performing manual F5 and FN+F5 reloads.") | |
try: | |
pyautogui.press('f5'); time.sleep(0.5) | |
pyautogui.keyDown('ctrl'); time.sleep(0.3); pyautogui.press('f5'); time.sleep(0.3); pyautogui.keyUp('ctrl'); time.sleep(0.5) | |
with mss.mss() as sct: | |
img = Image.frombytes("RGB", sct.grab(sct.monitors[1]).size, sct.grab(sct.monitors[1]).bgra, "raw", "BGRX") | |
img_bytes = io.BytesIO(); img.save(img_bytes, format="PNG") | |
b64_2 = base64.b64encode(img_bytes.getvalue()).decode('utf-8') | |
logger.info(f"Screenshot base64 after manual reload: {b64_2}") | |
except Exception as manu_err: | |
logger.error(f"Manual reload simulation failed: {manu_err}", exc_info=True) | |
return full_response_text | |
except Exception as e: | |
logger.error(f"Error generating reload button location: {e}", exc_info=True) | |
return "[]" | |
driver_manager = DriverManager() | |
class ChatHandler: | |
async def send_message_and_stream_response(prompt: str, model_id: str, stream_raw_html: bool = False, convert_html_to_markdown: bool = True): | |
driver = driver_manager.get_driver() | |
request_id = str(uuid.uuid4()) | |
logger.info(f"[{request_id}] Starting chat. Model: '{model_id}', RawHTML: {stream_raw_html}, MarkdownMode: {convert_html_to_markdown}.") | |
try: | |
if model_id: | |
await driver_manager._select_model(model_id) | |
sanitized_prompt = ChatHandler._sanitize_for_bmp(prompt) | |
logger.info(f"[{request_id}] Sending prompt (first 50 chars): '{sanitized_prompt[:50]}...'") | |
await ChatHandler._send_prompt(driver, sanitized_prompt) | |
await ChatHandler._handle_agreement_dialog(driver) | |
logger.info(f"[{request_id}] Prompt sent. Streaming response...") | |
async for chunk in ChatHandler._stream_response(driver, stream_raw_html, convert_html_to_markdown): | |
yield chunk | |
logger.info(f"[{request_id}] Finished streaming response from browser.") | |
except Exception as e: | |
logger.error(f"[{request_id}] Chat interaction failed: {e}", exc_info=True) | |
raise ChatInteractionError(f"Chat interaction failed: {e}") from e | |
finally: | |
logger.info(f"[{request_id}] Cleaning up chat session by clicking 'New Chat'.") | |
try: | |
await ChatHandler._click_new_chat(driver, request_id) | |
except Exception as e_cleanup: | |
logger.error(f"[{request_id}] Error clicking 'New Chat' during cleanup: {e_cleanup}", exc_info=True) | |
def _sanitize_for_bmp(text: str) -> str: | |
return ''.join(c for c in text if ord(c) <= 0xFFFF) | |
async def _send_prompt(driver: Driver, prompt: str): | |
logger.info("Typing prompt into textarea.") | |
await asyncio.get_event_loop().run_in_executor(None, lambda: driver.type('textarea', prompt + "\n")) | |
logger.info("Prompt submitted.") | |
async def _handle_agreement_dialog(driver: Driver): | |
logger.info("Checking for 'Agree' button in dialog.") | |
if await asyncio.get_event_loop().run_in_executor(None, lambda: driver.click_if_visible("//button[normalize-space()='Agree']", timeout=1)): | |
logger.info("'Agree' button found and clicked.") | |
else: | |
logger.info("'Agree' button not visible, skipping.") | |
async def _stream_response(driver: Driver, stream_raw_html: bool, convert_html_to_markdown: bool) -> AsyncGenerator[str, None]: | |
try: | |
# *** CORRECTED XPATH ***: Restored your original, more robust XPath selector. | |
content_container_locator = (By.XPATH, "(//ol[contains(@class, 'flex-col-reverse')]/div[.//h2[starts-with(@id, 'radix-')]])[1]//div[contains(@class, 'grid') and contains(@class, 'pt-4')]") | |
WebDriverWait(driver, config.response_timeout).until(EC.presence_of_element_located(content_container_locator)) | |
stream_config = StreamConfig( | |
poll_interval=config.poll_interval, | |
response_timeout=config.response_timeout, | |
stabilization_timeout=config.stabilization_timeout, | |
max_inactivity=config.max_inactivity, | |
stream_raw_html=stream_raw_html, | |
convert_html_to_markdown=convert_html_to_markdown | |
) | |
stream_processor = StreamProcessor(config=stream_config) | |
processed_stream_iterator = stream_processor.get_processed_text_stream( | |
driver=driver, | |
element_locator=content_container_locator | |
) | |
async for chunk in ChatHandler._sync_to_async(processed_stream_iterator): | |
yield chunk | |
except TimeoutException: | |
logger.error("Streaming error: Timed out waiting for response container to appear.", exc_info=True) | |
yield "\n\nError: Timed out waiting for response from the page." | |
except Exception as e: | |
logger.error(f"Streaming error: {e}", exc_info=True) | |
yield f"\n\nError: {str(e)}" | |
async def _sync_to_async(sync_iter: Iterator[str]) -> AsyncGenerator[str, None]: | |
for item in sync_iter: | |
yield item | |
await asyncio.sleep(0) | |
async def _click_new_chat(driver: Driver, request_id: str): | |
logger.info(f"[{request_id}] Attempting to click 'New Chat' button.") | |
await asyncio.get_event_loop().run_in_executor(None, lambda: driver.click("//a[contains(@class, 'whitespace-nowrap') and .//h2[contains(text(), 'New Chat')]]")) | |
logger.info(f"[{request_id}] 'New Chat' button clicked successfully.") | |
async def get_available_models() -> List[str]: | |
driver = driver_manager.get_driver() | |
def _sync_scrape_models(drv: Driver) -> List[str]: | |
logger.info("Scraping available models...") | |
dropdown_locator = (By.XPATH, "//button[@data-sentry-source-file='select-model.tsx' and @role='combobox']") | |
model_item_locator = (By.XPATH, "//div[@cmdk-item and @data-value]") | |
try: | |
dropdown_button = WebDriverWait(drv, config.driver_timeout).until(EC.element_to_be_clickable(dropdown_locator)) | |
dropdown_button.click(); logger.info("Model dropdown clicked.") | |
WebDriverWait(drv, config.driver_timeout).until(EC.presence_of_all_elements_located(model_item_locator)); time.sleep(0.5) | |
model_ids = [elem.get_attribute('data-value') for elem in drv.find_elements(*model_item_locator) if elem.get_attribute('data-value')] | |
logger.info(f"Found {len(model_ids)} models.") | |
dropdown_button.click(); logger.info("Closed model dropdown.") | |
return model_ids | |
except (TimeoutException, NoSuchElementException) as e: | |
logger.error(f"Failed to scrape models: {e}", exc_info=True) | |
try: drv.find_element(*dropdown_locator).click() | |
except Exception as close_e: logger.warning(f"Could not close model dropdown after error: {close_e}") | |
raise LmArenaError(f"Could not find or interact with the model dropdown: {e}") from e | |
try: | |
return await asyncio.get_event_loop().run_in_executor(None, _sync_scrape_models, driver) | |
except Exception as e: | |
logger.error(f"Error executing model scraping in executor: {e}", exc_info=True); raise | |
async def lifespan(app: FastAPI): | |
logger.info("Application startup sequence initiated.") | |
try: | |
await driver_manager.initialize() | |
logger.info("Application startup sequence completed successfully.") | |
except Exception as e: | |
logger.critical(f"A critical error occurred during application startup: {e}", exc_info=True) | |
await driver_manager.cleanup(); raise | |
yield | |
logger.info("Application shutdown sequence initiated.") | |
await driver_manager.cleanup() | |
logger.info("Application shutdown sequence completed.") | |
app = FastAPI(lifespan=lifespan) | |
async def health_check(): | |
try: | |
driver_manager.get_driver() | |
return {"status": "healthy", "driver": "available"} | |
except DriverNotAvailableError: | |
return {"status": "unhealthy", "driver": "unavailable"} | |
async def list_models(): | |
logger.info("Received request for /models endpoint.") | |
try: | |
model_ids = await get_available_models() | |
return ModelListResponse(data=[ModelInfo(id=model_id) for model_id in model_ids]) | |
except DriverNotAvailableError: | |
raise HTTPException(status_code=503, detail="Service unavailable: The backend driver is not ready.") | |
except Exception as e: | |
logger.error(f"An unexpected error occurred while fetching models: {e}", exc_info=True) | |
# Attempt captcha resolution then retry | |
try: | |
drv = driver_manager.get_driver() | |
driver_manager._perform_sync_captcha_checks(drv) | |
logger.info("Retrying model fetch after captcha checks...") | |
model_ids = await get_available_models() | |
return ModelListResponse(data=[ModelInfo(id=mid) for mid in model_ids]) | |
except Exception as retry_e: | |
logger.error(f"Retry after captcha checks failed: {retry_e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while fetching models: {str(retry_e)}") | |
async def chat_completions(request: ChatCompletionRequest): | |
completion_id, created_timestamp = f"chatcmpl-{uuid.uuid4().hex}", int(time.time()) | |
logger.info(f"[{completion_id}] Received chat completion request: model='{request.model}', stream={request.stream}, md_convert={request.convert_html_to_markdown}") | |
full_prompt = "\n".join([msg.content for msg in request.messages if isinstance(msg.content, str)]) | |
try: | |
driver_manager.get_driver() | |
send_message_func = lambda p, m: ChatHandler.send_message_and_stream_response( | |
prompt=p, model_id=m, stream_raw_html=request.stream_raw_html, convert_html_to_markdown=request.convert_html_to_markdown | |
) | |
if request.stream: | |
return StreamingResponse( | |
create_streaming_response( | |
completion_id=completion_id, created=created_timestamp, model=request.model, prompt=full_prompt, send_message_func=send_message_func | |
), media_type="text/event-stream" | |
) | |
else: | |
return await _create_non_streaming_response( | |
completion_id, created_timestamp, request.model, full_prompt, request.convert_html_to_markdown | |
) | |
except DriverNotAvailableError as e: | |
raise HTTPException(status_code=503, detail="Service unavailable: The backend driver is not ready.") | |
except APIError as e: | |
raise HTTPException(status_code=e.status_code, detail=e.message) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An unexpected processing error occurred: {str(e)}") | |
async def _create_non_streaming_response(completion_id: str, created: int, model: str, prompt: str, convert_html_to_markdown: bool) -> ChatCompletionResponse: | |
try: | |
content_parts = [chunk async for chunk in ChatHandler.send_message_and_stream_response(prompt, model, stream_raw_html=False, convert_html_to_markdown=convert_html_to_markdown)] | |
final_content = "".join(content_parts) | |
return ChatCompletionResponse( | |
id=completion_id, object="chat.completion", created=created, model=model, | |
choices=[Choice(index=0, message={"role": "assistant", "content": final_content}, finish_reason="stop")], | |
usage=Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0) | |
) | |
except Exception as e: | |
logger.error(f"[{completion_id}] Exception during non-streaming response creation: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail="Error processing non-streaming request.") | |
if __name__ == "__main__": | |
import uvicorn | |
if not os.getenv("GEMINI_API_KEY"): | |
logger.error("FATAL: GEMINI_API_KEY environment variable not set. Captcha solving will be disabled.") | |
else: | |
logger.info("GEMINI_API_KEY is set.") | |
logger.info("Starting Uvicorn server on 0.0.0.0:8000.") | |
uvicorn.run(app, host="0.0.0.0", port=8000) |