Spaces:
Sleeping
Sleeping
File size: 9,132 Bytes
d067c3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.concurrency import run_in_threadpool
from gradio_client import Client
import gradio as gr
import uvicorn
import httpx
import websockets
import asyncio
from urllib.parse import urljoin, urlparse, unquote
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
class HttpClient:
def __init__(self):
# Configure the HTTP client with appropriate timeouts
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
follow_redirects=False
)
async def forward_request(self, request: Request, target_url: str):
"""
Forward an incoming request to a target URL
"""
try:
# Extract method, headers, and body from the incoming request
method = request.method
headers = dict(request.headers)
# Remove headers that shouldn't be forwarded
headers.pop("host", None)
headers.pop("connection", None)
# Get the request body
body = await request.body()
logger.info(f"Forwarding {method} request to {target_url}")
# Forward the request to the target URL
response = await self.client.request(
method=method,
url=target_url,
headers=headers,
content=body
)
# Handle the response from the target server
response_headers = dict(response.headers)
# Remove headers that shouldn't be forwarded from the response
response_headers.pop("connection", None)
response_headers.pop("transfer-encoding", None)
return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers
)
except httpx.TimeoutException:
logger.error(f"Timeout error while forwarding request to {target_url}")
return Response(
content="Request timeout error",
status_code=504
)
except httpx.NetworkError as e:
logger.error(f"Network error while forwarding request: {str(e)}")
return Response(
content=f"Network error: {str(e)}",
status_code=502
)
except Exception as e:
logger.error(f"Error forwarding request: {str(e)}")
return Response(
content=f"Request error: {str(e)}",
status_code=500
)
async def close(self):
await self.client.aclose()
# Initialize the HTTP client
http_client = HttpClient()
@app.get("/", response_class=HTMLResponse)
async def read_root():
with open("index.html") as f:
return f.read()
@app.post("/gp/{repo_id:path}/{api_name:path}")
async def gradio_client(repo_id: str, api_name: str, request: Request):
client = Client(repo_id)
data = await request.json()
result = await run_in_threadpool(client.predict, *data["args"], api_name=f"/{api_name}")
return result
@app.api_route("/wp/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
async def web_client(request: Request, path: str):
"""
Main web client endpoint that forwards all requests to the target URL
specified in the path or the 'X-Target-Url' header
"""
# Prioritize URL in path if it starts with http:// or https://
if path.lower().startswith("http://") or path.lower().startswith("https://"):
target_url = path
else:
# Get the target URL from the header
target_url = request.headers.get("X-Target-Url")
# If we have a target URL from header and a path, combine them
if target_url and path:
# Validate the target URL from header
try:
parsed_url = urlparse(target_url)
if not parsed_url.scheme or not parsed_url.netloc:
return Response(
content="Invalid X-Target-Url header",
status_code=400
)
except Exception:
return Response(
content="Invalid X-Target-Url header",
status_code=400
)
# Join the target URL with the path properly
target_url = urljoin(target_url.rstrip('/') + '/', path.lstrip('/'))
if not target_url:
return Response(
content="Missing X-Target-Url header or URL in path",
status_code=400
)
# Validate the target URL
try:
parsed_url = urlparse(target_url)
if not parsed_url.scheme or not parsed_url.netloc:
return Response(
content="Invalid X-Target-Url header or URL in path",
status_code=400
)
except Exception:
return Response(
content="Invalid X-Target-Url header or URL in path",
status_code=400
)
# Forward the request
return await http_client.forward_request(request, target_url)
@app.websocket("/wp/{path:path}")
async def websocket_client(websocket: WebSocket, path: str):
"""
WebSocket endpoint that forwards WebSocket connections to the target URL
specified in the 'X-Target-Url' header or in the path
"""
# Get the target URL from the header or path
target_url = websocket.headers.get("X-Target-Url")
# If no header, use path as target URL if it's a valid WebSocket URL
if not target_url:
# Handle URL-encoded paths
decoded_path = path
if path and '%' in path:
# URL decode the path
decoded_path = unquote(path)
if decoded_path and (decoded_path.lower().startswith("ws://") or decoded_path.lower().startswith("wss://")):
target_url = decoded_path
else:
await websocket.close(code=1008, reason="Missing X-Target-Url header or invalid URL in path")
return
# Validate the target URL
try:
parsed_url = urlparse(target_url)
if not parsed_url.scheme or not parsed_url.netloc:
await websocket.close(code=1008, reason="Invalid target URL")
return
except Exception:
await websocket.close(code=1008, reason="Invalid target URL")
return
# Accept the WebSocket connection
await websocket.accept()
# Convert HTTP/HTTPS URL to WebSocket URL
if target_url.lower().startswith("https://"):
ws_target_url = "wss://" + target_url[8:]
elif target_url.lower().startswith("http://"):
ws_target_url = "ws://" + target_url[7:]
else:
ws_target_url = target_url
# Add path if provided (but only if it's not already a complete URL)
if path and not (path.lower().startswith("ws://") or path.lower().startswith("wss://")):
# Join the target URL with the path properly
ws_target_url = urljoin(ws_target_url.rstrip('/') + '/', path.lstrip('/'))
try:
# Connect to the target WebSocket server
async with websockets.connect(ws_target_url) as target_ws:
# Forward messages between client and target server
async def forward_client_to_target():
try:
while True:
data = await websocket.receive_text()
await target_ws.send(data)
except WebSocketDisconnect:
pass
async def forward_target_to_client():
try:
while True:
data = await target_ws.recv()
await websocket.send_text(data)
except websockets.ConnectionClosed:
pass
# Run both forwarding tasks concurrently
await asyncio.gather(
forward_client_to_target(),
forward_target_to_client(),
return_exceptions=True
)
except websockets.InvalidURI:
await websocket.close(code=1008, reason="Invalid WebSocket URL")
except websockets.InvalidHandshake:
await websocket.close(code=1008, reason="WebSocket handshake failed")
except Exception as e:
logger.error(f"Error in WebSocket connection: {str(e)}")
await websocket.close(code=1011, reason="Internal server error")
finally:
try:
await websocket.close()
except:
pass
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "ok"}
@app.on_event("shutdown")
async def shutdown_event():
await http_client.close()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|