Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse | |
from fastapi.concurrency import run_in_threadpool | |
from gradio_client import Client | |
import gradio as gr | |
import uvicorn | |
import httpx | |
import websockets | |
import asyncio | |
from urllib.parse import urljoin, urlparse, unquote | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
class HttpClient: | |
def __init__(self): | |
# Configure the HTTP client with appropriate timeouts | |
self.client = httpx.AsyncClient( | |
timeout=httpx.Timeout(30.0), | |
follow_redirects=False | |
) | |
async def forward_request(self, request: Request, target_url: str): | |
""" | |
Forward an incoming request to a target URL | |
""" | |
try: | |
# Extract method, headers, and body from the incoming request | |
method = request.method | |
headers = dict(request.headers) | |
# Remove headers that shouldn't be forwarded | |
headers.pop("host", None) | |
headers.pop("connection", None) | |
# Get the request body | |
body = await request.body() | |
logger.info(f"Forwarding {method} request to {target_url}") | |
# Forward the request to the target URL | |
response = await self.client.request( | |
method=method, | |
url=target_url, | |
headers=headers, | |
content=body | |
) | |
# Handle the response from the target server | |
response_headers = dict(response.headers) | |
# Remove headers that shouldn't be forwarded from the response | |
response_headers.pop("connection", None) | |
response_headers.pop("transfer-encoding", None) | |
return Response( | |
content=response.content, | |
status_code=response.status_code, | |
headers=response_headers | |
) | |
except httpx.TimeoutException: | |
logger.error(f"Timeout error while forwarding request to {target_url}") | |
return Response( | |
content="Request timeout error", | |
status_code=504 | |
) | |
except httpx.NetworkError as e: | |
logger.error(f"Network error while forwarding request: {str(e)}") | |
return Response( | |
content=f"Network error: {str(e)}", | |
status_code=502 | |
) | |
except Exception as e: | |
logger.error(f"Error forwarding request: {str(e)}") | |
return Response( | |
content=f"Request error: {str(e)}", | |
status_code=500 | |
) | |
async def close(self): | |
await self.client.aclose() | |
# Initialize the HTTP client | |
http_client = HttpClient() | |
async def read_root(): | |
with open("index.html") as f: | |
return f.read() | |
async def gradio_client(repo_id: str, api_name: str, request: Request): | |
client = Client(repo_id) | |
data = await request.json() | |
result = await run_in_threadpool(client.predict, *data["args"], api_name=f"/{api_name}") | |
return result | |
async def web_client(request: Request, path: str): | |
""" | |
Main web client endpoint that forwards all requests to the target URL | |
specified in the path or the 'X-Target-Url' header | |
""" | |
# Prioritize URL in path if it starts with http:// or https:// | |
if path.lower().startswith("http://") or path.lower().startswith("https://"): | |
target_url = path | |
else: | |
# Get the target URL from the header | |
target_url = request.headers.get("X-Target-Url") | |
# If we have a target URL from header and a path, combine them | |
if target_url and path: | |
# Validate the target URL from header | |
try: | |
parsed_url = urlparse(target_url) | |
if not parsed_url.scheme or not parsed_url.netloc: | |
return Response( | |
content="Invalid X-Target-Url header", | |
status_code=400 | |
) | |
except Exception: | |
return Response( | |
content="Invalid X-Target-Url header", | |
status_code=400 | |
) | |
# Join the target URL with the path properly | |
target_url = urljoin(target_url.rstrip('/') + '/', path.lstrip('/')) | |
if not target_url: | |
return Response( | |
content="Missing X-Target-Url header or URL in path", | |
status_code=400 | |
) | |
# Validate the target URL | |
try: | |
parsed_url = urlparse(target_url) | |
if not parsed_url.scheme or not parsed_url.netloc: | |
return Response( | |
content="Invalid X-Target-Url header or URL in path", | |
status_code=400 | |
) | |
except Exception: | |
return Response( | |
content="Invalid X-Target-Url header or URL in path", | |
status_code=400 | |
) | |
# Forward the request | |
return await http_client.forward_request(request, target_url) | |
async def websocket_client(websocket: WebSocket, path: str): | |
""" | |
WebSocket endpoint that forwards WebSocket connections to the target URL | |
specified in the 'X-Target-Url' header or in the path | |
""" | |
# Get the target URL from the header or path | |
target_url = websocket.headers.get("X-Target-Url") | |
# If no header, use path as target URL if it's a valid WebSocket URL | |
if not target_url: | |
# Handle URL-encoded paths | |
decoded_path = path | |
if path and '%' in path: | |
# URL decode the path | |
decoded_path = unquote(path) | |
if decoded_path and (decoded_path.lower().startswith("ws://") or decoded_path.lower().startswith("wss://")): | |
target_url = decoded_path | |
else: | |
await websocket.close(code=1008, reason="Missing X-Target-Url header or invalid URL in path") | |
return | |
# Validate the target URL | |
try: | |
parsed_url = urlparse(target_url) | |
if not parsed_url.scheme or not parsed_url.netloc: | |
await websocket.close(code=1008, reason="Invalid target URL") | |
return | |
except Exception: | |
await websocket.close(code=1008, reason="Invalid target URL") | |
return | |
# Accept the WebSocket connection | |
await websocket.accept() | |
# Convert HTTP/HTTPS URL to WebSocket URL | |
if target_url.lower().startswith("https://"): | |
ws_target_url = "wss://" + target_url[8:] | |
elif target_url.lower().startswith("http://"): | |
ws_target_url = "ws://" + target_url[7:] | |
else: | |
ws_target_url = target_url | |
# Add path if provided (but only if it's not already a complete URL) | |
if path and not (path.lower().startswith("ws://") or path.lower().startswith("wss://")): | |
# Join the target URL with the path properly | |
ws_target_url = urljoin(ws_target_url.rstrip('/') + '/', path.lstrip('/')) | |
try: | |
# Connect to the target WebSocket server | |
async with websockets.connect(ws_target_url) as target_ws: | |
# Forward messages between client and target server | |
async def forward_client_to_target(): | |
try: | |
while True: | |
data = await websocket.receive_text() | |
await target_ws.send(data) | |
except WebSocketDisconnect: | |
pass | |
async def forward_target_to_client(): | |
try: | |
while True: | |
data = await target_ws.recv() | |
await websocket.send_text(data) | |
except websockets.ConnectionClosed: | |
pass | |
# Run both forwarding tasks concurrently | |
await asyncio.gather( | |
forward_client_to_target(), | |
forward_target_to_client(), | |
return_exceptions=True | |
) | |
except websockets.InvalidURI: | |
await websocket.close(code=1008, reason="Invalid WebSocket URL") | |
except websockets.InvalidHandshake: | |
await websocket.close(code=1008, reason="WebSocket handshake failed") | |
except Exception as e: | |
logger.error(f"Error in WebSocket connection: {str(e)}") | |
await websocket.close(code=1011, reason="Internal server error") | |
finally: | |
try: | |
await websocket.close() | |
except: | |
pass | |
async def health_check(): | |
"""Health check endpoint""" | |
return {"status": "ok"} | |
async def shutdown_event(): | |
await http_client.close() | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |